-
Notifications
You must be signed in to change notification settings - Fork 0
Virchow2 model #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a167a22
1d3310f
4e27a48
fd3154d
eaac807
46fe8b1
f07723e
9d6e265
e7612f9
5945f10
8ef4cc5
a6c427e
27c7801
e5d84cb
e1fcb6c
e7ac073
51f07a4
178f226
5cc123f
964114e
181f79e
156a7d4
57176d6
fe51ee2
9f37d03
210c7e6
bf7cff1
6813264
7510c9f
2eae503
c5095bd
e94baec
7cdd290
8cce2cb
7e329a8
8dfea82
e6f8603
fb646c4
b2d083c
bfd90a9
75ed923
bfcbe5b
1247e72
d1a3d97
3d26432
5ccf860
c62f2ef
cea82e7
5816940
ae96ca1
5cd8344
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,4 +53,7 @@ RUN sudo sh -c 'echo "/usr/local/lib" > /etc/ld.so.conf.d/custom-libs.conf' && \ | |
| sudo sh -c 'echo "/home/ray/anaconda3/lib/python3.12/site-packages/nvidia/cudnn/lib" > /etc/ld.so.conf.d/nvidia-libs.conf' && \ | ||
| sudo ldconfig | ||
|
|
||
| RUN pip install --no-cache-dir onnxruntime-gpu tensorrt lz4 ratiopath "mlflow<3.0" | ||
| RUN pip install --no-cache-dir \ | ||
| onnxruntime-gpu tensorrt-cu12 lz4 ratiopath "mlflow<3.0" torch torchvision \ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you really need TensorRT for cuda 12? By default it uses cuda 13, same as torch. This way you are using two different cuda version |
||
| "timm>=1.0.0" \ | ||
| "huggingface-hub>=0.23.0" | ||
matejpekar marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,121 @@ | ||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||
| from typing import Any, TypedDict, cast | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import lz4.frame | ||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||
| from fastapi import FastAPI, Request, Response | ||||||||||||||||||||||||
| from numpy.typing import NDArray | ||||||||||||||||||||||||
| from PIL import Image | ||||||||||||||||||||||||
| from ray import serve | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class Config(TypedDict): | ||||||||||||||||||||||||
| tile_size: int | ||||||||||||||||||||||||
| model: dict[str, Any] | ||||||||||||||||||||||||
| max_batch_size: int | ||||||||||||||||||||||||
| batch_wait_timeout_s: float | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| fastapi = FastAPI() | ||||||||||||||||||||||||
| logger = logging.getLogger("ray.serve") | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @serve.deployment(num_replicas="auto") | ||||||||||||||||||||||||
| @serve.ingress(fastapi) | ||||||||||||||||||||||||
| class Virchow2: | ||||||||||||||||||||||||
| """Virchow2 foundation model for pathology.""" | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||||||||
| self.model: torch.nn.Module | None = None | ||||||||||||||||||||||||
| self.transforms: Any = None | ||||||||||||||||||||||||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||||||||||||||||||||
| self.tile_size: int = 0 | ||||||||||||||||||||||||
Jurgee marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+30
to
+34
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def reconfigure(self, config: Config) -> None: | ||||||||||||||||||||||||
| import importlib | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| import timm | ||||||||||||||||||||||||
| from timm.data.config import resolve_data_config | ||||||||||||||||||||||||
| from timm.data.transforms_factory import create_transform | ||||||||||||||||||||||||
| from timm.layers.mlp import SwiGLUPacked | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self.tile_size = config["tile_size"] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| model_config = dict(config["model"]) | ||||||||||||||||||||||||
| module_path, attr_name = model_config.pop("_target_").split(":") | ||||||||||||||||||||||||
| provider = getattr(importlib.import_module(module_path), attr_name) | ||||||||||||||||||||||||
| repo_id = model_config["repo_id"] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| logger.info(f"Loading Virchow2 model from {repo_id}...") | ||||||||||||||||||||||||
| provider(**model_config) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self.model = timm.create_model( | ||||||||||||||||||||||||
| f"hf-hub:{repo_id}", | ||||||||||||||||||||||||
| pretrained=True, | ||||||||||||||||||||||||
| num_classes=0, | ||||||||||||||||||||||||
| mlp_layer=SwiGLUPacked, | ||||||||||||||||||||||||
| act_layer=torch.nn.SiLU, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
Comment on lines
+52
to
+60
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be fragile. I guess the provider downloads the model to cache and then timm loads it from the cache based on the environmental variables set by the provider. Why do you even need to call the provider? I think timm can handle the downloading and storing to cache? |
||||||||||||||||||||||||
| self.model = self.model.to(self.device).eval() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self.transforms = create_transform( | ||||||||||||||||||||||||
| **resolve_data_config(self.model.pretrained_cfg, model=self.model) | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| self.predict.set_max_batch_size(config["max_batch_size"]) # type: ignore[attr-defined] | ||||||||||||||||||||||||
| self.predict.set_batch_wait_timeout_s(config["batch_wait_timeout_s"]) # type: ignore[attr-defined] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @serve.batch | ||||||||||||||||||||||||
| async def predict( | ||||||||||||||||||||||||
| self, inputs: list[torch.Tensor | NDArray[np.float16] | NDArray[np.float32]] | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accept only one type. Also check if ray can serialize tensors |
||||||||||||||||||||||||
| ) -> list[NDArray[np.float32]]: | ||||||||||||||||||||||||
| tensors = torch.stack( | ||||||||||||||||||||||||
| [ | ||||||||||||||||||||||||
| item if isinstance(item, torch.Tensor) else torch.from_numpy(item) | ||||||||||||||||||||||||
| for item in inputs | ||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||
| ).to(self.device) | ||||||||||||||||||||||||
| model = cast("torch.nn.Module", self.model) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| device_type = self.device.type | ||||||||||||||||||||||||
| autocast_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16 | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you switching to bfp16 for CPUs? Is fp16 not supported by the cluster CPUs? |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| with ( | ||||||||||||||||||||||||
| torch.inference_mode(), | ||||||||||||||||||||||||
| torch.autocast(device_type=device_type, dtype=autocast_dtype), | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| output = model(tensors) | ||||||||||||||||||||||||
| class_token = output[:, 0] | ||||||||||||||||||||||||
| patch_tokens = output[:, 5:] | ||||||||||||||||||||||||
| embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) | ||||||||||||||||||||||||
|
Comment on lines
+90
to
+92
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be optional. Some users might want to access the individual patch tokens |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return list(embedding.float().cpu().numpy()) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @fastapi.post("/") | ||||||||||||||||||||||||
| async def root(self, request: Request) -> Response: | ||||||||||||||||||||||||
| data = await asyncio.to_thread(lz4.frame.decompress, await request.body()) | ||||||||||||||||||||||||
| image = np.frombuffer(data, dtype=np.uint8).reshape( | ||||||||||||||||||||||||
| self.tile_size, self.tile_size, 3 | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| requested_dtype = request.headers.get("x-output-dtype", "float32").lower() | ||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The output dtype should be applied in the predict method to avoid serialization of large arrays |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| output_dtype = np.dtype(requested_dtype) | ||||||||||||||||||||||||
| tensor = self.transforms(Image.fromarray(image)) | ||||||||||||||||||||||||
Jurgee marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||
| result: NDArray[np.float32] = await self.predict(tensor) | ||||||||||||||||||||||||
| output_shape = ",".join(str(d) for d in result.shape) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| return Response( | ||||||||||||||||||||||||
| content=lz4.frame.compress( | ||||||||||||||||||||||||
| result.astype(output_dtype, copy=False).tobytes() | ||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||
| media_type="application/octet-stream", | ||||||||||||||||||||||||
| headers={ | ||||||||||||||||||||||||
| "x-output-shape": output_shape, | ||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| app = Virchow2.bind() # type: ignore[attr-defined] | ||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| apiVersion: v1 | ||
| kind: PersistentVolumeClaim | ||
| metadata: | ||
| name: huggingface-cache-pvc | ||
| namespace: rationai-jobs-ns | ||
Jurgee marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| spec: | ||
| accessModes: | ||
| - ReadWriteMany | ||
| resources: | ||
| requests: | ||
| storage: 15Gi | ||
| storageClassName: nfs-csi | ||
Uh oh!
There was an error while loading. Please reload this page.