Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
a167a22
feat: tensorrt support
matejpekar Jan 17, 2026
1d3310f
fix: remove flush
matejpekar Jan 17, 2026
4e27a48
feat: add docker files for cpu/gpu
Feb 8, 2026
fd3154d
feat: add PVC for TensorRT
Feb 8, 2026
eaac807
feat: add support of TensorRT for models
Feb 8, 2026
46fe8b1
feat: add TensorRT cache to workers
Feb 8, 2026
f07723e
add Jiri as coauthor
Feb 8, 2026
9d6e265
fix: remove gpu number from serve.deployment in code
Feb 8, 2026
e7612f9
fix: warning suppress
Feb 9, 2026
5945f10
feat: add jobs to download virchow2
Feb 10, 2026
8ef4cc5
feat: add model provider for hf
Feb 10, 2026
a6c427e
feat: add pvc for huggingface
Feb 10, 2026
27c7801
feat: add virchow2 model
Feb 10, 2026
e5d84cb
fix
Feb 10, 2026
e1fcb6c
fix: fine tune
Feb 10, 2026
e7ac073
feat: add into dockerfile
Feb 14, 2026
51f07a4
fix: remove installs from model
Feb 14, 2026
178f226
fix: based on official docs
Feb 14, 2026
5cc123f
fix
Feb 14, 2026
964114e
fix: remove comment
Feb 14, 2026
181f79e
chore: update docker gpu file
Jurgee Mar 13, 2026
156a7d4
feat: optimalize virchow2 deployment
Jurgee Mar 13, 2026
57176d6
fix: remove hf token, create new secret
Jurgee Mar 14, 2026
fe51ee2
fix
Jurgee Mar 14, 2026
9f37d03
Merge branch 'main' into feature/virchow2-model
Jurgee Mar 14, 2026
210c7e6
fix: remove intra threads
Jurgee Mar 14, 2026
bf7cff1
fix: lint
Jurgee Mar 14, 2026
6813264
fix: remove duplicity
Jurgee Mar 14, 2026
7510c9f
fixes
Jurgee Mar 14, 2026
2eae503
docker files
Jurgee Mar 14, 2026
c5095bd
fix: docker
Jurgee Mar 14, 2026
e94baec
chore: new docker image
Jurgee Mar 14, 2026
7cdd290
chore: cpu docker
Jurgee Mar 14, 2026
8cce2cb
fix
Jurgee Mar 14, 2026
7e329a8
final changes
Jurgee Mar 14, 2026
8dfea82
fix: usage of master branch
Jurgee Mar 15, 2026
e6f8603
Potential fix for pull request finding
Jurgee Mar 15, 2026
fb646c4
Potential fix for pull request finding
Jurgee Mar 15, 2026
b2d083c
Potential fix for pull request finding
Jurgee Mar 15, 2026
bfd90a9
Potential fix for pull request finding
Jurgee Mar 15, 2026
75ed923
fix: comments
Jurgee Mar 18, 2026
bfcbe5b
fix
Jurgee Mar 18, 2026
1247e72
fix: comment remove
Jurgee Mar 18, 2026
d1a3d97
chore: new docker image
Jurgee Mar 18, 2026
3d26432
fix: simple pip command
Jurgee Mar 22, 2026
5ccf860
feat: newly optimalized model
Jurgee Mar 23, 2026
c62f2ef
fix: remove hf token from gpu workers
Jurgee Mar 23, 2026
cea82e7
refactor: replace read_region_relative with read_tile in fetch_tissue…
Jurgee Mar 23, 2026
5816940
fix: Type error
Jurgee Mar 23, 2026
ae96ca1
fix: remove index url
Jurgee Mar 28, 2026
5cd8344
refactor: clean up Virchow2 deployment and simplify model loading
Jurgee Mar 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ RUN sudo apt-get update && sudo apt-get -y upgrade && \
# Cleanup
RUN sudo apt-get remove -y --purge systemd systemd-sysv && sudo apt-get autoremove --purge -y && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*

