Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a Virchow2 Ray Serve deployment with FastAPI ingress, HuggingFace caching/support (PVC, provider helper, downloader job), Dockerfile and Ray service updates for HF/GPU integration, and removes Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant RayServe as Ray Serve (Virchow2)
participant Provider as providers.huggingface
participant PVC as HuggingFace PVC
participant HFHub as HuggingFace Hub
participant Downloader as Downloader Job
Client->>RayServe: POST LZ4-compressed image(s)
RayServe->>Provider: request local path for repo_id (HF_HOME)
Provider-->>RayServe: return local cache path (exists?)
alt cache miss
RayServe->>Downloader: schedule/run downloader job
Downloader->>HFHub: snapshot_download / hf_hub_download (uses HF_TOKEN)
HFHub-->>PVC: write model files to cache
Downloader-->>RayServe: notify completion
end
RayServe->>RayServe: preprocess, batch, run timm model (GPU)
RayServe-->>Client: return embeddings (LZ4-compressed)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the model serving infrastructure by integrating the Virchow2 foundation model and optimizing existing models for GPU performance. It establishes a robust framework for deploying advanced deep learning models, leveraging TensorRT for efficient inference and Hugging Face for model management. The changes also include necessary infrastructure updates for dependency management and persistent caching, ensuring a more efficient and scalable system. Highlights
Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Virchow2 foundation model within the Ray Serve infrastructure, which is a significant addition. The changes include a new Ray Serve deployment for Virchow2, updates to Dockerfiles for GPU support with necessary dependencies, and Kubernetes configurations for model downloading and caching. The refactoring of existing models to leverage GPU and TensorRT is a great performance enhancement. However, the pull request introduces critical security vulnerabilities by including hardcoded secrets in configuration files. These must be addressed by using a secure secret management solution like Kubernetes Secrets. Additionally, there are minor areas for improvement regarding Docker image consistency and file permissions.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
models/virchow2.py (1)
50-55: Clarify provider contract: currently its return value is discarded.At Line 51,
provider(**...)is called for side effects, buttimm.create_modelat Line 53 always usesrepo_id. This makes_target_abstraction misleading (compare withmodels/semantic_segmentation.py:55-69, where provider output is consumed directly).Consider either:
- explicitly documenting provider as cache warm-up only, or
- consuming provider output as the model source of truth.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@models/virchow2.py` around lines 50 - 55, The call to provider(**config["model"]) in the constructor is currently invoked only for side effects and its return value is ignored while timm.create_model always uses repo_id; either consume the provider return as the authoritative model source or explicitly document it as a cache-warmup-only call. Update the code so that provider(...)'s return (e.g., a model path, HF repo override, or config dict) is checked and passed into timm.create_model (replace f"hf-hub:{repo_id}" with the provider-provided identifier) and assign to self.model, or alternatively add a clear comment and/or refactor the _target_ contract and call site to match models/semantic_segmentation.py behavior where the provider output is consumed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@models/virchow2.py`:
- Line 66: The log message always says "moved to GPU" even when the model is
loaded to CPU; update the logger call that currently reads logger.info("Virchow2
model loaded and moved to GPU.") to report the actual device variable used
(e.g., include the `device` or `device_str` value computed when selecting
CPU/GPU) so the message becomes device-accurate; locate the model load/transfer
code (where `model.to(device)` or device selection is performed) and change the
info log to include that device identifier.
- Around line 46-48: The reconfigure function is mutating config["model"] by
calling .pop("_target_"), which will remove the key and break subsequent calls;
instead, read the target without mutating (e.g., target =
config["model"].get("_target_") or work on a shallow copy of config["model"])
and then split that target into module_path and attr_name, and fall back to a
clear error if the key is missing; update the code references in reconfigure
that currently use module_path, attr_name, provider to use the non-mutating
value so the original user_config remains unchanged.
- Around line 99-107: The code currently lets client-caused errors from
lz4.frame.decompress, np.frombuffer/reshape and np.dtype bubble up as 500s; wrap
the decompression/reshape and dtype parsing in a try/except that catches
RuntimeError, ValueError and TypeError (and optionally OverflowError) and
convert them to a client error (raise an HTTP 4xx, e.g.,
HTTPException(status_code=400) or 422) with a clear message; specifically
protect the lz4.frame.decompress call, the
np.frombuffer(...).reshape(self.tile_size, self.tile_size, 3) sequence, and the
np.dtype(requested_dtype) call (and validate
request.headers.get("x-output-dtype") before parsing) so malformed compressed
payloads, wrong buffer sizes, or invalid dtypes return 4xx instead of 500.
---
Nitpick comments:
In `@models/virchow2.py`:
- Around line 50-55: The call to provider(**config["model"]) in the constructor
is currently invoked only for side effects and its return value is ignored while
timm.create_model always uses repo_id; either consume the provider return as the
authoritative model source or explicitly document it as a cache-warmup-only
call. Update the code so that provider(...)'s return (e.g., a model path, HF
repo override, or config dict) is checked and passed into timm.create_model
(replace f"hf-hub:{repo_id}" with the provider-provided identifier) and assign
to self.model, or alternatively add a clear comment and/or refactor the _target_
contract and call site to match models/semantic_segmentation.py behavior where
the provider output is consumed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 85e59ece-bf28-466e-a429-d0cc10ffa6ae
📒 Files selected for processing (2)
models/virchow2.pyproviders/model_provider.py
🚧 Files skipped from review as they are similar to previous changes (1)
- providers/model_provider.py
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
ray-service.yaml (2)
257-257:⚠️ Potential issue | 🟠 MajorReplace mutable
latest-gpuimage tag with a pinned version (and digest).Line 257 can drift independently of
rayVersion: 2.53.0, causing rollout non-determinism.Suggested change
- image: cerit.io/rationai/model-service:latest-gpu + image: cerit.io/rationai/model-service:2.53.0-gpu@sha256:<digest>🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ray-service.yaml` at line 257, The image reference "image: cerit.io/rationai/model-service:latest-gpu" is using a mutable tag; replace it with a specific, immutable image: use the release-aligned semantic version (matching rayVersion: 2.53.0 if applicable) and append the image digest (sha256) so the line becomes a pinned image (e.g., cerit.io/rationai/model-service:<version>@sha256:<digest>); update CI/build manifests that produce the digest or fetch the digest from your registry and ensure the new pinned string replaces the "latest-gpu" token to guarantee deterministic rollouts.
92-95:⚠️ Potential issue | 🟠 MajorPin
working_dirto an immutable commit archive.Line 95 points to
refs/heads/master.zip, which is mutable and breaks reproducibility/rollback guarantees on restart.Suggested change
- working_dir: https://github.com/RationAI/model-service/archive/refs/heads/master.zip + working_dir: https://github.com/RationAI/model-service/archive/<commit-sha>.zip🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@ray-service.yaml` around lines 92 - 95, The working_dir currently points to a mutable branch archive (https://.../refs/heads/master.zip) which breaks reproducibility; change the runtime_env.working_dir to an immutable commit archive by replacing the refs/heads/master.zip URL with the repository archive URL for a specific commit SHA (e.g. https://github.com/RationAI/model-service/archive/<commit-sha>.zip). Locate the runtime_env block and update working_dir accordingly, ensuring you use a pinned commit SHA (not a branch name) and update any deployment docs to record the chosen SHA for rollbacks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docker/Dockerfile.gpu`:
- Around line 56-64: The Dockerfile's GPU pip installs (the pip install lines
installing onnxruntime-gpu, tensorrt-cu12, torch, torchvision, timm and
huggingface-hub) must use explicit, validated version pins to avoid upstream
drift; update the two pip install invocations to pin each critical package
(e.g., torch==2.4.1+cu121 and torchvision==0.19.1+cu121 or your tested
equivalents, timm==0.9.11+, and fixed onnxruntime-gpu and tensorrt-cu12 versions
that you validated for Python 3.12 + CUDA 12.1) and document the chosen
combinations in the Dockerfile comment so future rebuilds use the same, tested
package matrix.
In `@ray-service.yaml`:
- Around line 270-274: The HF_TOKEN secret is being injected into all GPU worker
pods (env name HF_TOKEN from secret huggingface-secret) even though
providers/model_provider.py already uses local_files_only=True and does not need
the token; remove HF_TOKEN from the generic GPU worker container spec and
instead inject the secret only into the specific pod/container that performs
authenticated Hugging Face calls (the component that imports
providers/model_provider.py or any service that sets local_files_only=False).
Update ray-service.yaml to delete the HF_TOKEN env entry from the generic worker
template and add it to the targeted deployment/container spec (or create a
separate pod template/service account) so only that component receives
huggingface-secret.
---
Duplicate comments:
In `@ray-service.yaml`:
- Line 257: The image reference "image:
cerit.io/rationai/model-service:latest-gpu" is using a mutable tag; replace it
with a specific, immutable image: use the release-aligned semantic version
(matching rayVersion: 2.53.0 if applicable) and append the image digest (sha256)
so the line becomes a pinned image (e.g.,
cerit.io/rationai/model-service:<version>@sha256:<digest>); update CI/build
manifests that produce the digest or fetch the digest from your registry and
ensure the new pinned string replaces the "latest-gpu" token to guarantee
deterministic rollouts.
- Around line 92-95: The working_dir currently points to a mutable branch
archive (https://.../refs/heads/master.zip) which breaks reproducibility; change
the runtime_env.working_dir to an immutable commit archive by replacing the
refs/heads/master.zip URL with the repository archive URL for a specific commit
SHA (e.g. https://github.com/RationAI/model-service/archive/<commit-sha>.zip).
Locate the runtime_env block and update working_dir accordingly, ensuring you
use a pinned commit SHA (not a branch name) and update any deployment docs to
record the chosen SHA for rollbacks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1e5753ec-93b8-4bb1-86fb-69a3fdeac1c2
📒 Files selected for processing (2)
docker/Dockerfile.gpuray-service.yaml
There was a problem hiding this comment.
♻️ Duplicate comments (2)
models/virchow2.py (2)
46-51:⚠️ Potential issue | 🟡 MinorAvoid mutating deployment config in
reconfigure.Line 46 uses
.pop("_target_"), which mutatesconfig["model"]. Reconfigure can be invoked multiple times, so this can make subsequent calls fail when_target_is missing.♻️ Proposed fix
- module_path, attr_name = config["model"].pop("_target_").split(":") + model_cfg = dict(config["model"]) + module_path, attr_name = model_cfg["_target_"].split(":") provider = getattr(importlib.import_module(module_path), attr_name) - repo_id = config["model"]["repo_id"] + repo_id = model_cfg["repo_id"] @@ - provider(**config["model"]) + provider(**{k: v for k, v in model_cfg.items() if k != "_target_"})#!/bin/bash # Verify mutation pattern is present and compare with similar deployment code. rg -nP 'config\["model"\]\.pop\("_target_"\)' models/virchow2.py models/binary_classifier.py -C2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@models/virchow2.py` around lines 46 - 51, The code currently mutates config["model"] by calling .pop("_target_"), which breaks repeated calls (e.g., reconfigure); instead, read the target string without removing it and pass a non-mutated kwargs dict to the provider: extract target_str = config["model"]["_target_"] then do module_path, attr_name = target_str.split(":"), load provider = getattr(importlib.import_module(module_path), attr_name), build model_kwargs = {k:v for k,v in config["model"].items() if k != "_target_"} (or use config["model"].copy() and del the key on the copy) and call provider(**model_kwargs); keep using repo_id and logger.info as before.
97-104:⚠️ Potential issue | 🟠 MajorMap malformed input and dtype parsing errors to 4xx responses.
Lines 97-104 can raise client-caused exceptions (
lz4.frame.decompress,reshape,np.dtype) that currently bubble as 500s. This should return a 4xx with a clear message.🛠️ Proposed fix
- data = await asyncio.to_thread(lz4.frame.decompress, await request.body()) - image = np.frombuffer(data, dtype=np.uint8).reshape( - self.tile_size, self.tile_size, 3 - ) - requested_dtype = request.headers.get("x-output-dtype", "float32").lower() - - output_dtype = np.dtype(requested_dtype) + if requested_dtype not in {"float16", "float32"}: + return Response("Unsupported x-output-dtype", status_code=400) + + try: + data = await asyncio.to_thread(lz4.frame.decompress, await request.body()) + image = np.frombuffer(data, dtype=np.uint8).reshape( + self.tile_size, self.tile_size, 3 + ) + output_dtype = np.dtype(requested_dtype) + except (RuntimeError, ValueError, TypeError, OverflowError): + return Response("Malformed request payload", status_code=400)#!/bin/bash # Verify risky parsing operations and whether 4xx handling is present around them. rg -nP 'lz4\.frame\.decompress|np\.frombuffer|reshape\(|np\.dtype\(' models/virchow2.py -C2 rg -nP 'HTTPException|status_code\s*=\s*4[0-9]{2}' models/virchow2.py -C2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@models/virchow2.py` around lines 97 - 104, The code that calls lz4.frame.decompress, np.frombuffer(...).reshape(...), and np.dtype(...) can raise client-caused errors and should be mapped to 4xx responses: wrap the risky block (the async to_thread call that decompresses request.body(), the np.frombuffer(...).reshape(self.tile_size, self.tile_size, 3) call, and the np.dtype(requested_dtype) parsing) in a try/except that catches lz4.frame.LZ4FrameError (or a broad lz4 error), ValueError and TypeError and then raise fastapi.HTTPException(status_code=400, detail="...") with a short, clear message (include the original exception text) instead of allowing a 500; reference the handler method containing these calls and use HTTPException/status_code=400 for the response.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@models/virchow2.py`:
- Around line 46-51: The code currently mutates config["model"] by calling
.pop("_target_"), which breaks repeated calls (e.g., reconfigure); instead, read
the target string without removing it and pass a non-mutated kwargs dict to the
provider: extract target_str = config["model"]["_target_"] then do module_path,
attr_name = target_str.split(":"), load provider =
getattr(importlib.import_module(module_path), attr_name), build model_kwargs =
{k:v for k,v in config["model"].items() if k != "_target_"} (or use
config["model"].copy() and del the key on the copy) and call
provider(**model_kwargs); keep using repo_id and logger.info as before.
- Around line 97-104: The code that calls lz4.frame.decompress,
np.frombuffer(...).reshape(...), and np.dtype(...) can raise client-caused
errors and should be mapped to 4xx responses: wrap the risky block (the async
to_thread call that decompresses request.body(), the
np.frombuffer(...).reshape(self.tile_size, self.tile_size, 3) call, and the
np.dtype(requested_dtype) parsing) in a try/except that catches
lz4.frame.LZ4FrameError (or a broad lz4 error), ValueError and TypeError and
then raise fastapi.HTTPException(status_code=400, detail="...") with a short,
clear message (include the original exception text) instead of allowing a 500;
reference the handler method containing these calls and use
HTTPException/status_code=400 for the response.
matejpekar
left a comment
There was a problem hiding this comment.
This PR is not inline with #5.
|
|
||
| RUN pip install --no-cache-dir onnxruntime-gpu tensorrt lz4 ratiopath "mlflow<3.0" | ||
| RUN pip install --no-cache-dir \ | ||
| onnxruntime-gpu tensorrt-cu12 lz4 ratiopath "mlflow<3.0" torch torchvision \ |
There was a problem hiding this comment.
Do you really need TensorRT for cuda 12? By default it uses cuda 13, same as torch. This way you are using two different cuda version
| def __init__(self) -> None: | ||
| self.model: torch.nn.Module | None = None | ||
| self.transforms: Any = None | ||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| self.tile_size: int = 0 |
There was a problem hiding this comment.
| def __init__(self) -> None: | |
| self.model: torch.nn.Module | None = None | |
| self.transforms: Any = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tile_size: int = 0 | |
| model: torch.nn.Module | |
| transforms: Any | |
| tile_size: int | |
| def __init__(self) -> None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.tile_size, self.tile_size, 3 | ||
| ) | ||
|
|
||
| requested_dtype = request.headers.get("x-output-dtype", "float32").lower() |
There was a problem hiding this comment.
The output dtype should be applied in the predict method to avoid serialization of large arrays
|
|
||
| @serve.batch | ||
| async def predict( | ||
| self, inputs: list[torch.Tensor | NDArray[np.float16] | NDArray[np.float32]] |
There was a problem hiding this comment.
Accept only one type. Also check if ray can serialize tensors
| model = cast("torch.nn.Module", self.model) | ||
|
|
||
| device_type = self.device.type | ||
| autocast_dtype = torch.float16 if device_type == "cuda" else torch.bfloat16 |
There was a problem hiding this comment.
Why are you switching to bfp16 for CPUs? Is fp16 not supported by the cluster CPUs?
| class_token = output[:, 0] | ||
| patch_tokens = output[:, 5:] | ||
| embedding = torch.cat([class_token, patch_tokens.mean(dim=1)], dim=-1) |
There was a problem hiding this comment.
This should be optional. Some users might want to access the individual patch tokens
| provider(**model_config) | ||
|
|
||
| self.model = timm.create_model( | ||
| f"hf-hub:{repo_id}", | ||
| pretrained=True, | ||
| num_classes=0, | ||
| mlp_layer=SwiGLUPacked, | ||
| act_layer=torch.nn.SiLU, | ||
| ) |
There was a problem hiding this comment.
This seems to be fragile. I guess the provider downloads the model to cache and then timm loads it from the cache based on the environmental variables set by the provider. Why do you even need to call the provider? I think timm can handle the downloading and storing to cache?
This PR introduces support for the Virchow2 foundation model (paige-ai/Virchow2) within the Ray Serve infrastructure.
New Model Deployment: Added virchow2.py implementing the Virchow2 class as a Ray Serve deployment
Summary by CodeRabbit
New Features
Refactor
Chores