RUN pip install --no-cache-dir onnxruntime lz4 ratiopath "mlflow<3.0"
RUN pip install --no-cache-dir \
onnxruntime lz4 ratiopath "mlflow<3.0"
5 changes: 4 additions & 1 deletion docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ RUN sudo sh -c 'echo "/usr/local/lib" > /etc/ld.so.conf.d/custom-libs.conf' && \
sudo sh -c 'echo "/home/ray/anaconda3/lib/python3.12/site-packages/nvidia/cudnn/lib" > /etc/ld.so.conf.d/nvidia-libs.conf' && \
sudo ldconfig

RUN pip install --no-cache-dir onnxruntime-gpu tensorrt lz4 ratiopath "mlflow<3.0"
RUN pip install --no-cache-dir \
onnxruntime-gpu tensorrt-cu12 lz4 ratiopath "mlflow<3.0" torch torchvision \
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need TensorRT for cuda 12? By default it uses cuda 13, same as torch. This way you are using two different cuda version

"timm>=1.0.0" \
"huggingface-hub>=0.23.0"
11 changes: 7 additions & 4 deletions misc/fetch_tissue_tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ def fetch_tissue_tile(
)
if not np.asarray(tile.convert("L")).any():
return None

tile = slide.read_region_relative((x, y), level, (tile_size, tile_size)).convert(
"RGB"
tile = slide.read_tile(
x=x,
y=y,
extent_x=tile_size,
extent_y=tile_size,
level=level,
)
return np.asarray(tile).transpose(2, 0, 1)
return tile.transpose(2, 0, 1)
9 changes: 5 additions & 4 deletions models/binary_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class Config(TypedDict):
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float
intra_op_num_threads: int
trt_cache_path: str


Expand Down Expand Up @@ -69,7 +68,7 @@ def reconfigure(self, config: Config) -> None:
"trt_engine_cache_path": cache_path,
"trt_max_workspace_size": config.get(
"trt_max_workspace_size", 8 * 1024 * 1024 * 1024
), # type: ignore[typeddict-item]
),
"trt_builder_optimization_level": 5,
"trt_timing_cache_enable": True,
"trt_profile_min_shapes": min_shape,
Expand All @@ -79,7 +78,6 @@ def reconfigure(self, config: Config) -> None:

# Configure ONNX Runtime session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = config["intra_op_num_threads"]
sess_options.inter_op_num_threads = 1

# Enable all graph optimizations (constant folding, node fusion, etc.) for maximum inference performance.
Expand Down Expand Up @@ -118,7 +116,10 @@ async def predict(self, images: list[NDArray[np.uint8]]) -> list[float]:
"""Run inference on a batch of images."""
batch = np.stack(images, axis=0, dtype=np.uint8)

outputs = self.session.run([self.output_name], {self.input_name: batch})
outputs = self.session.run(
[self.output_name],
{self.input_name: batch},
)

return outputs[0].flatten().tolist() # pyright: ignore[reportAttributeAccessIssue]

Expand Down
4 changes: 1 addition & 3 deletions models/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class Config(TypedDict):
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float
intra_op_num_threads: int
trt_cache_path: str


Expand Down Expand Up @@ -65,7 +64,7 @@ def reconfigure(self, config: Config) -> None:
"trt_engine_cache_path": cache_path,
"trt_max_workspace_size": config.get(
"trt_max_workspace_size", 8 * 1024 * 1024 * 1024
), # type: ignore[typeddict-item]
),
"trt_builder_optimization_level": 5,
"trt_timing_cache_enable": True,
"trt_profile_min_shapes": min_shape,
Expand All @@ -75,7 +74,6 @@ def reconfigure(self, config: Config) -> None:

# Configure ONNX Runtime session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = config["intra_op_num_threads"]
sess_options.inter_op_num_threads = 1

# Enable all graph optimizations (constant folding, node fusion, etc.) for maximum inference performance.
Expand Down
121 changes: 121 additions & 0 deletions models/virchow2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import asyncio
import logging
from typing import Any, TypedDict, cast

import lz4.frame
import numpy as np
import torch
from fastapi import FastAPI, Request, Response
from numpy.typing import NDArray
from PIL import Image
from ray import serve


class Config(TypedDict):
tile_size: int
model: dict[str, Any]
max_batch_size: int
batch_wait_timeout_s: float


fastapi = FastAPI()
logger = logging.getLogger("ray.serve")


@serve.deployment(num_replicas="auto")
@serve.ingress(fastapi)
class Virchow2:
"""Virchow2 foundation model for pathology."""

def __init__(self) -> None:
self.model: torch.nn.Module | None = None
self.transforms: Any = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tile_size: int = 0
Comment on lines +30 to +34
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self) -> None:
self.model: torch.nn.Module | None = None
self.transforms: Any = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tile_size: int = 0
model: torch.nn.Module
transforms: Any
tile_size: int
def __init__(self) -> None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def reconfigure(self, config: Config) -> None:
import importlib

import timm
from timm.data.config import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers.mlp import SwiGLUPacked

self.tile_size = config["tile_size"]

model_config = dict(config["model"])
module_path, attr_name = model_config.pop("_target_").split(":")
provider = getattr(importlib.import_module(module_path), attr_name)
repo_id = model_config["repo_id"]

logger.info(f"Loading Virchow2 model from {repo_id}...")
provider(**model_config)

self.model = timm.create_model(
f"hf-hub:{repo_id}",
pretrained=True,
num_classes=0,
mlp_layer=SwiGLUPacked,
act_layer=torch.nn.SiLU,
)
Comment on lines +52 to +60
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be fragile. I guess the provider downloads the model to cache and then timm loads it from the cache based on the environmental variables set by the provider. Why do you even need to call the provider? I think timm can handle the downloading and storing to cache?

self.model = self.model.to(self.device).eval()

self.transforms = create_transform(
**resolve_data_config(self.model.pretrained_cfg, model=self.model)
)

self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined]
self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined]

@serve.batch
async def predict(
self, inputs: list[torch.Tensor | NDArray[np.float16] | NDArray[np.float32]]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accept only one type. Also check if ray can serialize tensors

) -> list[NDArray[np.float32]]:
tensors = torch.stack(
[
item if isinstance(item, torch.Tensor) else torch.from_numpy(item)
for item in inputs
]
).to(self.device)
model = cast("torch.nn.Module", self.model)

device_type = self.device.type
autocast_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you switching to bfp16 for CPUs? Is fp16 not supported by the cluster CPUs?


with (
torch.inference_mode(),
torch.autocast(device_type=device_type, dtype=autocast_dtype),
):
output = model(tensors)
class_token = output[:, 0]
patch_tokens = output[:, 5:]
embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1)
Comment on lines +90 to +92
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be optional. Some users might want to access the individual patch tokens


return list(embedding.float().cpu().numpy())

@fastapi.post("/")
async def root(self, request: Request) -> Response:
data = await asyncio.to_thread(lz4.frame.decompress, await request.body())
image = np.frombuffer(data, dtype=np.uint8).reshape(
self.tile_size, self.tile_size, 3
)

requested_dtype = request.headers.get("x-output-dtype", "float32").lower()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output dtype should be applied in the predict method to avoid serialization of large arrays


output_dtype = np.dtype(requested_dtype)
tensor = self.transforms(Image.fromarray(image))
result: NDArray[np.float32] = await self.predict(tensor)
output_shape = ",".join(str(d) for d in result.shape)

return Response(
content=lz4.frame.compress(
result.astype(output_dtype, copy=False).tobytes()
),
media_type="application/octet-stream",
headers={
"x-output-shape": output_shape,
},
)


app = Virchow2.bind() # type: ignore[attr-defined]
25 changes: 25 additions & 0 deletions providers/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,28 @@ def mlflow(artifact_uri: str) -> str:
import mlflow.artifacts

return mlflow.artifacts.download_artifacts(artifact_uri=artifact_uri)


def huggingface(repo_id: str, filename: str | None = None) -> str:
import os

from huggingface_hub import hf_hub_download, snapshot_download

# If HF_TOKEN is not set, we assume we're running in an offline environment and only allow loading from the local cache.
offline = os.environ.get("HF_TOKEN") is None

hf_home = os.environ.get("HF_HOME", "/mnt/huggingface_cache")
os.makedirs(hf_home, exist_ok=True)
os.environ["HF_HOME"] = hf_home

if filename:
return hf_hub_download(
repo_id=repo_id,
filename=filename,
local_files_only=offline,
)
else:
return snapshot_download(
repo_id=repo_id,
local_files_only=offline,
)
12 changes: 12 additions & 0 deletions pvc/huggingface-pvc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: huggingface-cache-pvc
namespace: rationai-jobs-ns
spec:
accessModes:
- ReadWriteMany
resources:
requests:
storage: 15Gi
storageClassName: nfs-csi
55 changes: 49 additions & 6 deletions ray-service.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ spec:
tile_size: 512
max_batch_size: 16
batch_wait_timeout_s: 0.01
intra_op_num_threads: 4
trt_max_workspace_size: 8589934592 # 8 GiB
trt_cache_path: /mnt/cache/trt_cache
model:
Expand Down Expand Up @@ -67,7 +66,6 @@ spec:
_target_: providers.model_provider:mlflow
artifact_uri: mlflow-artifacts:/10/39f821ed5b964c71a603cc6db196f9fd/artifacts/checkpoints/epoch=19-step=32020/model.onnx/model.onnx
- name: heatmap-builder
import_path: builders.heatmap_builder:app
route_prefix: /heatmap-builder
Expand All @@ -88,6 +86,41 @@ spec:
num_threads: 8
max_concurrent_tasks: 24
- name: virchow2
import_path: models.virchow2:app
route_prefix: /virchow2
runtime_env:
config:
setup_timeout_seconds: 1800
working_dir: https://github.com/RationAI/model-service/archive/refs/heads/main.zip
deployments:
- name: Virchow2
max_ongoing_requests: 200
max_queued_requests: 2048
autoscaling_config:
min_replicas: 0
max_replicas: 1
target_ongoing_requests: 128
ray_actor_options:
num_cpus: 8
num_gpus: 1
memory: 8589934592 # 8 GiB
runtime_env:
env_vars:
HF_HOME: "/mnt/huggingface_cache"
HF_TOKEN:
valueFrom:
secretKeyRef:
name: huggingface-secret
key: token
user_config:
tile_size: 224
max_batch_size: 256
batch_wait_timeout_s: 0.1
model:
_target_: providers.model_provider:huggingface
repo_id: paige-ai/Virchow2
rayClusterConfig:
rayVersion: 2.53.0
enableInTreeAutoscaling: true
Expand Down Expand Up @@ -185,6 +218,8 @@ spec:
mountPath: /mnt/bioptic_tree
- name: trt-cache-volume
mountPath: /mnt/cache
- name: huggingface-cache
mountPath: /mnt/huggingface_cache

volumes:
- name: data
Expand All @@ -202,11 +237,16 @@ spec:
- name: trt-cache-volume
persistentVolumeClaim:
claimName: tensorrt-cache-pvc
- name: huggingface-cache
persistentVolumeClaim:
claimName: huggingface-cache-pvc

- groupName: gpu-workers
replicas: 0
minReplicas: 0
maxReplicas: 2
rayStartParams:
num-gpus: "1"
template:
spec:
securityContext:
Expand All @@ -216,17 +256,15 @@ spec:
runAsUser: 1000
seccompProfile:
type: RuntimeDefault
nodeSelector:
nvidia.com/gpu.product: NVIDIA-A40
containers:
- name: ray-worker
image: cerit.io/rationai/model-service:2.53.0-gpu
image: cerit.io/rationai/model-service:2.54.0-gpu
imagePullPolicy: Always
resources:
limits:
cpu: 8
memory: 24Gi
nvidia.com/gpu: 1
nvidia.com/mig-2g.20gb: 1
requests:
cpu: 8
memory: 24Gi
Expand All @@ -253,6 +291,8 @@ spec:
mountPath: /mnt/bioptic_tree
- name: trt-cache-volume
mountPath: /mnt/cache
- name: huggingface-cache
mountPath: /mnt/huggingface_cache

volumes:
- name: data
Expand All @@ -270,3 +310,6 @@ spec:
- name: trt-cache-volume
persistentVolumeClaim:
claimName: tensorrt-cache-pvc
- name: huggingface-cache
persistentVolumeClaim:
claimName: huggingface-cache-pvc