From f053a2bb06877edbdc419849d23c53056a9d2ca3 Mon Sep 17 00:00:00 2001 From: vidyuthdev Date: Sun, 12 Apr 2026 12:46:23 -0400 Subject: [PATCH 1/3] Add end-to-end API smoke test and document usage in README - Add scripts/smoke_test.sh: plan creation, job polling, artifact download - Document smoke test steps after docker compose in README Made-with: Cursor --- README.md | 15 ++++++ scripts/smoke_test.sh | 121 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100755 scripts/smoke_test.sh diff --git a/README.md b/README.md index d4bf578..3c3330e 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,21 @@ docker compose up -d # Starts all 5 services curl http://localhost:8000/api/v1/info ``` +### Smoke test (recommended) + +After the stack is up, run an end-to-end API smoke test (plan → job → artifacts): + +```bash +chmod +x scripts/smoke_test.sh +./scripts/smoke_test.sh +``` + +To target a different API base URL: + +```bash +API_BASE_URL="http://127.0.0.1:8000/api/v1" ./scripts/smoke_test.sh +``` + | Service | Port | Description | |---|---|---| | `api` | 8000 | Radiarch FastAPI server | diff --git a/scripts/smoke_test.sh b/scripts/smoke_test.sh new file mode 100755 index 0000000..58fb964 --- /dev/null +++ b/scripts/smoke_test.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env bash +set -euo pipefail + +API_BASE_URL="${API_BASE_URL:-http://127.0.0.1:8000/api/v1}" + +echo "Radiarch smoke test" +echo "API_BASE_URL=${API_BASE_URL}" +echo + +echo "1) GET /info" +curl -fsS "${API_BASE_URL}/info" | python3 -m json.tool >/dev/null +echo " OK" +echo + +echo "2) POST /plans (create a minimal plan)" +CREATE_RESP="$( + curl -fsS "${API_BASE_URL}/plans" \ + -H "Content-Type: application/json" \ + -d '{ + "workflow_id": "proton-impt-basic", + "study_instance_uid": "1.2.826.0.1.3680043.8.1055.1", + "prescription_gy": 2.0, + "beam_count": 3, + "fraction_count": 1, + "notes": "smoke-test" + }' +)" +export CREATE_RESP +PLAN_ID="$(python3 - <<'PY' +import json, os, sys +data = json.loads(os.environ["CREATE_RESP"]) +print(data["id"]) +PY +)" +JOB_ID="$(python3 - <<'PY' +import json, os, sys +data = json.loads(os.environ["CREATE_RESP"]) +print(data.get("job_id") or "") +PY +)" +echo " plan_id=${PLAN_ID}" +echo " job_id=${JOB_ID}" +echo + +if [[ -z "${JOB_ID}" ]]; then + echo "ERROR: plan response did not include job_id" + exit 1 +fi + +echo "3) Poll /jobs/${JOB_ID} until terminal state" +terminal="succeeded failed cancelled" +state="" +for i in $(seq 1 180); do + JOB_JSON="$(curl -fsS "${API_BASE_URL}/jobs/${JOB_ID}")" + export JOB_JSON + state="$(python3 - <<'PY' +import json, os +print(json.loads(os.environ["JOB_JSON"])["state"]) +PY +)" + progress="$(python3 - <<'PY' +import json, os +print(json.loads(os.environ["JOB_JSON"]).get("progress", 0.0)) +PY +)" + printf " [%3s/180] state=%s progress=%s\r" "$i" "$state" "$progress" + if [[ " ${terminal} " == *" ${state} "* ]]; then + echo + break + fi + sleep 1 +done +echo " final state=${state}" +echo + +if [[ "${state}" != "succeeded" ]]; then + echo "ERROR: job did not succeed (state=${state})" + echo "Job payload:" + echo "${JOB_JSON}" | python3 -m json.tool || true + exit 1 +fi + +echo "4) GET /plans/${PLAN_ID} and check artifacts + qa_summary" +PLAN_JSON="$(curl -fsS "${API_BASE_URL}/plans/${PLAN_ID}")" +export PLAN_JSON +ARTIFACT_COUNT="$(python3 - <<'PY' +import json, os +data = json.loads(os.environ["PLAN_JSON"]) +print(len(data.get("artifact_ids") or [])) +PY +)" +QA_ENGINE="$(python3 - <<'PY' +import json, os +data = json.loads(os.environ["PLAN_JSON"]) +qa = data.get("qa_summary") or {} +print(qa.get("engine") or "") +PY +)" +echo " artifacts=${ARTIFACT_COUNT}" +echo " qa_engine=${QA_ENGINE:-}" +echo + +if [[ "${ARTIFACT_COUNT}" -lt 1 ]]; then + echo "ERROR: expected at least 1 artifact id on plan" + echo "${PLAN_JSON}" | python3 -m json.tool || true + exit 1 +fi + +ARTIFACT_ID="$(python3 - <<'PY' +import json, os +data = json.loads(os.environ["PLAN_JSON"]) +print((data.get("artifact_ids") or [""])[0]) +PY +)" +echo "5) GET /artifacts/${ARTIFACT_ID} (download first artifact)" +TMP_OUT="$(mktemp -t radiarch_artifact.XXXXXX)" +curl -fsS "${API_BASE_URL}/artifacts/${ARTIFACT_ID}" -o "${TMP_OUT}" +echo " downloaded=$(wc -c < "${TMP_OUT}") bytes to ${TMP_OUT}" +echo + +echo "Smoke test succeeded." From c4bb802e02afd552f4b39999c111dc56f2c2c818 Mon Sep 17 00:00:00 2001 From: vidyuthdev Date: Mon, 20 Apr 2026 10:00:49 -0400 Subject: [PATCH 2/3] Add geometry service pipeline with DICOM fetch, voxel build, and API routes. Implement Service 1 end-to-end with HU-to-density conversion models, axis-aligned resampling, contour rasterization, on-disk geometry persistence/cache, Orthanc adapter retrieval support, and /api/v1/geometry endpoints plus comprehensive unit and API tests. Made-with: Cursor --- src/radiarch/adapters/orthanc.py | 132 ++++++++- src/radiarch/api/routes/geometry.py | 92 ++++++ src/radiarch/app.py | 3 +- src/radiarch/models/geometry.py | 238 +++++++++++++++ src/radiarch/services/__init__.py | 10 + src/radiarch/services/dicom_fetcher.py | 227 ++++++++++++++ src/radiarch/services/geometry.py | 373 ++++++++++++++++++++++++ src/radiarch/services/hu_density.py | 188 ++++++++++++ src/radiarch/services/persistence.py | 241 +++++++++++++++ src/radiarch/services/rasterization.py | 210 +++++++++++++ src/radiarch/services/resampling.py | 157 ++++++++++ tests/services/__init__.py | 0 tests/services/test_dicom_fetcher.py | 325 +++++++++++++++++++++ tests/services/test_geometry_models.py | 133 +++++++++ tests/services/test_geometry_service.py | 317 ++++++++++++++++++++ tests/services/test_hu_density.py | 125 ++++++++ tests/services/test_persistence.py | 190 ++++++++++++ tests/services/test_rasterization.py | 215 ++++++++++++++ tests/services/test_resampling.py | 145 +++++++++ tests/test_api_geometry.py | 165 +++++++++++ 20 files changed, 3484 insertions(+), 2 deletions(-) create mode 100644 src/radiarch/api/routes/geometry.py create mode 100644 src/radiarch/models/geometry.py create mode 100644 src/radiarch/services/__init__.py create mode 100644 src/radiarch/services/dicom_fetcher.py create mode 100644 src/radiarch/services/geometry.py create mode 100644 src/radiarch/services/hu_density.py create mode 100644 src/radiarch/services/persistence.py create mode 100644 src/radiarch/services/rasterization.py create mode 100644 src/radiarch/services/resampling.py create mode 100644 tests/services/__init__.py create mode 100644 tests/services/test_dicom_fetcher.py create mode 100644 tests/services/test_geometry_models.py create mode 100644 tests/services/test_geometry_service.py create mode 100644 tests/services/test_hu_density.py create mode 100644 tests/services/test_persistence.py create mode 100644 tests/services/test_rasterization.py create mode 100644 tests/services/test_resampling.py create mode 100644 tests/test_api_geometry.py diff --git a/src/radiarch/adapters/orthanc.py b/src/radiarch/adapters/orthanc.py index 1ce53f7..3a358d9 100644 --- a/src/radiarch/adapters/orthanc.py +++ b/src/radiarch/adapters/orthanc.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, List, Optional from loguru import logger @@ -25,6 +25,8 @@ class StudyMetadata: class OrthancAdapterBase: + # ---- Metadata / control-plane methods ----------------------------- + def get_study(self, study_instance_uid: str) -> Optional[StudyMetadata]: raise NotImplementedError @@ -34,6 +36,38 @@ def get_segmentation(self, sop_instance_uid: str) -> Optional[Dict[str, Any]]: def store_artifact(self, dataset_bytes: bytes, content_type: str = "application/dicom") -> str: raise NotImplementedError + # ---- Data-plane methods (Geometry Service path) ------------------- + + def can_retrieve_instances(self) -> bool: + """Does this adapter support downloading DICOM instance bytes? + + Returns False for metadata-only fake adapters — callers can use + this to decide whether to fall back to a local data root. + """ + return False + + def search_for_series( + self, + study_instance_uid: str, + modality: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Return series-level metadata for ``study_instance_uid``. + + Each entry should at minimum expose ``SeriesInstanceUID`` and + ``Modality``. Implementations may also include + ``NumberOfSeriesRelatedInstances`` so the Geometry Service can + pick the largest CT series when no explicit UID is provided. + """ + raise NotImplementedError + + def retrieve_series( + self, + study_instance_uid: str, + series_instance_uid: str, + ) -> Iterable[Any]: + """Yield ``pydicom.Dataset`` objects for every instance in a series.""" + raise NotImplementedError + class FakeOrthancAdapter(OrthancAdapterBase): def get_study(self, study_instance_uid: str) -> Optional[StudyMetadata]: @@ -49,6 +83,31 @@ def store_artifact(self, dataset_bytes: bytes, content_type: str = "application/ logger.debug("Storing artifact in fake adapter (content type %s, %s bytes)", content_type, len(dataset_bytes)) return "mock-artifact-uid" + # The fake adapter has no DICOM bytes — expose metadata only so the + # Geometry Service can cleanly detect "mock mode" and fall back to + # the local data root. + def can_retrieve_instances(self) -> bool: + return False + + def search_for_series( + self, + study_instance_uid: str, + modality: Optional[str] = None, + ) -> List[Dict[str, Any]]: + study = sample_data.SAMPLE_STUDIES.get(study_instance_uid) + if not study: + return [] + series = study.get("series", []) + if modality: + series = [s for s in series if s.get("Modality") == modality] + return list(series) + + def retrieve_series(self, study_instance_uid, series_instance_uid): + raise OrthancAdapterError( + "FakeOrthancAdapter does not serve DICOM bytes; " + "set RADIARCH_ORTHANC_USE_MOCK=false or use a real adapter." + ) + class OrthancAdapter(OrthancAdapterBase): def __init__(self, settings: Settings): @@ -89,6 +148,77 @@ def store_artifact(self, dataset_bytes: bytes, content_type: str = "application/ raise OrthancAdapterError(f"Failed to store artifact: {exc}") from exc return result[0]["ID"] if result else "" + def can_retrieve_instances(self) -> bool: + return True + + def search_for_series( + self, + study_instance_uid: str, + modality: Optional[str] = None, + ) -> List[Dict[str, Any]]: + filters: Dict[str, Any] = {"StudyInstanceUID": study_instance_uid} + if modality: + filters["Modality"] = modality + try: + results = self.client.search_for_series(search_filters=filters) + except Exception as exc: # pragma: no cover — network paths + raise OrthancAdapterError( + f"search_for_series failed for study {study_instance_uid}: {exc}" + ) from exc + return self._flatten_series_metadata(results) + + def retrieve_series( + self, + study_instance_uid: str, + series_instance_uid: str, + ) -> Iterable[Any]: + try: + yield from self.client.retrieve_series( + study_instance_uid=study_instance_uid, + series_instance_uid=series_instance_uid, + ) + except Exception as exc: # pragma: no cover — network paths + raise OrthancAdapterError( + f"retrieve_series failed for {series_instance_uid}: {exc}" + ) from exc + + @staticmethod + def _flatten_series_metadata(results: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Normalize DICOMwebClient search responses. + + dicomweb-client returns tag-addressed dicts (``{"0020000E": + {"Value": ["..."]}}``) which are miserable downstream. Flatten + the handful of tags we actually use into plain strings so callers + can treat the result as a regular dict. + """ + flat: List[Dict[str, Any]] = [] + for entry in results: + if not isinstance(entry, dict): + continue + series_uid = _tag_value(entry, "0020000E") or entry.get("SeriesInstanceUID") + modality = _tag_value(entry, "00080060") or entry.get("Modality") + count = _tag_value(entry, "00201209") or entry.get("NumberOfSeriesRelatedInstances") + flat.append( + { + "SeriesInstanceUID": series_uid, + "Modality": modality, + "NumberOfSeriesRelatedInstances": int(count) if count is not None else None, + "_raw": entry, + } + ) + return flat + + +def _tag_value(entry: Dict[str, Any], tag: str) -> Any: + """Pull the first Value out of a DICOMweb tag-addressed dict, if present.""" + node = entry.get(tag) + if not isinstance(node, dict): + return None + values = node.get("Value") + if not values: + return None + return values[0] + def build_orthanc_adapter(settings: Settings | None = None) -> OrthancAdapterBase: settings = settings or get_settings() diff --git a/src/radiarch/api/routes/geometry.py b/src/radiarch/api/routes/geometry.py new file mode 100644 index 0000000..bc0cfc1 --- /dev/null +++ b/src/radiarch/api/routes/geometry.py @@ -0,0 +1,92 @@ +"""FastAPI routes for the Geometry Service (Service 1). + +v1 is synchronous — ``POST /geometry/build`` runs the build in-process +and returns the :class:`GeometryResult` directly. The ``/jobs`` variant +listed in the spec comes with the async-mode PR (Celery + DB); in the +meantime, we expose a placeholder so existing clients can evolve. + +Endpoints +--------- +``POST /api/v1/geometry/build`` — build (or reuse cached) geometry. +``GET /api/v1/geometry/{id}`` — retrieve cached geometry metadata. +``GET /api/v1/geometry/{id}/density``— stream density NIfTI. +``GET /api/v1/geometry/{id}/masks`` — stream multi-label mask NIfTI. +""" + +from __future__ import annotations + +import os +from functools import lru_cache + +from fastapi import APIRouter, HTTPException +from fastapi.responses import FileResponse + +from ...models.geometry import GeometryBuildRequest, GeometryResult +from ...services.geometry import GeometryService +from ...services.persistence import DENSITY_FILENAME, MASKS_FILENAME + +router = APIRouter(prefix="/geometry", tags=["geometry"]) + + +@lru_cache(maxsize=1) +def _service() -> GeometryService: + """Singleton service instance. Cached for the process lifetime so + every request reuses the same on-disk store + cache index.""" + return GeometryService() + + +@router.post( + "/build", + response_model=GeometryResult, + summary="Build (or reuse cached) geometry from a DICOM study.", +) +async def build_geometry(request: GeometryBuildRequest) -> GeometryResult: + try: + return _service().build(request) + except FileNotFoundError as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except ValueError as exc: + # Request-level validation problems — e.g., rotated affine, + # underspecified grid. Bubble up as 422 so clients can act. + raise HTTPException(status_code=422, detail=str(exc)) from exc + + +@router.get( + "/{geometry_id}", + response_model=GeometryResult, + summary="Retrieve completed geometry metadata.", +) +async def get_geometry(geometry_id: str) -> GeometryResult: + result = _service().store.get_by_id(geometry_id) + if result is None: + raise HTTPException(status_code=404, detail=f"Geometry not found: {geometry_id}") + return result + + +@router.get( + "/{geometry_id}/density", + summary="Stream the density NIfTI volume.", + response_class=FileResponse, +) +async def get_density(geometry_id: str): + return _stream_geometry_file(geometry_id, DENSITY_FILENAME, "application/gzip") + + +@router.get( + "/{geometry_id}/masks", + summary="Stream the multi-label mask NIfTI volume.", + response_class=FileResponse, +) +async def get_masks(geometry_id: str): + return _stream_geometry_file(geometry_id, MASKS_FILENAME, "application/gzip") + + +def _stream_geometry_file(geometry_id: str, filename: str, media_type: str): + result = _service().store.get_by_id(geometry_id) + if result is None: + raise HTTPException(status_code=404, detail=f"Geometry not found: {geometry_id}") + + base = _service().store.base_dir / geometry_id / filename + if not os.path.isfile(base): + raise HTTPException(status_code=410, detail=f"{filename} no longer on disk") + return FileResponse(path=str(base), media_type=media_type, filename=filename) diff --git a/src/radiarch/app.py b/src/radiarch/app.py index 2a04a8e..6f6d4ea 100644 --- a/src/radiarch/app.py +++ b/src/radiarch/app.py @@ -4,7 +4,7 @@ from fastapi.middleware.cors import CORSMiddleware from .config import get_settings -from .api.routes import info, plans, jobs, artifacts, workflows, sessions, simulations +from .api.routes import info, plans, jobs, artifacts, workflows, sessions, simulations, geometry from .adapters import build_orthanc_adapter from .core.database import init_db @@ -41,6 +41,7 @@ def create_app() -> FastAPI: app.include_router(artifacts.router, prefix=settings.api_prefix) app.include_router(sessions.router, prefix=settings.api_prefix) app.include_router(simulations.router, prefix=settings.api_prefix) + app.include_router(geometry.router, prefix=settings.api_prefix) @app.get("/") async def root(): diff --git a/src/radiarch/models/geometry.py b/src/radiarch/models/geometry.py new file mode 100644 index 0000000..44c5948 --- /dev/null +++ b/src/radiarch/models/geometry.py @@ -0,0 +1,238 @@ +"""Pydantic I/O models for the Geometry Service. + +Service 1 converts raw clinical DICOM (CT + RTSTRUCT) into a +computation-ready voxel model consumable by downstream dose / optimization +services. This module defines the public request / response schemas. + +See ``docs/tps_services_implementation_plan.md`` (Service 1) for the full +specification. Highlights: + + - ``GeometryBuildRequest`` — inputs: patient_ref, target GridSpec, + HU→density model choice, optional structure-name aliasing. + - ``GeometryResult`` — outputs: URIs to the density grid and multi-label + mask volume, a structure_index name→label map, the realized GridSpec + with its 4×4 affine, and a content-addressable ``cache_key``. + +The cache_key is deliberately stable across runs so identical requests can +short-circuit to a cached geometry without rebuilding. +""" + +from __future__ import annotations + +import hashlib +import json +from datetime import datetime +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import numpy as np +from pydantic import BaseModel, Field, field_validator, model_validator + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + +class HUDensityModel(str, Enum): + """Selectable HU → mass density conversion models. + + SCHNEIDER — Piecewise-linear Schneider-2000 calibration. + STOICHIOMETRIC — Vendored MCsquare CT calibration (mass density from + tissue-composition stoichiometry). Most accurate for + proton dose calculation. + LINEAR — Simple ρ = max(0, 1 + HU/1000). Fast, for tests and + synthetic mode. + """ + + schneider = "SCHNEIDER" + stoichiometric = "STOICHIOMETRIC" + linear = "LINEAR" + + +# --------------------------------------------------------------------------- +# GridSpec +# --------------------------------------------------------------------------- + +class GridSpec(BaseModel): + """Axis-aligned voxel grid in patient coordinates (LPS). + + ``spacing_mm`` is always required. Leaving ``origin_mm`` or ``size`` + null in an input *request* means "inherit from the source CT"; the + service will fill them in before returning a GridSpec in a + ``GeometryResult``. + + Axes correspond to the stored array layout (i, j, k). The ``affine`` + is derived from spacing + origin on demand — no rotational component + is supported in v1 (rows/cols strictly aligned to patient axes). + """ + + spacing_mm: Tuple[float, float, float] = Field( + ..., description="Voxel spacing in mm along (i, j, k)." + ) + origin_mm: Optional[Tuple[float, float, float]] = Field( + default=None, + description="Origin (mm) of voxel (0,0,0) in patient LPS coordinates.", + ) + size: Optional[Tuple[int, int, int]] = Field( + default=None, description="Number of voxels along (i, j, k)." + ) + + # Populated on output only — derived from spacing + origin. + affine: Optional[List[List[float]]] = Field( + default=None, + description="4×4 voxel-index → patient-LPS affine. Derived; set on output.", + ) + + @field_validator("spacing_mm") + @classmethod + def _spacing_positive(cls, v: Tuple[float, float, float]) -> Tuple[float, float, float]: + if any(s <= 0 for s in v): + raise ValueError(f"spacing_mm must be strictly positive, got {v}") + return v + + @field_validator("size") + @classmethod + def _size_positive(cls, v: Optional[Tuple[int, int, int]]) -> Optional[Tuple[int, int, int]]: + if v is not None and any(s <= 0 for s in v): + raise ValueError(f"size entries must be positive, got {v}") + return v + + # ---- Helpers (not part of the public schema) -------------------------- + + def compute_affine(self) -> List[List[float]]: + """Build a 4×4 voxel-index → patient-LPS affine from spacing + origin. + + Requires both ``origin_mm`` and ``spacing_mm`` to be populated. + """ + if self.origin_mm is None: + raise ValueError("compute_affine requires origin_mm to be set") + sx, sy, sz = self.spacing_mm + ox, oy, oz = self.origin_mm + return [ + [sx, 0.0, 0.0, ox], + [0.0, sy, 0.0, oy], + [0.0, 0.0, sz, oz], + [0.0, 0.0, 0.0, 1.0], + ] + + def to_numpy_affine(self) -> np.ndarray: + return np.asarray(self.compute_affine(), dtype=np.float64) + + def is_fully_specified(self) -> bool: + return self.origin_mm is not None and self.size is not None + + +# --------------------------------------------------------------------------- +# Patient reference +# --------------------------------------------------------------------------- + +class PatientRef(BaseModel): + """Points the Geometry Service at a specific CT + RTSTRUCT pair.""" + + dicom_study_uid: str = Field(..., description="DICOM Study Instance UID") + ct_series_uid: Optional[str] = Field( + default=None, + description="CT Series Instance UID. Null = auto-detect the primary CT.", + ) + rtstruct_uid: Optional[str] = Field( + default=None, + description="RTSTRUCT Series Instance UID. Null = auto-detect.", + ) + + +# --------------------------------------------------------------------------- +# Request / response +# --------------------------------------------------------------------------- + +class GeometryBuildRequest(BaseModel): + """Input payload for POST /api/v1/geometry/build.""" + + patient_ref: PatientRef + grid_spec: Optional[GridSpec] = Field( + default=None, + description="Target grid. Null = match the source CT grid exactly (fast path).", + ) + hu_to_density_model: HUDensityModel = Field( + default=HUDensityModel.stoichiometric, + description="Which HU→density conversion to use.", + ) + structure_name_map: Optional[Dict[str, List[str]]] = Field( + default=None, + description=( + "Canonical-name → list-of-aliases mapping. Case-insensitive matching. " + 'e.g. {"PTV": ["PTV_60", "PTV60"], "SpinalCord": ["Cord"]}' + ), + ) + data_root_override: Optional[str] = Field( + default=None, + description="Override RADIARCH_OPENTPS_DATA_ROOT for this build (dev/testing only).", + ) + + # ---- Cache key ------------------------------------------------------- + + def compute_cache_key(self) -> str: + """Deterministic sha256 over the inputs that affect the output. + + Excludes ``data_root_override`` (a developer convenience — the + actual CT/RTSTRUCT UIDs fully determine the geometry content). + """ + payload = { + "study": self.patient_ref.dicom_study_uid, + "ct": self.patient_ref.ct_series_uid, + "rts": self.patient_ref.rtstruct_uid, + "grid": self.grid_spec.model_dump() if self.grid_spec else None, + "hu_model": self.hu_to_density_model.value, + "name_map": self._normalized_name_map(), + } + blob = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(blob.encode("utf-8")).hexdigest() + + def _normalized_name_map(self) -> Optional[Dict[str, List[str]]]: + """Lowercase canonical name → sorted-lowercase alias list. + + Normalization makes the cache key invariant to stylistic + differences that don't affect the output. + """ + if not self.structure_name_map: + return None + return { + k.lower(): sorted(a.lower() for a in v) + for k, v in self.structure_name_map.items() + } + + +class CTMetadata(BaseModel): + """Minimal CT provenance carried through to downstream services.""" + + patient_name: str = Field(default="ANONYMOUS") + modality: str = Field(default="CT") + num_slices: int = Field(..., ge=1) + study_instance_uid: Optional[str] = None + series_instance_uid: Optional[str] = None + + +class GeometryResult(BaseModel): + """Output of a completed geometry build.""" + + geometry_id: str + density_grid_uri: str = Field(..., description="URI/path to the density volume (NIfTI).") + structure_masks_uri: str = Field( + ..., description="URI/path to the multi-label mask volume (NIfTI, uint16)." + ) + structure_index: Dict[str, int] = Field( + ..., description='Canonical name → integer label in the mask volume. 0 = background.' + ) + grid_spec: GridSpec = Field(..., description="Grid the outputs were written on.") + frame_of_reference_uid: str + ct_metadata: CTMetadata + cache_key: str + + created_at: Optional[datetime] = None + + @model_validator(mode="after") + def _check_structure_index(self) -> "GeometryResult": + if 0 in self.structure_index.values(): + raise ValueError("structure_index labels must be >= 1 (0 is reserved for background)") + if len(set(self.structure_index.values())) != len(self.structure_index): + raise ValueError("structure_index labels must be unique") + return self diff --git a/src/radiarch/services/__init__.py b/src/radiarch/services/__init__.py new file mode 100644 index 0000000..3d98366 --- /dev/null +++ b/src/radiarch/services/__init__.py @@ -0,0 +1,10 @@ +"""Radiarch service layer. + +Each service converts one well-defined input to one well-defined output +and is fully usable in isolation (no implicit coupling to plans or jobs). + +Services: + geometry.GeometryService — Raw DICOM (CT + RTSTRUCT) → voxel model. + +Future services (dose, optimization, robustness, simulation) land here. +""" diff --git a/src/radiarch/services/dicom_fetcher.py b/src/radiarch/services/dicom_fetcher.py new file mode 100644 index 0000000..7fa3812 --- /dev/null +++ b/src/radiarch/services/dicom_fetcher.py @@ -0,0 +1,227 @@ +"""Download a DICOM study from a PACS and stage it to a temp directory. + +The Geometry Service talks to PACS through the ``OrthancAdapterBase`` +abstraction. This module sits between that adapter and OpenTPS's +``dataLoader.readData`` (which only knows how to read from a directory +on disk). Flow: + +1. Ask the adapter which series live in the study. +2. Pick the target CT series — explicit UID wins; otherwise "largest" + (the series with the most instances, which in practice is the + planning CT). Deterministic. +3. Pick the RTSTRUCT series — explicit UID wins; otherwise the first + RTSTRUCT series the adapter returns. +4. Stream every instance of both series through the adapter and write + ``.dcm`` files into an mkdtemp'd directory. +5. Yield a :class:`StagedDicom` whose ``__exit__`` wipes the directory. + +Only axis-aligned CT is supported in v1 (see the resampling module); +we don't enforce that here — invalid affines are caught downstream +with a clear error. + +Adapters that can't serve bytes (``FakeOrthancAdapter``, etc.) raise +via ``retrieve_series``; callers should check +``adapter.can_retrieve_instances()`` first and fall back to local disk +loading if it returns False. +""" + +from __future__ import annotations + +import shutil +import tempfile +from contextlib import AbstractContextManager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional + +from loguru import logger + +from ..adapters.orthanc import OrthancAdapterBase +from ..models.geometry import PatientRef + + +class DicomFetcherError(RuntimeError): + """The adapter couldn't produce a usable study for this request.""" + + +@dataclass +class StagedDicom(AbstractContextManager): + """A temp directory populated with one CT + (optional) RTSTRUCT series. + + Use as a context manager so the directory is cleaned up reliably even + on exceptions — OpenTPS's ``dataLoader`` reads everything eagerly, so + we can delete the files as soon as the caller hands us back the + loaded ``CTImage`` / ``RTStruct`` objects. + """ + + directory: Path + ct_series_uid: str + rtstruct_series_uid: Optional[str] = None + files_written: int = 0 + _cleaned: bool = field(default=False, repr=False) + + def __enter__(self) -> "StagedDicom": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.cleanup() + + def cleanup(self) -> None: + if self._cleaned: + return + shutil.rmtree(self.directory, ignore_errors=True) + self._cleaned = True + + +# --------------------------------------------------------------------------- +# DicomFetcher +# --------------------------------------------------------------------------- + +class DicomFetcher: + """Adapter-driven DICOM study download.""" + + def __init__(self, adapter: OrthancAdapterBase) -> None: + self.adapter = adapter + + # ---- Capability probe -------------------------------------------- + + @property + def can_fetch(self) -> bool: + """True iff the adapter is wired to deliver DICOM bytes.""" + return bool(self.adapter.can_retrieve_instances()) + + # ---- Main entry point -------------------------------------------- + + def fetch(self, patient_ref: PatientRef) -> StagedDicom: + """Download the study and return a populated :class:`StagedDicom`. + + On any failure we clean up the temp dir before raising — callers + should *not* have to worry about partial state. + """ + if not self.can_fetch: + raise DicomFetcherError( + "Adapter does not support retrieving DICOM instances; " + "caller should fall back to data_root loading." + ) + + ct_series_uid = self._pick_ct_series( + patient_ref.dicom_study_uid, patient_ref.ct_series_uid + ) + rtstruct_series_uid = self._pick_rtstruct_series( + patient_ref.dicom_study_uid, patient_ref.rtstruct_uid + ) + + tmpdir = Path(tempfile.mkdtemp(prefix="radiarch_dicom_")) + try: + ct_count = self._stage_series( + tmpdir, patient_ref.dicom_study_uid, ct_series_uid, "CT" + ) + rt_count = 0 + if rtstruct_series_uid: + rt_count = self._stage_series( + tmpdir, patient_ref.dicom_study_uid, rtstruct_series_uid, "RT" + ) + logger.info( + "Staged study %s → %s (CT=%d instances, RT=%d instances)", + patient_ref.dicom_study_uid, + tmpdir, + ct_count, + rt_count, + ) + return StagedDicom( + directory=tmpdir, + ct_series_uid=ct_series_uid, + rtstruct_series_uid=rtstruct_series_uid, + files_written=ct_count + rt_count, + ) + except Exception: + shutil.rmtree(tmpdir, ignore_errors=True) + raise + + # ---- Series selection -------------------------------------------- + + def _pick_ct_series(self, study_uid: str, explicit_uid: Optional[str]) -> str: + """Resolve the target CT series UID. + + Explicit UIDs short-circuit any search — if the caller said + ``ct_series_uid="1.2.3"``, that's what we use, and we only hit + ``search_for_series`` to confirm it exists. When the caller + leaves it null, we list all CT series in the study and pick the + one with the most instances (the planning CT, in practice). + """ + if explicit_uid: + return explicit_uid + + series = self.adapter.search_for_series(study_uid, modality="CT") + if not series: + raise DicomFetcherError( + f"No CT series found in study {study_uid}" + ) + + # "Largest" = most instances. Deterministic tiebreak on UID so + # two equally-large series always resolve the same way across + # machines. + def _sort_key(s: Dict[str, Any]): + count = s.get("NumberOfSeriesRelatedInstances") or 0 + return (-int(count), str(s.get("SeriesInstanceUID") or "")) + + winner = sorted(series, key=_sort_key)[0] + uid = winner.get("SeriesInstanceUID") + if not uid: + raise DicomFetcherError( + f"Study {study_uid} has a CT series missing SeriesInstanceUID" + ) + return uid + + def _pick_rtstruct_series( + self, study_uid: str, explicit_uid: Optional[str] + ) -> Optional[str]: + """Resolve the RTSTRUCT series UID. ``None`` means "no RTSTRUCT".""" + if explicit_uid: + return explicit_uid + + try: + series = self.adapter.search_for_series(study_uid, modality="RTSTRUCT") + except Exception as exc: + logger.warning( + "RTSTRUCT series lookup failed for %s: %s — continuing without it", + study_uid, + exc, + ) + return None + + if not series: + logger.info("No RTSTRUCT series found in study %s", study_uid) + return None + + # For now: take the first one. When we encounter studies with + # multiple RTSTRUCT series this'll need policy (pick by Label, + # by most recent, etc.) — revisit when the real data forces it. + return series[0].get("SeriesInstanceUID") + + # ---- Download ---------------------------------------------------- + + def _stage_series( + self, + tmpdir: Path, + study_uid: str, + series_uid: str, + kind: str, + ) -> int: + """Pull every instance of ``series_uid`` into ``tmpdir``. Returns count.""" + written = 0 + for dataset in self.adapter.retrieve_series(study_uid, series_uid): + sop_uid = getattr(dataset, "SOPInstanceUID", None) or f"{kind}_{written:04d}" + # Filename prefix keeps CTs and RT structs visually grouped in + # the temp dir — helps when debugging a failed build by hand. + dest = tmpdir / f"{kind}_{sop_uid}.dcm" + dataset.save_as(str(dest)) + written += 1 + if written == 0: + raise DicomFetcherError( + f"Series {series_uid} produced zero instances" + ) + return written + + +__all__ = ["DicomFetcher", "DicomFetcherError", "StagedDicom"] diff --git a/src/radiarch/services/geometry.py b/src/radiarch/services/geometry.py new file mode 100644 index 0000000..6fa20c9 --- /dev/null +++ b/src/radiarch/services/geometry.py @@ -0,0 +1,373 @@ +"""``GeometryService`` — DICOM (CT + RTSTRUCT) → voxel model. + +This is the public entry point for Service 1. One method, one contract: +``build(request) -> GeometryResult``. Everything else is private. + +Pipeline +-------- +1. Compute ``cache_key`` and short-circuit to the cached geometry if present. +2. Load the CT + patient (delegates to ``_helpers.load_ct_and_patient``; + respects ``data_root_override``). +3. Convert CT HU → mass density on the *native* CT grid, using the + requested :class:`HUDensityModel`. This ordering matters: converting + HU before resampling preserves tissue boundaries better than + resampling HU and then converting (which smears piecewise-linear + models across interpolation boundaries). +4. Determine the target :class:`GridSpec` — either the user's explicit + grid or the CT's own grid (the fast path, no resampling). +5. If the target ≠ native, resample the density with trilinear + interpolation. +6. Rasterize contours on the target grid (respecting ``structure_name_map`` + aliases). Contours go straight to the target grid to avoid a + second resampling step that would corrupt label boundaries. +7. Persist density + masks + metadata atomically; update cache index. +8. Return the :class:`GeometryResult`. + +Testability +----------- +The service exposes a ``_load`` → ``_process`` seam. ``_load`` is the +OpenTPS-dependent step (DICOM I/O), which tests stub out with a +synthetic CT + fake-contour patient. ``_process`` does the math and +persistence; that's where the interesting invariants live. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Optional, Tuple + +import numpy as np +from loguru import logger + +from ..config import get_settings +from ..models.geometry import ( + CTMetadata, + GeometryBuildRequest, + GeometryResult, + GridSpec, + HUDensityModel, +) +from .hu_density import get_model as get_hu_density_model +from .persistence import ( + DENSITY_FILENAME, + GeometryPaths, + GeometryStore, + MASKS_FILENAME, +) +from .rasterization import rasterize_contours +from .resampling import identity_grid_from_affine, resample_to_grid + + +@dataclass +class _LoadedCT: + """Internal bundle returned by ``_load`` — keeps the public API narrow.""" + + ct: Any # OpenTPS CTImage (or test double) + patient: Any # OpenTPS Patient (or test double) + contours: list # Flattened list of ROI contours across all RTStructs + + +class GeometryService: + """Stateless service (one instance can serve many requests). + + The only persistent state lives on disk, mediated by + :class:`GeometryStore`. Instantiate directly for tests, or with a + custom base_dir; the default uses ``{settings.artifact_dir}/geometry``. + + The optional ``adapter`` argument lets tests inject a fake + ``OrthancAdapterBase``. In production it's left None and the service + constructs one lazily via ``build_orthanc_adapter()`` the first time + ``_load`` needs to reach PACS. + """ + + def __init__( + self, + base_dir: Optional[str | Path] = None, + adapter: Optional[Any] = None, + ) -> None: + if base_dir is None: + settings = get_settings() + base_dir = Path(settings.artifact_dir) / "geometry" + self.store = GeometryStore(base_dir) + self._adapter = adapter # None = lazy + + def _get_adapter(self): + if self._adapter is None: + from ..adapters import build_orthanc_adapter # lazy import + self._adapter = build_orthanc_adapter() + return self._adapter + + # ----------------------------------------------------------------- + # Public entry point + # ----------------------------------------------------------------- + + def build(self, request: GeometryBuildRequest) -> GeometryResult: + cache_key = request.compute_cache_key() + + cached = self.store.lookup_by_cache_key(cache_key) + if cached is not None: + logger.info("Geometry cache hit for key %s → %s", cache_key[:10], cached.geometry_id) + return cached + + logger.info("Building geometry (cache miss, key %s)", cache_key[:10]) + loaded = self._load(request) + return self._process(request, loaded, cache_key) + + # ----------------------------------------------------------------- + # DICOM loading. + # + # Two paths, chosen at request time: + # + # 1. PACS path — if the adapter exposes ``can_retrieve_instances()`` + # (i.e. a real Orthanc / DICOMweb backend), we fetch the study + # into a temp dir and point OpenTPS at it. + # 2. Disk path — fall back to the legacy ``load_ct_and_patient`` + # helper which reads from ``opentps_data_root`` (or the request's + # ``data_root_override``). This keeps the existing dev/test + # flows working when Orthanc is mocked or unreachable. + # + # This is also the seam that tests stub out (monkeypatch ``_load`` + # to return a synthetic CT + fake contours). + # ----------------------------------------------------------------- + + def _load(self, request: GeometryBuildRequest) -> _LoadedCT: + # If the caller forced a data_root, honor it — useful for tests + # and one-off debugging against local fixtures even when Orthanc + # is reachable. + if request.data_root_override: + return self._load_from_disk(request.data_root_override) + + # Prefer PACS when available. + try: + adapter = self._get_adapter() + except Exception as exc: # pragma: no cover — defensive + logger.warning( + "Failed to build Orthanc adapter, falling back to disk: %s", exc + ) + return self._load_from_disk(None) + + from .dicom_fetcher import DicomFetcher # lazy + + fetcher = DicomFetcher(adapter) + if not fetcher.can_fetch: + logger.info( + "Orthanc adapter is metadata-only (mock mode); " + "falling back to opentps_data_root." + ) + return self._load_from_disk(None) + + return self._load_from_pacs(fetcher, request) + + # ---- Disk path ---------------------------------------------------- + + @staticmethod + def _load_from_disk(data_root: Optional[str]) -> _LoadedCT: + from ..core.workflows._helpers import load_ct_and_patient # lazy + + ct, patient, rt_structs = load_ct_and_patient(data_root=data_root) + contours: list = [] + for rt in patient.rtStructs if patient and patient.rtStructs else rt_structs: + contours.extend(rt.contours) + return _LoadedCT(ct=ct, patient=patient, contours=contours) + + # ---- PACS path ---------------------------------------------------- + + @staticmethod + def _load_from_pacs(fetcher, request: GeometryBuildRequest) -> _LoadedCT: + """Download the study via ``fetcher`` and hand OpenTPS the temp dir. + + The temp dir is cleaned up before we return — OpenTPS's + ``readData`` reads everything eagerly into memory, so we don't + need the files on disk afterwards. + """ + from opentps.core.io import dataLoader + from opentps.core.data.images import CTImage + from opentps.core.data._rtStruct import RTStruct + + with fetcher.fetch(request.patient_ref) as staged: + data_list = dataLoader.readData(str(staged.directory)) + if not data_list: + raise ValueError( + f"OpenTPS found no readable DICOM in staged dir {staged.directory}" + ) + + ct = None + patient = None + found_rt_structs: list = [] + for item in data_list: + if isinstance(item, CTImage): + ct = item + patient = item.patient + elif isinstance(item, RTStruct): + found_rt_structs.append(item) + + if not ct: + raise ValueError( + "No CTImage was produced from the staged DICOM — " + "the series UID may point at something other than CT." + ) + if not patient: + from opentps.core.data import Patient + patient = Patient(name="Unknown") + + for rt in found_rt_structs: + if rt not in patient.rtStructs: + patient.appendPatientData(rt) + + contours: list = [] + for rt in patient.rtStructs if patient.rtStructs else found_rt_structs: + contours.extend(rt.contours) + return _LoadedCT(ct=ct, patient=patient, contours=contours) + + # ----------------------------------------------------------------- + # Core processing pipeline — tested with stubs. + # ----------------------------------------------------------------- + + def _process( + self, + request: GeometryBuildRequest, + loaded: _LoadedCT, + cache_key: str, + ) -> GeometryResult: + ct = loaded.ct + ct_array = np.asarray(ct.imageArray) + if ct_array.ndim != 3: + raise ValueError(f"CT imageArray must be 3D, got shape {ct_array.shape}") + + src_spacing = tuple(float(s) for s in ct.spacing) + src_origin = tuple(float(o) for o in ct.origin) + src_size = tuple(int(s) for s in ct_array.shape) + + source_grid = GridSpec( + spacing_mm=src_spacing, + origin_mm=src_origin, + size=src_size, + ) + source_grid.affine = source_grid.compute_affine() + src_affine = source_grid.to_numpy_affine() + + # 1. HU → density on the NATIVE grid. + hu_model = get_hu_density_model(request.hu_to_density_model) + density_native = hu_model.convert(ct_array) + + # 2. Pick the target grid. + target_grid = self._resolve_target_grid(request, source_grid) + + # 3. Resample density if target ≠ source. + density_final = self._maybe_resample( + density_native, src_affine, source_grid, target_grid + ) + + # 4. Rasterize contours directly on the target grid. + masks, structure_index = rasterize_contours( + loaded.contours, + target_grid, + structure_name_map=request.structure_name_map, + ) + + # 5. Persist + build the result. + geometry_id = str(uuid.uuid4()) + paths = GeometryPaths.for_id(self.store.base_dir, geometry_id) + ct_meta = self._ct_metadata(ct) + + result = GeometryResult( + geometry_id=geometry_id, + density_grid_uri=str(paths.density), + structure_masks_uri=str(paths.masks), + structure_index=structure_index, + grid_spec=target_grid, + frame_of_reference_uid=self._frame_of_reference(ct), + ct_metadata=ct_meta, + cache_key=cache_key, + ) + + self.store.save( + geometry_id=geometry_id, + cache_key=cache_key, + density=density_final, + masks=masks, + result=result, + ) + logger.info( + "Geometry %s built: grid=%s structures=%s", + geometry_id, + target_grid.size, + list(structure_index.keys()), + ) + return result + + # ----------------------------------------------------------------- + # Helpers + # ----------------------------------------------------------------- + + @staticmethod + def _resolve_target_grid( + request: GeometryBuildRequest, + source: GridSpec, + ) -> GridSpec: + """Fill in any missing target-grid fields from the source grid.""" + if request.grid_spec is None: + # Fast path: inherit everything from the source. + inherited = GridSpec( + spacing_mm=source.spacing_mm, + origin_mm=source.origin_mm, + size=source.size, + ) + inherited.affine = inherited.compute_affine() + return inherited + + # Partial inheritance: user gave spacing but left origin / size + # null → adopt them from the source grid. + spec = request.grid_spec + origin = spec.origin_mm if spec.origin_mm is not None else source.origin_mm + size = spec.size if spec.size is not None else source.size + + completed = GridSpec( + spacing_mm=spec.spacing_mm, + origin_mm=origin, + size=size, + ) + completed.affine = completed.compute_affine() + return completed + + @staticmethod + def _maybe_resample( + density: np.ndarray, + src_affine: np.ndarray, + source: GridSpec, + target: GridSpec, + ) -> np.ndarray: + """Skip resampling if target grid == source grid (the fast path).""" + if ( + source.spacing_mm == target.spacing_mm + and source.origin_mm == target.origin_mm + and source.size == target.size + ): + return density.astype(np.float32, copy=False) + return resample_to_grid(density, src_affine, target, order=1, cval=0.0) + + @staticmethod + def _ct_metadata(ct: Any) -> CTMetadata: + patient = getattr(ct, "patient", None) + patient_name = getattr(patient, "name", None) or getattr(patient, "patientName", None) or "ANONYMOUS" + return CTMetadata( + patient_name=str(patient_name), + modality="CT", + num_slices=int(np.asarray(ct.imageArray).shape[2]), + study_instance_uid=getattr(ct, "studyInstanceUID", None), + series_instance_uid=getattr(ct, "seriesInstanceUID", None), + ) + + @staticmethod + def _frame_of_reference(ct: Any) -> str: + for_uid = ( + getattr(ct, "frameOfReferenceUID", None) + or getattr(ct, "FrameOfReferenceUID", None) + or "" + ) + return str(for_uid) + + +__all__ = ["GeometryService"] diff --git a/src/radiarch/services/hu_density.py b/src/radiarch/services/hu_density.py new file mode 100644 index 0000000..fffe5a9 --- /dev/null +++ b/src/radiarch/services/hu_density.py @@ -0,0 +1,188 @@ +"""Pluggable Hounsfield-unit → mass-density models. + +Three models behind a common ``HUToDensity`` interface: + +* :class:`SchneiderModel` — piecewise-linear Schneider-2000 calibration, + the published reference for proton therapy. Four-segment piecewise-linear + HU → ρ (g/cc): + + (-1000, 0.00121) air + ( -98, 0.93) fat + ( 14, 1.03) soft tissue + ( 23, 1.03) muscle + ( 2000, 2.88) cortical bone (extrapolated) + + Values are interpolated linearly between breakpoints and clamped at + the endpoints. + +* :class:`StoichiometricModel` — wraps the vendored MCsquare CT calibration + (``opentps.core.data.CTCalibrations.MCsquareCalibration``). The MCsquare + calibration is the gold-standard for proton dose in this project; this + model delegates to it for density lookup. + +* :class:`LinearModel` — ρ = max(0, 1 + HU/1000). Used for synthetic / test + modes where we just need a smooth density surrogate. + +Units: HU is dimensionless; density is g/cm³ (a.k.a. g/cc). Downstream dose +engines expect float32. +""" + +from __future__ import annotations + +import abc +from typing import Optional + +import numpy as np + +from ..models.geometry import HUDensityModel + + +class HUToDensity(abc.ABC): + """Abstract base: HU array → density array (g/cc, float32).""" + + #: Symbolic name matching the :class:`HUDensityModel` enum value. + name: str + + @abc.abstractmethod + def convert(self, hu: np.ndarray) -> np.ndarray: + """Return a float32 density array the same shape as ``hu``.""" + ... + + def __call__(self, hu: np.ndarray) -> np.ndarray: + return self.convert(hu) + + +# --------------------------------------------------------------------------- +# Schneider (2000) +# --------------------------------------------------------------------------- + +# Published breakpoints — (HU, density g/cc). Extended at both ends so we +# can clamp without a separate branch. +_SCHNEIDER_BREAKPOINTS = np.array( + [ + [-1000.0, 0.00121], # air + [-98.0, 0.93], # adipose / fat + [14.0, 1.03], # soft tissue + [23.0, 1.03], # muscle plateau + [100.0, 1.065], # connective tissue + [400.0, 1.21], # trabecular bone + [1000.0, 1.60], # low-density cortical bone + [2000.0, 2.88], # dense cortical bone (extrapolation cap) + ], + dtype=np.float64, +) + + +class SchneiderModel(HUToDensity): + """Piecewise-linear HU → density per Schneider et al. 2000.""" + + name = HUDensityModel.schneider.value + + def __init__(self) -> None: + self._hu = _SCHNEIDER_BREAKPOINTS[:, 0] + self._rho = _SCHNEIDER_BREAKPOINTS[:, 1] + + def convert(self, hu: np.ndarray) -> np.ndarray: + rho = np.interp(hu, self._hu, self._rho).astype(np.float32, copy=False) + # np.interp already clamps at the endpoints, but guard against + # negative densities from upstream HU corruption. + np.clip(rho, 0.0, None, out=rho) + return rho + + +# --------------------------------------------------------------------------- +# Stoichiometric (MCsquare-backed) +# --------------------------------------------------------------------------- + +class StoichiometricModel(HUToDensity): + """HU → density via the vendored MCsquare CT calibration. + + The MCsquare calibration stores an HU→density table derived from tissue + stoichiometry. We defer the actual math to OpenTPS's + ``MCsquareCTCalibration.convertHU2MassDensity``; this class is a thin, + cache-friendly adapter so the rest of the service stays decoupled from + OpenTPS internals. + + Constructed lazily: loading the calibration pulls in the vendored + MCsquare package, which we don't want to import at module-load time + when callers only need the Linear/Schneider variants. + """ + + name = HUDensityModel.stoichiometric.value + + def __init__(self, calibration: Optional[object] = None) -> None: + # ``calibration`` is an MCsquareCTCalibration; typed loosely to + # avoid importing opentps here. Loaded on demand if not supplied. + self._calibration = calibration + + def _ensure_calibration(self): + if self._calibration is None: + # Deferred import — keeps module-load cheap and lets us unit-test + # Schneider/Linear without the opentps tree on sys.path. + from ..core.workflows._helpers import setup_calibration + + calibration, _ = setup_calibration() + self._calibration = calibration + return self._calibration + + def convert(self, hu: np.ndarray) -> np.ndarray: + cal = self._ensure_calibration() + # MCsquareCTCalibration.convertHU2MassDensity accepts either a + # scalar or an array; returns the same shape. + rho = np.asarray(cal.convertHU2MassDensity(hu), dtype=np.float32) + np.clip(rho, 0.0, None, out=rho) + return rho + + +# --------------------------------------------------------------------------- +# Linear (trivial surrogate) +# --------------------------------------------------------------------------- + +class LinearModel(HUToDensity): + """ρ = max(0, 1 + HU/1000). Smooth, monotone, zero deps.""" + + name = HUDensityModel.linear.value + + def convert(self, hu: np.ndarray) -> np.ndarray: + hu_arr = np.asarray(hu, dtype=np.float32) + rho = 1.0 + hu_arr / 1000.0 + np.clip(rho, 0.0, None, out=rho) + return rho + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +def get_model(model: HUDensityModel | str) -> HUToDensity: + """Return a fresh instance of the requested HU→density model. + + Accepts the enum value or the string form for convenience. + """ + if isinstance(model, str): + try: + model = HUDensityModel(model) + except ValueError as exc: + raise ValueError( + f"Unknown HU density model {model!r}. " + f"Valid: {[m.value for m in HUDensityModel]}" + ) from exc + + if model is HUDensityModel.schneider: + return SchneiderModel() + if model is HUDensityModel.stoichiometric: + return StoichiometricModel() + if model is HUDensityModel.linear: + return LinearModel() + + # Exhaustive; mypy would flag if a new enum value is added. + raise ValueError(f"Unhandled HU density model: {model!r}") + + +__all__ = [ + "HUToDensity", + "SchneiderModel", + "StoichiometricModel", + "LinearModel", + "get_model", +] diff --git a/src/radiarch/services/persistence.py b/src/radiarch/services/persistence.py new file mode 100644 index 0000000..625e488 --- /dev/null +++ b/src/radiarch/services/persistence.py @@ -0,0 +1,241 @@ +"""On-disk persistence for Geometry Service outputs. + +Layout on disk under ``{artifact_dir}/geometry/``:: + + geometry/ + _index.json # cache_key → geometry_id lookup + {geometry_id}/ + density.nii.gz + masks.nii.gz + meta.json # full GeometryResult minus the URIs + +NIfTI I/O goes through SimpleITK (already a project dep). We write +density as float32 and masks as uint16; the target GridSpec's +axis-aligned affine is mapped to ITK's ``SetSpacing`` + ``SetOrigin`` + +``SetDirection`` (identity direction only in v1 — rotational affines +are rejected upstream in the resampler). + +Atomicity +--------- +Outputs are written under ``{geometry_id}/.tmp/`` first, then renamed to +their final names in a single step. A crash mid-write leaves the .tmp +directory behind (cleaned up on the next cache-miss for the same key) +but never poisons the cache with half-written files. + +The cache index is a plain JSON file. For v1 (synchronous mode) this is +simpler than introducing an Alembic migration and a Geometry table; the +async-mode PR will migrate reads to the DB. +""" + +from __future__ import annotations + +import json +import os +import shutil +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Optional, Tuple + +import numpy as np + +from ..models.geometry import GeometryResult, GridSpec + + +DENSITY_FILENAME = "density.nii.gz" +MASKS_FILENAME = "masks.nii.gz" +META_FILENAME = "meta.json" +INDEX_FILENAME = "_index.json" + + +# --------------------------------------------------------------------------- +# GridSpec ↔ SimpleITK +# --------------------------------------------------------------------------- + +def _apply_gridspec_to_itk(image, grid: GridSpec) -> None: + """Stamp an ``sitk.Image`` with spacing/origin/direction from ``grid``. + + NB: v1 is axis-aligned, so the direction cosine matrix is the 3×3 + identity. Rotational affines are rejected in ``resample_to_grid``. + """ + image.SetSpacing(tuple(float(s) for s in grid.spacing_mm)) + if grid.origin_mm is not None: + image.SetOrigin(tuple(float(o) for o in grid.origin_mm)) + image.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) + + +def _write_nifti(volume: np.ndarray, grid: GridSpec, path: Path) -> None: + """Write a 3D numpy array as a NIfTI file at ``path``. + + SimpleITK expects arrays in (z, y, x) order; we store volumes in + (i, j, k) = (x, y, z) order (matches OpenTPS and DICOM), so we + transpose on the way out and on the way back in (see ``_read_nifti``). + """ + import SimpleITK as sitk # local import → keeps module-load light + + # Transpose (i,j,k)=(x,y,z) → (z,y,x) for SimpleITK. + arr = np.asarray(volume) + image = sitk.GetImageFromArray(np.transpose(arr, (2, 1, 0))) + _apply_gridspec_to_itk(image, grid) + # SimpleITK handles .nii.gz transparently based on extension. + sitk.WriteImage(image, str(path)) + + +def _read_nifti(path: Path) -> Tuple[np.ndarray, GridSpec]: + """Read a NIfTI file; return ``(array_in_ijk_order, GridSpec)``. + + The inverse of :func:`_write_nifti`. Primarily used in tests / + downstream services that want to round-trip the geometry outputs. + """ + import SimpleITK as sitk + + image = sitk.ReadImage(str(path)) + arr = sitk.GetArrayFromImage(image) # (z, y, x) + arr = np.transpose(arr, (2, 1, 0)) # → (x, y, z) + + spacing = tuple(float(s) for s in image.GetSpacing()) + origin = tuple(float(o) for o in image.GetOrigin()) + size = tuple(int(s) for s in arr.shape) + + spec = GridSpec(spacing_mm=spacing, origin_mm=origin, size=size) + spec.affine = spec.compute_affine() + return arr, spec + + +# --------------------------------------------------------------------------- +# On-disk geometry store (synchronous-mode cache) +# --------------------------------------------------------------------------- + +@dataclass +class GeometryPaths: + """Convenience bundle of the final on-disk paths for one geometry.""" + + root: Path + density: Path + masks: Path + meta: Path + + @classmethod + def for_id(cls, base_dir: Path, geometry_id: str) -> "GeometryPaths": + root = base_dir / geometry_id + return cls( + root=root, + density=root / DENSITY_FILENAME, + masks=root / MASKS_FILENAME, + meta=root / META_FILENAME, + ) + + +class GeometryStore: + """File-backed geometry persistence with a JSON cache index. + + Not thread-safe across processes (fine for synchronous mode — the + async-mode store will layer a DB row with a unique index on cache_key + to serialize concurrent builds). + """ + + def __init__(self, base_dir: str | os.PathLike[str]) -> None: + self.base_dir = Path(base_dir).resolve() + self.base_dir.mkdir(parents=True, exist_ok=True) + + # ---- cache index -------------------------------------------------- + + @property + def _index_path(self) -> Path: + return self.base_dir / INDEX_FILENAME + + def _load_index(self) -> Dict[str, str]: + if not self._index_path.exists(): + return {} + try: + return json.loads(self._index_path.read_text()) + except (OSError, json.JSONDecodeError): + # A corrupt index shouldn't wedge the whole service; treat + # it as empty and overwrite on the next successful write. + return {} + + def _save_index(self, index: Dict[str, str]) -> None: + tmp = self._index_path.with_suffix(".json.tmp") + tmp.write_text(json.dumps(index, indent=2, sort_keys=True)) + os.replace(tmp, self._index_path) + + def lookup_by_cache_key(self, cache_key: str) -> Optional[GeometryResult]: + """Return the cached ``GeometryResult`` for ``cache_key``, or None.""" + index = self._load_index() + geometry_id = index.get(cache_key) + if not geometry_id: + return None + return self.get_by_id(geometry_id) + + def get_by_id(self, geometry_id: str) -> Optional[GeometryResult]: + paths = GeometryPaths.for_id(self.base_dir, geometry_id) + if not paths.meta.exists(): + return None + try: + data = json.loads(paths.meta.read_text()) + except (OSError, json.JSONDecodeError): + return None + return GeometryResult.model_validate(data) + + # ---- writes ------------------------------------------------------- + + def save( + self, + *, + geometry_id: str, + cache_key: str, + density: np.ndarray, + masks: np.ndarray, + result: GeometryResult, + ) -> GeometryPaths: + """Write density + masks + meta atomically, then update the cache index.""" + paths = GeometryPaths.for_id(self.base_dir, geometry_id) + # Use a sibling .tmp dir so the final rename is on the same fs. + with tempfile.TemporaryDirectory(dir=self.base_dir, prefix=f".{geometry_id}.tmp.") as tmp: + tmp_path = Path(tmp) + tmp_density = tmp_path / DENSITY_FILENAME + tmp_masks = tmp_path / MASKS_FILENAME + tmp_meta = tmp_path / META_FILENAME + + _write_nifti(density.astype(np.float32, copy=False), result.grid_spec, tmp_density) + _write_nifti(masks.astype(np.uint16, copy=False), result.grid_spec, tmp_masks) + tmp_meta.write_text(result.model_dump_json(indent=2)) + + # Atomic replace: if paths.root already exists (retry of a + # failed build), nuke it first. + if paths.root.exists(): + shutil.rmtree(paths.root) + os.replace(tmp_path, paths.root) + # ``tmp_path`` has moved to ``paths.root``; replace the context + # manager's now-invalid reference with a fresh tempdir so the + # __exit__ cleanup becomes a no-op. + os.makedirs(tmp_path, exist_ok=True) + + # Update the cache index last — a crash between file-write and + # index-update leaves orphan geometry dirs (harmless) instead of + # dangling index entries (harmful). + index = self._load_index() + index[cache_key] = geometry_id + self._save_index(index) + return paths + + # ---- debugging helpers ------------------------------------------- + + def list_ids(self) -> list[str]: + if not self.base_dir.exists(): + return [] + return sorted( + p.name + for p in self.base_dir.iterdir() + if p.is_dir() and not p.name.startswith(".") and (p / META_FILENAME).exists() + ) + + +__all__ = [ + "DENSITY_FILENAME", + "MASKS_FILENAME", + "META_FILENAME", + "INDEX_FILENAME", + "GeometryPaths", + "GeometryStore", +] diff --git a/src/radiarch/services/rasterization.py b/src/radiarch/services/rasterization.py new file mode 100644 index 0000000..b877ede --- /dev/null +++ b/src/radiarch/services/rasterization.py @@ -0,0 +1,210 @@ +"""Contour → multi-label voxel mask rasterization. + +Given a list of RTSTRUCT contours (anything with a ``.name`` and a +``.getBinaryMask(origin, gridSize, spacing)`` method — the OpenTPS +``ROIContour`` duck-types cleanly), this module produces a single +uint16 multi-label volume on a user-supplied :class:`GridSpec`. + +Design +------ +* **One label per canonical name.** Aliases are collapsed first, so if + the user passes ``structure_name_map={"PTV": ["PTV_60", "PTV60"]}`` + and the RTSTRUCT contains both ``PTV_60`` and ``PTV60``, we pick the + first contour that hits and discard the rest with a logger.debug + breadcrumb — rather than silently over-painting. +* **Deterministic label ordering.** Targets (PTV/GTV/CTV) always receive + the lowest label indices. Everything else sorts alphabetically. This + keeps mask files reproducible across builds and across machines. +* **First-match-wins overlap policy.** When two canonical contours + overlap in a voxel, the lower-label one (earlier in the ordering) + keeps the voxel. OAR rules generally want targets to dominate + overlapping OARs in the mask, and our ordering guarantees that. +* **0 is reserved for background.** Labels start at 1. +""" + +from __future__ import annotations + +from typing import Dict, Iterable, Mapping, Optional, Protocol, Tuple + +import numpy as np +from loguru import logger + +from ..models.geometry import GridSpec + + +# --------------------------------------------------------------------------- +# Contour protocol — any object matching this shape works. +# --------------------------------------------------------------------------- + +class ContourLike(Protocol): + """Minimum shape we need from an RT contour. + + OpenTPS's ``ROIContour`` satisfies this; fakes in tests can too. + """ + + name: str + + def getBinaryMask( + self, + origin: Tuple[float, float, float], + gridSize: Tuple[int, int, int], + spacing: Tuple[float, float, float], + ): # -> ROIMask with .imageArray, but we only call .imageArray on it. + ... + + +# --------------------------------------------------------------------------- +# Alias resolution +# --------------------------------------------------------------------------- + +#: Canonical target-volume names. Case-insensitive; any contour matching +#: one of these (directly or via alias) sorts to the front of the label +#: ordering. +TARGET_CANONICAL_NAMES = ("PTV", "GTV", "CTV") + + +def _normalize(name: str) -> str: + """Lowercase + strip — the invariant we compare on.""" + return name.strip().lower() + + +def _build_alias_lookup( + name_map: Optional[Mapping[str, Iterable[str]]], +) -> Dict[str, str]: + """Flatten ``{canonical: [aliases]}`` → ``{lowercase_name: canonical}``. + + Canonical names are returned in their original case so the output + ``structure_index`` reads naturally. Both the canonical name itself + and each alias become keys (lowercased). + """ + lookup: Dict[str, str] = {} + if not name_map: + return lookup + for canonical, aliases in name_map.items(): + lookup[_normalize(canonical)] = canonical + for alias in aliases: + lookup[_normalize(alias)] = canonical + return lookup + + +def _resolve_canonical(contour_name: str, lookup: Dict[str, str]) -> str: + """Alias lookup; fall back to the raw name if no mapping applies.""" + return lookup.get(_normalize(contour_name), contour_name) + + +# --------------------------------------------------------------------------- +# Label ordering +# --------------------------------------------------------------------------- + +def _target_priority(name: str) -> int: + """PTV=0, GTV=1, CTV=2, everything else = 100. Lower sorts first.""" + upper = name.strip().upper() + for idx, target in enumerate(TARGET_CANONICAL_NAMES): + # Match either exact target or "TARGET_xxx" prefix so PTV_60 still + # reads as a target even if the alias map didn't collapse it. + if upper == target or upper.startswith(f"{target}_") or upper.startswith(f"{target}-"): + return idx + return 100 + + +def _order_canonicals(names: Iterable[str]) -> list[str]: + """Targets first (by canonical order), then alphabetical.""" + return sorted(set(names), key=lambda n: (_target_priority(n), n.lower())) + + +# --------------------------------------------------------------------------- +# Rasterization +# --------------------------------------------------------------------------- + +def rasterize_contours( + contours: Iterable[ContourLike], + target: GridSpec, + *, + structure_name_map: Optional[Mapping[str, Iterable[str]]] = None, +) -> Tuple[np.ndarray, Dict[str, int]]: + """Pack ``contours`` into a single uint16 multi-label volume. + + Parameters + ---------- + contours + Iterable of RTSTRUCT contours. Each must have ``.name`` and + ``.getBinaryMask(origin, gridSize, spacing)``. + target + Fully-specified :class:`GridSpec` (``origin_mm`` and ``size`` set). + The mask volume is rendered on this grid directly — we do not + resample after rasterization, to avoid label corruption. + structure_name_map + Optional alias map: ``{canonical: [aliases]}``. + + Returns + ------- + (mask_volume, structure_index) + ``mask_volume`` is a uint16 array of shape ``target.size`` with + 0 as background. ``structure_index`` is ``{canonical_name: label}``. + """ + if not target.is_fully_specified(): + raise ValueError( + "rasterize_contours requires a fully specified target GridSpec" + ) + + origin = tuple(float(x) for x in target.origin_mm) # type: ignore[arg-type] + grid_size = tuple(int(x) for x in target.size) # type: ignore[arg-type] + spacing = tuple(float(x) for x in target.spacing_mm) + + alias_lookup = _build_alias_lookup(structure_name_map) + + # First pass: group contours by canonical name (first-match-wins + # within a canonical group). ``masks_by_canonical`` maps canonical + # name → the *first* rasterized boolean mask we accept for it. + masks_by_canonical: Dict[str, np.ndarray] = {} + for contour in contours: + canonical = _resolve_canonical(contour.name, alias_lookup) + if canonical in masks_by_canonical: + logger.debug( + "Skipping duplicate contour %r → canonical %r (already rasterized)", + contour.name, + canonical, + ) + continue + try: + mask_obj = contour.getBinaryMask(origin, grid_size, spacing) + except Exception as exc: # pragma: no cover — logged, not fatal. + logger.warning( + "Rasterization failed for contour %r (canonical %r): %s", + contour.name, + canonical, + exc, + ) + continue + mask = np.asarray(mask_obj.imageArray, dtype=bool) + if mask.shape != grid_size: + logger.warning( + "Contour %r mask shape %s does not match target %s — skipping", + contour.name, + mask.shape, + grid_size, + ) + continue + masks_by_canonical[canonical] = mask + + # Second pass: assign labels deterministically and pack. + ordered = _order_canonicals(masks_by_canonical) + mask_volume = np.zeros(grid_size, dtype=np.uint16) + structure_index: Dict[str, int] = {} + + for label, canonical in enumerate(ordered, start=1): + m = masks_by_canonical[canonical] + # First-match-wins at voxel level: only paint where current + # label is still 0 (unclaimed). + paint = m & (mask_volume == 0) + mask_volume[paint] = label + structure_index[canonical] = label + + return mask_volume, structure_index + + +__all__ = [ + "ContourLike", + "TARGET_CANONICAL_NAMES", + "rasterize_contours", +] diff --git a/src/radiarch/services/resampling.py b/src/radiarch/services/resampling.py new file mode 100644 index 0000000..d6e6203 --- /dev/null +++ b/src/radiarch/services/resampling.py @@ -0,0 +1,157 @@ +"""Grid resampling between axis-aligned voxel grids. + +The Geometry Service frequently needs to map a source volume (defined on +the native CT grid with some affine ``A_src``) onto a user-specified +target grid (defined by a :class:`GridSpec`). This module provides one +function that does exactly that, correctly and in a single pass. + +Design notes +------------ +* **Axis-aligned only.** v1 assumes the CT's patient-LPS axes are aligned + with the target grid's axes. The affines are diagonal apart from the + translation component. We deliberately reject rotated affines here + rather than silently resampling them incorrectly — rotational + resampling is a separate, more expensive operation we'll add when we + actually encounter a rotated scanner. +* **Interpolation order.** Order 1 (trilinear) is correct for continuous + fields (density, dose). Order 0 (nearest-neighbor) is correct for + label volumes (masks, segmentations); anything higher than 0 corrupts + integer labels. +* **Out-of-bounds.** Fill with a caller-supplied constant (0.0 by default + for densities, 0 for masks — 0 = "outside / background" for both). +""" + +from __future__ import annotations + +from typing import Tuple + +import numpy as np +from scipy.ndimage import map_coordinates + +from ..models.geometry import GridSpec + + +def _extract_diagonal_affine(affine: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Return (spacing, origin) from a 4×4 affine; raise if not axis-aligned. + + Tolerates tiny floating-point noise in the off-diagonal entries + (<= 1e-9) — DICOM ImageOrientationPatient is often not bit-exactly + identity even on unrotated scans. + """ + affine = np.asarray(affine, dtype=np.float64) + if affine.shape != (4, 4): + raise ValueError(f"Expected 4×4 affine, got shape {affine.shape}") + + rot = affine[:3, :3] + off_diag = rot - np.diag(np.diagonal(rot)) + if np.any(np.abs(off_diag) > 1e-6): + raise NotImplementedError( + "Rotated affines are not supported in v1 of the Geometry Service. " + f"Off-diagonal magnitude: {np.abs(off_diag).max():.3g}" + ) + + spacing = np.diagonal(rot).copy() + if np.any(spacing <= 0): + raise ValueError(f"Non-positive spacing on affine diagonal: {spacing}") + origin = affine[:3, 3].copy() + return spacing, origin + + +def resample_to_grid( + volume: np.ndarray, + src_affine: np.ndarray, + target: GridSpec, + *, + order: int = 1, + cval: float = 0.0, +) -> np.ndarray: + """Resample ``volume`` onto ``target`` and return the resampled array. + + Parameters + ---------- + volume + Source volume, shape ``(ni, nj, nk)``. Dtype is preserved for + integer order=0 (labels) and cast to float32 for order >= 1 + (continuous fields). + src_affine + 4×4 voxel-index → patient-LPS affine of ``volume``. Must be + axis-aligned; rotational affines raise. + target + Fully-specified :class:`GridSpec` (must have non-None ``origin_mm`` + and ``size``). + order + Interpolation order. ``1`` (trilinear) for density/dose, ``0`` + (nearest-neighbor) for masks/labels. Values > 1 are valid but + rarely useful here. + cval + Fill value for out-of-bounds voxels (outside the source extent). + + Returns + ------- + numpy.ndarray + Volume on the target grid, shape == ``target.size``. + """ + if not target.is_fully_specified(): + raise ValueError( + "resample_to_grid requires a fully specified target GridSpec " + "(origin_mm and size must both be set)" + ) + + volume = np.asarray(volume) + if volume.ndim != 3: + raise ValueError(f"volume must be 3D, got shape {volume.shape}") + + src_spacing, src_origin = _extract_diagonal_affine(np.asarray(src_affine)) + tgt_spacing = np.asarray(target.spacing_mm, dtype=np.float64) + tgt_origin = np.asarray(target.origin_mm, dtype=np.float64) + tgt_size = np.asarray(target.size, dtype=np.int64) + + # Build the target-voxel-index grid (i, j, k). + ii, jj, kk = np.meshgrid( + np.arange(tgt_size[0]), + np.arange(tgt_size[1]), + np.arange(tgt_size[2]), + indexing="ij", + ) + + # Target voxel index → patient-LPS mm → source voxel index. + # Because both affines are diagonal, this collapses to per-axis + # arithmetic: src_idx = (tgt_origin + tgt_idx * tgt_spacing - src_origin) / src_spacing + src_i = (tgt_origin[0] + ii * tgt_spacing[0] - src_origin[0]) / src_spacing[0] + src_j = (tgt_origin[1] + jj * tgt_spacing[1] - src_origin[1]) / src_spacing[1] + src_k = (tgt_origin[2] + kk * tgt_spacing[2] - src_origin[2]) / src_spacing[2] + + coords = np.stack([src_i, src_j, src_k], axis=0) + + # Preserve integer dtype for labels; cast continuous fields to float32. + if order == 0 and np.issubdtype(volume.dtype, np.integer): + work = volume + out_dtype = volume.dtype + else: + work = volume.astype(np.float32, copy=False) + out_dtype = np.float32 + + resampled = map_coordinates( + work, coords, order=order, mode="constant", cval=cval, prefilter=(order > 1) + ) + return resampled.astype(out_dtype, copy=False) + + +def identity_grid_from_affine(affine: np.ndarray, size: Tuple[int, int, int]) -> GridSpec: + """Build the :class:`GridSpec` that matches ``affine`` + ``size`` exactly. + + Convenience: lets callers reuse the source CT grid as the target + GridSpec when no explicit grid is requested (the "null grid_spec" + fast path in ``GeometryBuildRequest``). + """ + spacing, origin = _extract_diagonal_affine(np.asarray(affine)) + spec = GridSpec( + spacing_mm=tuple(float(s) for s in spacing), + origin_mm=tuple(float(o) for o in origin), + size=tuple(int(s) for s in size), + ) + spec.affine = spec.compute_affine() + return spec + + +__all__ = ["resample_to_grid", "identity_grid_from_affine"] diff --git a/tests/services/__init__.py b/tests/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/services/test_dicom_fetcher.py b/tests/services/test_dicom_fetcher.py new file mode 100644 index 0000000..d1e7c8e --- /dev/null +++ b/tests/services/test_dicom_fetcher.py @@ -0,0 +1,325 @@ +"""Unit tests for radiarch.services.dicom_fetcher. + +We build a fake ``OrthancAdapterBase`` that yields synthesized pydicom +Datasets — enough structure to serialize to disk and be re-read, but no +pixel data. That keeps the tests fast and self-contained. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional + +import pydicom +import pytest +from pydicom.dataset import Dataset, FileDataset +from pydicom.uid import ExplicitVRLittleEndian, generate_uid + +from radiarch.adapters.orthanc import OrthancAdapterBase +from radiarch.models.geometry import PatientRef +from radiarch.services.dicom_fetcher import ( + DicomFetcher, + DicomFetcherError, + StagedDicom, +) + + +# --------------------------------------------------------------------------- +# Minimal synthetic DICOM instance +# --------------------------------------------------------------------------- + +def _make_instance( + study_uid: str, + series_uid: str, + modality: str, + sop_uid: Optional[str] = None, +) -> Dataset: + """Build a minimally-valid pydicom Dataset that can be written to disk.""" + sop_uid = sop_uid or generate_uid() + file_meta = Dataset() + file_meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.2" # CT Image Storage + file_meta.MediaStorageSOPInstanceUID = sop_uid + file_meta.TransferSyntaxUID = ExplicitVRLittleEndian + + ds = FileDataset("", {}, file_meta=file_meta, preamble=b"\x00" * 128) + ds.StudyInstanceUID = study_uid + ds.SeriesInstanceUID = series_uid + ds.SOPInstanceUID = sop_uid + ds.Modality = modality + ds.PatientName = "TEST^PATIENT" + ds.PatientID = "PID_001" + ds.is_little_endian = True + ds.is_implicit_VR = False + return ds + + +# --------------------------------------------------------------------------- +# Fake adapter +# --------------------------------------------------------------------------- + +class FakePacsAdapter(OrthancAdapterBase): + """Serves in-memory synthetic DICOM datasets keyed by study/series.""" + + def __init__( + self, + series_metadata: List[Dict[str, Any]], + series_instances: Dict[str, List[Dataset]], + *, + can_retrieve: bool = True, + ) -> None: + self._series_metadata = series_metadata + self._series_instances = series_instances + self._can_retrieve = can_retrieve + self.retrieve_calls: List[str] = [] + + def can_retrieve_instances(self) -> bool: + return self._can_retrieve + + def search_for_series( + self, study_instance_uid: str, modality: Optional[str] = None + ) -> List[Dict[str, Any]]: + out = [s for s in self._series_metadata if s["_study"] == study_instance_uid] + if modality: + out = [s for s in out if s["Modality"] == modality] + return out + + def retrieve_series( + self, study_instance_uid: str, series_instance_uid: str + ) -> Iterable[Dataset]: + self.retrieve_calls.append(series_instance_uid) + return iter(self._series_instances.get(series_instance_uid, [])) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +STUDY_UID = "1.2.840.radiarch.test.study.1" +CT_A_UID = "1.2.840.radiarch.test.series.ct_a" # 10 instances +CT_B_UID = "1.2.840.radiarch.test.series.ct_b" # 3 instances (smaller) +RT_UID = "1.2.840.radiarch.test.series.rt" + +@pytest.fixture +def fake_adapter() -> FakePacsAdapter: + # Two CT series (to exercise "largest wins") + one RTSTRUCT. + metadata = [ + { + "_study": STUDY_UID, + "SeriesInstanceUID": CT_A_UID, + "Modality": "CT", + "NumberOfSeriesRelatedInstances": 10, + }, + { + "_study": STUDY_UID, + "SeriesInstanceUID": CT_B_UID, + "Modality": "CT", + "NumberOfSeriesRelatedInstances": 3, + }, + { + "_study": STUDY_UID, + "SeriesInstanceUID": RT_UID, + "Modality": "RTSTRUCT", + "NumberOfSeriesRelatedInstances": 1, + }, + ] + instances = { + CT_A_UID: [_make_instance(STUDY_UID, CT_A_UID, "CT") for _ in range(10)], + CT_B_UID: [_make_instance(STUDY_UID, CT_B_UID, "CT") for _ in range(3)], + RT_UID: [_make_instance(STUDY_UID, RT_UID, "RTSTRUCT")], + } + return FakePacsAdapter(metadata, instances) + + +# --------------------------------------------------------------------------- +# Capability probe +# --------------------------------------------------------------------------- + +class TestCapability: + def test_can_fetch_passes_through_from_adapter(self) -> None: + adapter = FakePacsAdapter([], {}, can_retrieve=True) + assert DicomFetcher(adapter).can_fetch is True + + def test_cannot_fetch_when_adapter_is_metadata_only(self) -> None: + adapter = FakePacsAdapter([], {}, can_retrieve=False) + assert DicomFetcher(adapter).can_fetch is False + + def test_fetch_raises_when_adapter_cannot_serve_bytes(self) -> None: + adapter = FakePacsAdapter([], {}, can_retrieve=False) + fetcher = DicomFetcher(adapter) + with pytest.raises(DicomFetcherError, match="does not support"): + fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) + + +# --------------------------------------------------------------------------- +# Series selection +# --------------------------------------------------------------------------- + +class TestSeriesSelection: + def test_picks_largest_ct_series_when_uid_not_given( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as staged: + # CT_A has 10 instances → beats CT_B's 3. + assert staged.ct_series_uid == CT_A_UID + # And the fetcher actually downloaded from CT_A, not CT_B. + assert CT_A_UID in fake_adapter.retrieve_calls + assert CT_B_UID not in fake_adapter.retrieve_calls + + def test_honors_explicit_ct_series_uid( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + # Force the smaller series even though the largest-wins heuristic + # would have picked CT_A. + with fetcher.fetch( + PatientRef(dicom_study_uid=STUDY_UID, ct_series_uid=CT_B_UID) + ) as staged: + assert staged.ct_series_uid == CT_B_UID + assert CT_B_UID in fake_adapter.retrieve_calls + + def test_honors_explicit_rtstruct_uid( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + with fetcher.fetch( + PatientRef(dicom_study_uid=STUDY_UID, rtstruct_uid=RT_UID) + ) as staged: + assert staged.rtstruct_series_uid == RT_UID + + def test_auto_detects_rtstruct_when_uid_not_given( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as staged: + assert staged.rtstruct_series_uid == RT_UID + + def test_rtstruct_is_none_when_study_has_no_rtstruct(self) -> None: + # Only a CT series, no RTSTRUCT. + metadata = [ + {"_study": STUDY_UID, "SeriesInstanceUID": CT_A_UID, "Modality": "CT", + "NumberOfSeriesRelatedInstances": 2}, + ] + instances = {CT_A_UID: [_make_instance(STUDY_UID, CT_A_UID, "CT") for _ in range(2)]} + adapter = FakePacsAdapter(metadata, instances) + with DicomFetcher(adapter).fetch(PatientRef(dicom_study_uid=STUDY_UID)) as staged: + assert staged.rtstruct_series_uid is None + + def test_deterministic_tiebreak_on_equal_size(self) -> None: + # Two CT series with the same instance count — sort must pick + # the same one deterministically across runs. + metadata = [ + {"_study": STUDY_UID, "SeriesInstanceUID": "uid.zzz", "Modality": "CT", + "NumberOfSeriesRelatedInstances": 5}, + {"_study": STUDY_UID, "SeriesInstanceUID": "uid.aaa", "Modality": "CT", + "NumberOfSeriesRelatedInstances": 5}, + ] + instances = { + "uid.zzz": [_make_instance(STUDY_UID, "uid.zzz", "CT") for _ in range(5)], + "uid.aaa": [_make_instance(STUDY_UID, "uid.aaa", "CT") for _ in range(5)], + } + adapter = FakePacsAdapter(metadata, instances) + fetcher = DicomFetcher(adapter) + # Run twice; both runs must pick the same UID (lexicographically + # smaller as tiebreak). + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as s1: + first = s1.ct_series_uid + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as s2: + assert s2.ct_series_uid == first == "uid.aaa" + + +# --------------------------------------------------------------------------- +# Staging + cleanup +# --------------------------------------------------------------------------- + +class TestStaging: + def test_writes_all_instances_to_temp_dir( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as staged: + files = sorted(staged.directory.iterdir()) + # 10 CTs + 1 RTSTRUCT + assert len(files) == 11 + assert staged.files_written == 11 + # Files are readable as DICOM. + for f in files: + ds = pydicom.dcmread(str(f)) + assert ds.StudyInstanceUID == STUDY_UID + + def test_context_manager_cleans_up_on_success( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as staged: + directory = staged.directory + assert directory.exists() + assert not directory.exists(), "temp dir must be wiped on __exit__" + + def test_context_manager_cleans_up_on_exception( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + captured: Optional[Path] = None + with pytest.raises(RuntimeError, match="boom"): + with fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) as staged: + captured = staged.directory + raise RuntimeError("boom") + assert captured is not None + assert not captured.exists() + + def test_explicit_cleanup_is_idempotent( + self, fake_adapter: FakePacsAdapter + ) -> None: + fetcher = DicomFetcher(fake_adapter) + staged = fetcher.fetch(PatientRef(dicom_study_uid=STUDY_UID)) + staged.cleanup() + staged.cleanup() # second call must not raise + assert not staged.directory.exists() + + +# --------------------------------------------------------------------------- +# Error surfaces +# --------------------------------------------------------------------------- + +class TestErrors: + def test_raises_when_study_has_no_ct_series(self) -> None: + # RTSTRUCT only, no CT → clear error, not silent fallback. + metadata = [ + {"_study": STUDY_UID, "SeriesInstanceUID": RT_UID, "Modality": "RTSTRUCT", + "NumberOfSeriesRelatedInstances": 1}, + ] + adapter = FakePacsAdapter(metadata, {}) + with pytest.raises(DicomFetcherError, match="No CT series"): + DicomFetcher(adapter).fetch(PatientRef(dicom_study_uid=STUDY_UID)) + + def test_raises_when_series_produces_zero_instances(self) -> None: + # Metadata says there's a CT series, but retrieve_series yields nothing. + metadata = [ + {"_study": STUDY_UID, "SeriesInstanceUID": CT_A_UID, "Modality": "CT", + "NumberOfSeriesRelatedInstances": 5}, + ] + adapter = FakePacsAdapter(metadata, {CT_A_UID: []}) + with pytest.raises(DicomFetcherError, match="zero instances"): + DicomFetcher(adapter).fetch(PatientRef(dicom_study_uid=STUDY_UID)) + + def test_tempdir_removed_after_failed_download(self) -> None: + # Adapter that claims metadata exists but raises during download — + # tempdir must still be wiped. + class BrokenAdapter(FakePacsAdapter): + def retrieve_series(self, study_uid, series_uid): + raise RuntimeError("network blew up") + + metadata = [ + {"_study": STUDY_UID, "SeriesInstanceUID": CT_A_UID, "Modality": "CT", + "NumberOfSeriesRelatedInstances": 5}, + ] + adapter = BrokenAdapter(metadata, {}) + with pytest.raises(RuntimeError, match="network blew up"): + DicomFetcher(adapter).fetch(PatientRef(dicom_study_uid=STUDY_UID)) + + # The prefix is namespaced, so it's easy to assert no leftovers. + import tempfile + tmp_root = Path(tempfile.gettempdir()) + leftovers = list(tmp_root.glob("radiarch_dicom_*")) + assert not leftovers, f"Temp dir leak: {leftovers}" diff --git a/tests/services/test_geometry_models.py b/tests/services/test_geometry_models.py new file mode 100644 index 0000000..33e5404 --- /dev/null +++ b/tests/services/test_geometry_models.py @@ -0,0 +1,133 @@ +"""Unit tests for the Pydantic schemas in radiarch.models.geometry.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from radiarch.models.geometry import ( + CTMetadata, + GeometryBuildRequest, + GeometryResult, + GridSpec, + HUDensityModel, + PatientRef, +) + + +class TestGridSpec: + def test_rejects_non_positive_spacing(self) -> None: + with pytest.raises(ValueError, match="strictly positive"): + GridSpec(spacing_mm=(1.0, 0.0, 1.0)) + + def test_rejects_non_positive_size(self) -> None: + with pytest.raises(ValueError, match="size entries"): + GridSpec(spacing_mm=(1, 1, 1), origin_mm=(0, 0, 0), size=(0, 1, 1)) + + def test_compute_affine(self) -> None: + spec = GridSpec( + spacing_mm=(2.0, 2.0, 3.0), + origin_mm=(10.0, -5.0, 0.0), + size=(4, 4, 4), + ) + expected = np.array( + [ + [2.0, 0.0, 0.0, 10.0], + [0.0, 2.0, 0.0, -5.0], + [0.0, 0.0, 3.0, 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + np.testing.assert_allclose(spec.to_numpy_affine(), expected) + + def test_affine_requires_origin(self) -> None: + spec = GridSpec(spacing_mm=(1, 1, 1)) + with pytest.raises(ValueError, match="origin_mm"): + spec.compute_affine() + + +class TestGeometryBuildRequest: + def _req(self, **overrides) -> GeometryBuildRequest: + base = dict( + patient_ref=PatientRef( + dicom_study_uid="1.2.3", + ct_series_uid="1.2.3.4", + rtstruct_uid="1.2.3.5", + ), + grid_spec=GridSpec( + spacing_mm=(1, 1, 1), origin_mm=(0, 0, 0), size=(64, 64, 64) + ), + hu_to_density_model=HUDensityModel.schneider, + structure_name_map={"PTV": ["PTV_60", "PTV60"]}, + ) + base.update(overrides) + return GeometryBuildRequest(**base) + + def test_cache_key_is_deterministic(self) -> None: + k1 = self._req().compute_cache_key() + k2 = self._req().compute_cache_key() + assert k1 == k2 + + def test_cache_key_changes_with_hu_model(self) -> None: + a = self._req(hu_to_density_model=HUDensityModel.schneider).compute_cache_key() + b = self._req(hu_to_density_model=HUDensityModel.linear).compute_cache_key() + assert a != b + + def test_cache_key_invariant_to_alias_case_and_order(self) -> None: + a = self._req(structure_name_map={"PTV": ["PTV_60", "PTV60"]}).compute_cache_key() + b = self._req(structure_name_map={"ptv": ["ptv60", "ptv_60"]}).compute_cache_key() + assert a == b + + def test_cache_key_ignores_data_root_override(self) -> None: + a = self._req(data_root_override="/tmp/a").compute_cache_key() + b = self._req(data_root_override="/tmp/b").compute_cache_key() + assert a == b + + +class TestGeometryResult: + def _spec(self) -> GridSpec: + return GridSpec( + spacing_mm=(1, 1, 1), origin_mm=(0, 0, 0), size=(10, 10, 10) + ) + + def _ct_meta(self) -> CTMetadata: + return CTMetadata(patient_name="TEST", modality="CT", num_slices=10) + + def test_rejects_label_zero(self) -> None: + with pytest.raises(ValueError, match=">= 1"): + GeometryResult( + geometry_id="g1", + density_grid_uri="/tmp/d.nii.gz", + structure_masks_uri="/tmp/m.nii.gz", + structure_index={"PTV": 0, "Cord": 1}, + grid_spec=self._spec(), + frame_of_reference_uid="1.2.3.9", + ct_metadata=self._ct_meta(), + cache_key="abc", + ) + + def test_rejects_duplicate_labels(self) -> None: + with pytest.raises(ValueError, match="unique"): + GeometryResult( + geometry_id="g1", + density_grid_uri="/tmp/d.nii.gz", + structure_masks_uri="/tmp/m.nii.gz", + structure_index={"PTV": 1, "Cord": 1}, + grid_spec=self._spec(), + frame_of_reference_uid="1.2.3.9", + ct_metadata=self._ct_meta(), + cache_key="abc", + ) + + def test_happy_path(self) -> None: + r = GeometryResult( + geometry_id="g1", + density_grid_uri="/tmp/d.nii.gz", + structure_masks_uri="/tmp/m.nii.gz", + structure_index={"PTV": 1, "Cord": 2}, + grid_spec=self._spec(), + frame_of_reference_uid="1.2.3.9", + ct_metadata=self._ct_meta(), + cache_key="abc", + ) + assert r.structure_index == {"PTV": 1, "Cord": 2} diff --git a/tests/services/test_geometry_service.py b/tests/services/test_geometry_service.py new file mode 100644 index 0000000..996e787 --- /dev/null +++ b/tests/services/test_geometry_service.py @@ -0,0 +1,317 @@ +"""Unit tests for radiarch.services.geometry.GeometryService. + +We stub out ``_load`` so these tests never touch OpenTPS or real DICOM; +they exercise the processing pipeline (HU conversion → optional resample +→ rasterization → persistence) on synthetic CT + fake contours. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import pytest + +from radiarch.models.geometry import ( + GeometryBuildRequest, + GridSpec, + HUDensityModel, + PatientRef, +) +from radiarch.services.geometry import GeometryService, _LoadedCT +from radiarch.services.persistence import _read_nifti + + +# --------------------------------------------------------------------------- +# Synthetic CT + fake contours +# --------------------------------------------------------------------------- + +@dataclass +class _FakePatient: + name: str = "TEST" + rtStructs: list = field(default_factory=list) + + +@dataclass +class _FakeCT: + imageArray: np.ndarray + origin: Tuple[float, float, float] + spacing: Tuple[float, float, float] + patient: _FakePatient + seriesInstanceUID: str = "1.2.3.4" + studyInstanceUID: str = "1.2.3" + frameOfReferenceUID: str = "1.2.3.9" + + +@dataclass +class _FakeMask: + imageArray: np.ndarray + + +@dataclass +class _FakeContour: + name: str + mask: np.ndarray + + def getBinaryMask(self, origin, gridSize, spacing): + # Ignore geometry args — tests build masks sized to the target grid. + return _FakeMask(imageArray=self.mask.astype(bool)) + + +def _water_ct(size=(10, 10, 10)) -> np.ndarray: + """HU=0 everywhere (water) with a HU=50 soft-tissue 'target' blob.""" + arr = np.zeros(size, dtype=np.int16) + arr[3:7, 3:7, 3:7] = 50 + return arr + + +def _build_loaded_ct( + ct_array: np.ndarray, + origin=(0.0, 0.0, 0.0), + spacing=(1.0, 1.0, 1.0), + contours: List[_FakeContour] | None = None, +) -> _LoadedCT: + ct = _FakeCT( + imageArray=ct_array, + origin=origin, + spacing=spacing, + patient=_FakePatient(), + ) + return _LoadedCT(ct=ct, patient=ct.patient, contours=contours or []) + + +def _simple_request(**overrides) -> GeometryBuildRequest: + base = dict( + patient_ref=PatientRef( + dicom_study_uid="1.2.3", + ct_series_uid="1.2.3.4", + rtstruct_uid="1.2.3.5", + ), + grid_spec=None, # inherit from CT + hu_to_density_model=HUDensityModel.linear, + structure_name_map=None, + ) + base.update(overrides) + return GeometryBuildRequest(**base) + + +# --------------------------------------------------------------------------- +# Happy path — identity grid, no resample +# --------------------------------------------------------------------------- + +class TestHappyPath: + def test_builds_end_to_end_with_identity_grid(self, tmp_path: Path, monkeypatch) -> None: + ct_array = _water_ct() + ptv = np.zeros(ct_array.shape, dtype=bool) + ptv[3:7, 3:7, 3:7] = True + loaded = _build_loaded_ct(ct_array, contours=[_FakeContour("PTV", ptv)]) + + service = GeometryService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load", lambda _req: loaded) + + result = service.build(_simple_request()) + + # Sanity: grid matches the CT, structure_index has PTV=1. + assert result.grid_spec.size == ct_array.shape + assert result.grid_spec.spacing_mm == (1.0, 1.0, 1.0) + assert result.structure_index == {"PTV": 1} + assert result.frame_of_reference_uid == "1.2.3.9" + + # Files on disk. + density, spec = _read_nifti(Path(result.density_grid_uri)) + masks, _ = _read_nifti(Path(result.structure_masks_uri)) + assert density.shape == ct_array.shape + assert masks.shape == ct_array.shape + + # Linear HU model: ρ = 1 + HU/1000 = 1 at HU=0 (clamped to 0 floor). + np.testing.assert_allclose(density[0, 0, 0], 1.0, atol=1e-5) + # Inside the blob (HU=50) → ρ = 1.05. + np.testing.assert_allclose(density[5, 5, 5], 1.05, atol=1e-5) + + def test_ct_metadata_populated(self, tmp_path: Path, monkeypatch) -> None: + loaded = _build_loaded_ct(_water_ct()) + service = GeometryService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load", lambda _req: loaded) + + result = service.build(_simple_request()) + assert result.ct_metadata.patient_name == "TEST" + assert result.ct_metadata.modality == "CT" + assert result.ct_metadata.num_slices == 10 + assert result.ct_metadata.series_instance_uid == "1.2.3.4" + assert result.ct_metadata.study_instance_uid == "1.2.3" + + +# --------------------------------------------------------------------------- +# Caching +# --------------------------------------------------------------------------- + +class TestCaching: + def test_second_build_hits_cache(self, tmp_path: Path, monkeypatch) -> None: + loaded = _build_loaded_ct(_water_ct()) + service = GeometryService(base_dir=tmp_path) + + load_calls = {"n": 0} + + def _counting_load(_req): + load_calls["n"] += 1 + return loaded + + monkeypatch.setattr(service, "_load", _counting_load) + + req = _simple_request() + r1 = service.build(req) + r2 = service.build(req) + + assert r1.geometry_id == r2.geometry_id + assert load_calls["n"] == 1, "cache hit must skip the DICOM load" + + def test_different_hu_model_misses_cache(self, tmp_path: Path, monkeypatch) -> None: + loaded = _build_loaded_ct(_water_ct()) + service = GeometryService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load", lambda _req: loaded) + + r1 = service.build(_simple_request(hu_to_density_model=HUDensityModel.linear)) + r2 = service.build(_simple_request(hu_to_density_model=HUDensityModel.schneider)) + + assert r1.geometry_id != r2.geometry_id + assert r1.cache_key != r2.cache_key + + +# --------------------------------------------------------------------------- +# Grid resampling path +# --------------------------------------------------------------------------- + +class TestCustomGrid: + def test_target_grid_spacing_triggers_resample(self, tmp_path: Path, monkeypatch) -> None: + ct_array = _water_ct() + loaded = _build_loaded_ct(ct_array, spacing=(2.0, 2.0, 2.0)) + service = GeometryService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load", lambda _req: loaded) + + # Request a 1mm grid → density must be resampled (2× upsample per axis). + target = GridSpec( + spacing_mm=(1.0, 1.0, 1.0), + origin_mm=(0.0, 0.0, 0.0), + size=(19, 19, 19), + ) + # No contours so the rasterizer doesn't need mask fakes sized to the target. + req = _simple_request(grid_spec=target) + result = service.build(req) + + assert result.grid_spec.spacing_mm == (1.0, 1.0, 1.0) + assert result.grid_spec.size == (19, 19, 19) + density, _ = _read_nifti(Path(result.density_grid_uri)) + assert density.shape == (19, 19, 19) + + def test_partial_grid_inherits_missing_fields(self, tmp_path: Path, monkeypatch) -> None: + """User supplies spacing only → origin + size adopted from CT.""" + ct_array = _water_ct() + loaded = _build_loaded_ct(ct_array, spacing=(2.0, 2.0, 2.0), origin=(5.0, 5.0, 5.0)) + service = GeometryService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load", lambda _req: loaded) + + partial = GridSpec(spacing_mm=(2.0, 2.0, 2.0)) + result = service.build(_simple_request(grid_spec=partial)) + + # Should fall back to CT origin + size → identity path (no resample). + assert result.grid_spec.origin_mm == (5.0, 5.0, 5.0) + assert result.grid_spec.size == ct_array.shape + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + +class TestValidation: + def test_rejects_2d_ct(self, tmp_path: Path, monkeypatch) -> None: + ct_array = np.zeros((4, 4), dtype=np.int16) # 2D + loaded = _build_loaded_ct(ct_array) + service = GeometryService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load", lambda _req: loaded) + + with pytest.raises(ValueError, match="must be 3D"): + service.build(_simple_request()) + + +# --------------------------------------------------------------------------- +# Load-path selection +# --------------------------------------------------------------------------- + +class _MockAdapter: + """Adapter stand-in used to probe _load's branching logic.""" + + def __init__(self, *, can_retrieve: bool) -> None: + self._can = can_retrieve + + def can_retrieve_instances(self) -> bool: + return self._can + + +class TestLoadPathSelection: + def test_falls_back_to_disk_when_adapter_cannot_retrieve( + self, tmp_path: Path, monkeypatch + ) -> None: + """Mock-mode adapter → disk loader, not DicomFetcher.""" + service = GeometryService(base_dir=tmp_path, adapter=_MockAdapter(can_retrieve=False)) + calls = {"disk": 0, "pacs": 0} + + def fake_disk(_root): + calls["disk"] += 1 + return _build_loaded_ct(_water_ct()) + + def fake_pacs(_fetcher, _req): # pragma: no cover — must not be called + calls["pacs"] += 1 + raise AssertionError("_load_from_pacs must not be invoked in mock mode") + + monkeypatch.setattr(service, "_load_from_disk", fake_disk) + monkeypatch.setattr(service, "_load_from_pacs", fake_pacs) + + service.build(_simple_request()) + assert calls == {"disk": 1, "pacs": 0} + + def test_uses_pacs_when_adapter_can_retrieve( + self, tmp_path: Path, monkeypatch + ) -> None: + service = GeometryService(base_dir=tmp_path, adapter=_MockAdapter(can_retrieve=True)) + calls = {"disk": 0, "pacs": 0} + + def fake_disk(_root): # pragma: no cover — must not be called + calls["disk"] += 1 + raise AssertionError("_load_from_disk must not be invoked when PACS is available") + + def fake_pacs(_fetcher, _req): + calls["pacs"] += 1 + return _build_loaded_ct(_water_ct()) + + monkeypatch.setattr(service, "_load_from_disk", fake_disk) + monkeypatch.setattr(service, "_load_from_pacs", fake_pacs) + + service.build(_simple_request()) + assert calls == {"disk": 0, "pacs": 1} + + def test_data_root_override_forces_disk_even_with_pacs( + self, tmp_path: Path, monkeypatch + ) -> None: + """data_root_override is a dev/debug escape hatch — it beats PACS.""" + service = GeometryService(base_dir=tmp_path, adapter=_MockAdapter(can_retrieve=True)) + seen_root = {"value": None} + + def fake_disk(root): + seen_root["value"] = root + return _build_loaded_ct(_water_ct()) + + monkeypatch.setattr(service, "_load_from_disk", fake_disk) + # _load_from_pacs should never be reached — if it is, the test + # fails because the returned _LoadedCT won't exist. + monkeypatch.setattr( + service, + "_load_from_pacs", + lambda *_args, **_kw: (_ for _ in ()).throw( + AssertionError("PACS path must be skipped when data_root_override is set") + ), + ) + + service.build(_simple_request(data_root_override="/tmp/fixtures")) + assert seen_root["value"] == "/tmp/fixtures" diff --git a/tests/services/test_hu_density.py b/tests/services/test_hu_density.py new file mode 100644 index 0000000..d0a7b5b --- /dev/null +++ b/tests/services/test_hu_density.py @@ -0,0 +1,125 @@ +"""Unit tests for the HU → density module. + +We deliberately avoid importing OpenTPS here — the Stoichiometric model is +exercised in the OpenTPS integration suite, not this fast-tier test file. +Schneider and Linear have no heavy deps and must stay fast. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +from radiarch.models.geometry import HUDensityModel +from radiarch.services.hu_density import ( + LinearModel, + SchneiderModel, + get_model, +) + + +# --------------------------------------------------------------------------- +# Schneider +# --------------------------------------------------------------------------- + +class TestSchneider: + def setup_method(self) -> None: + self.model = SchneiderModel() + + def test_air_is_near_zero(self) -> None: + rho = self.model.convert(np.array([-1000.0])) + assert rho.dtype == np.float32 + assert float(rho[0]) == pytest.approx(0.00121, abs=1e-4) + + def test_water_is_unity(self) -> None: + # HU=0 lies between -98 (fat, ρ=0.93) and 14 (soft tissue, ρ=1.03); + # at HU=0 the interpolation gives ρ ≈ 0.93 + (0 - -98)/(14 - -98) * (1.03 - 0.93) + # = 0.93 + 98/112 * 0.10 ≈ 1.0175. + rho = self.model.convert(np.array([0.0])) + assert 0.99 < float(rho[0]) < 1.03 + + def test_soft_tissue_is_close_to_unity(self) -> None: + rho = self.model.convert(np.array([14.0, 23.0])) + assert np.allclose(rho, 1.03, atol=1e-3) + + def test_bone_is_denser_than_water(self) -> None: + rho = self.model.convert(np.array([1000.0])) + assert float(rho[0]) > 1.4 + + def test_clamps_below_floor(self) -> None: + rho = self.model.convert(np.array([-5000.0])) + # np.interp clamps at left endpoint; must not go negative. + assert float(rho[0]) == pytest.approx(0.00121, abs=1e-4) + + def test_clamps_above_cap(self) -> None: + rho = self.model.convert(np.array([10_000.0])) + assert float(rho[0]) == pytest.approx(2.88, abs=1e-3) + + def test_monotone_non_decreasing_on_dense_range(self) -> None: + hu = np.linspace(-1000, 2000, 200) + rho = self.model.convert(hu) + # Allow equality (the HU=14..23 plateau is intentional). + assert np.all(np.diff(rho) >= -1e-5), "Schneider curve must be non-decreasing" + + def test_preserves_shape(self) -> None: + hu = np.arange(24, dtype=np.int16).reshape(2, 3, 4) + rho = self.model.convert(hu.astype(np.float32)) + assert rho.shape == hu.shape + assert rho.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# Linear +# --------------------------------------------------------------------------- + +class TestLinear: + def setup_method(self) -> None: + self.model = LinearModel() + + def test_water_is_exactly_one(self) -> None: + rho = self.model.convert(np.array([0.0])) + assert float(rho[0]) == pytest.approx(1.0, abs=1e-7) + + def test_air_is_zero_not_negative(self) -> None: + rho = self.model.convert(np.array([-1000.0])) + assert float(rho[0]) == pytest.approx(0.0, abs=1e-7) + + def test_extreme_negative_clamped_to_zero(self) -> None: + rho = self.model.convert(np.array([-5000.0])) + # 1 + (-5000)/1000 = -4 → clamped to 0 + assert float(rho[0]) == 0.0 + + def test_bone_density(self) -> None: + rho = self.model.convert(np.array([1000.0])) + assert float(rho[0]) == pytest.approx(2.0, abs=1e-6) + + def test_vectorized_shape(self) -> None: + hu = np.random.RandomState(0).uniform(-1000, 2000, size=(5, 5, 5)) + rho = self.model.convert(hu) + assert rho.shape == (5, 5, 5) + assert rho.dtype == np.float32 + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +class TestFactory: + def test_enum_dispatch(self) -> None: + assert isinstance(get_model(HUDensityModel.schneider), SchneiderModel) + assert isinstance(get_model(HUDensityModel.linear), LinearModel) + + def test_string_dispatch(self) -> None: + assert isinstance(get_model("SCHNEIDER"), SchneiderModel) + assert isinstance(get_model("LINEAR"), LinearModel) + + def test_rejects_unknown_string(self) -> None: + with pytest.raises(ValueError, match="Unknown HU density model"): + get_model("PLASTICINE") + + def test_callable_interface(self) -> None: + # get_model(...) should return an object you can call like a function. + model = get_model(HUDensityModel.linear) + rho_call = model(np.array([0.0])) + rho_convert = model.convert(np.array([0.0])) + np.testing.assert_array_equal(rho_call, rho_convert) diff --git a/tests/services/test_persistence.py b/tests/services/test_persistence.py new file mode 100644 index 0000000..c720c47 --- /dev/null +++ b/tests/services/test_persistence.py @@ -0,0 +1,190 @@ +"""Unit tests for radiarch.services.persistence.""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from radiarch.models.geometry import ( + CTMetadata, + GeometryResult, + GridSpec, +) +from radiarch.services.persistence import ( + DENSITY_FILENAME, + GeometryStore, + MASKS_FILENAME, + META_FILENAME, + _read_nifti, + _write_nifti, +) + + +def _make_spec(size=(4, 4, 4), spacing=(2.0, 2.0, 3.0), origin=(1.0, -2.0, 0.5)) -> GridSpec: + spec = GridSpec(spacing_mm=spacing, origin_mm=origin, size=size) + spec.affine = spec.compute_affine() + return spec + + +def _make_result(density_path: str, masks_path: str, spec: GridSpec, cache_key: str = "abc") -> GeometryResult: + return GeometryResult( + geometry_id="g1", + density_grid_uri=density_path, + structure_masks_uri=masks_path, + structure_index={"PTV": 1, "Cord": 2}, + grid_spec=spec, + frame_of_reference_uid="1.2.3.9", + ct_metadata=CTMetadata(patient_name="TEST", modality="CT", num_slices=4), + cache_key=cache_key, + ) + + +# --------------------------------------------------------------------------- +# NIfTI round-trip +# --------------------------------------------------------------------------- + +class TestNiftiRoundtrip: + def test_float_density_roundtrip_preserves_values_and_affine(self, tmp_path: Path) -> None: + rng = np.random.RandomState(0) + vol = rng.uniform(0.0, 2.0, size=(4, 4, 4)).astype(np.float32) + spec = _make_spec() + + path = tmp_path / "density.nii.gz" + _write_nifti(vol, spec, path) + loaded, loaded_spec = _read_nifti(path) + + assert loaded.dtype == np.float32 + np.testing.assert_allclose(loaded, vol, rtol=1e-5) + assert loaded_spec.spacing_mm == spec.spacing_mm + assert loaded_spec.origin_mm == spec.origin_mm + assert loaded_spec.size == spec.size + + def test_uint16_labels_preserve_dtype_and_values(self, tmp_path: Path) -> None: + labels = np.zeros((4, 4, 4), dtype=np.uint16) + labels[0:2, 0:2, 0:2] = 1 + labels[2:4, 2:4, 2:4] = 2 + + path = tmp_path / "masks.nii.gz" + _write_nifti(labels, _make_spec(), path) + loaded, _ = _read_nifti(path) + + # SimpleITK may promote uint16 → int16 on the round-trip for + # some encoders; assert value preservation, not dtype identity, + # but verify the stored values are lossless. + np.testing.assert_array_equal(loaded.astype(np.int32), labels.astype(np.int32)) + + def test_ijk_ordering_preserved(self, tmp_path: Path) -> None: + """Write a volume with distinct axis lengths, then verify the shape + survives the SimpleITK (z,y,x) transpose round-trip.""" + vol = np.arange(2 * 3 * 5, dtype=np.float32).reshape((2, 3, 5)) + spec = GridSpec(spacing_mm=(1, 1, 1), origin_mm=(0, 0, 0), size=(2, 3, 5)) + path = tmp_path / "v.nii.gz" + _write_nifti(vol, spec, path) + loaded, loaded_spec = _read_nifti(path) + + assert loaded.shape == (2, 3, 5) + assert loaded_spec.size == (2, 3, 5) + np.testing.assert_allclose(loaded, vol) + + +# --------------------------------------------------------------------------- +# GeometryStore +# --------------------------------------------------------------------------- + +class TestGeometryStore: + def test_save_writes_expected_files(self, tmp_path: Path) -> None: + store = GeometryStore(tmp_path) + spec = _make_spec() + density = np.ones(spec.size, dtype=np.float32) + masks = np.zeros(spec.size, dtype=np.uint16) + masks[0, 0, 0] = 1 + density_path = str(tmp_path / "g1" / DENSITY_FILENAME) + masks_path = str(tmp_path / "g1" / MASKS_FILENAME) + result = _make_result(density_path, masks_path, spec) + + store.save( + geometry_id="g1", + cache_key="abc", + density=density, + masks=masks, + result=result, + ) + + assert (tmp_path / "g1" / DENSITY_FILENAME).exists() + assert (tmp_path / "g1" / MASKS_FILENAME).exists() + assert (tmp_path / "g1" / META_FILENAME).exists() + + def test_cache_lookup_roundtrip(self, tmp_path: Path) -> None: + store = GeometryStore(tmp_path) + spec = _make_spec() + result = _make_result( + str(tmp_path / "g1" / DENSITY_FILENAME), + str(tmp_path / "g1" / MASKS_FILENAME), + spec, + cache_key="deadbeef", + ) + store.save( + geometry_id="g1", + cache_key="deadbeef", + density=np.zeros(spec.size, dtype=np.float32), + masks=np.zeros(spec.size, dtype=np.uint16), + result=result, + ) + + hit = store.lookup_by_cache_key("deadbeef") + assert hit is not None + assert hit.geometry_id == "g1" + assert hit.structure_index == {"PTV": 1, "Cord": 2} + + def test_cache_miss_returns_none(self, tmp_path: Path) -> None: + store = GeometryStore(tmp_path) + assert store.lookup_by_cache_key("nope") is None + assert store.get_by_id("nope") is None + + def test_save_is_atomic_on_retry(self, tmp_path: Path) -> None: + """A second save with the same geometry_id overwrites cleanly.""" + store = GeometryStore(tmp_path) + spec = _make_spec() + density_a = np.zeros(spec.size, dtype=np.float32) + masks_a = np.zeros(spec.size, dtype=np.uint16) + density_b = np.ones(spec.size, dtype=np.float32) + masks_b = np.ones(spec.size, dtype=np.uint16) + result = _make_result( + str(tmp_path / "g1" / DENSITY_FILENAME), + str(tmp_path / "g1" / MASKS_FILENAME), + spec, + ) + store.save( + geometry_id="g1", cache_key="k", density=density_a, masks=masks_a, result=result + ) + store.save( + geometry_id="g1", cache_key="k", density=density_b, masks=masks_b, result=result + ) + + loaded, _ = _read_nifti(tmp_path / "g1" / DENSITY_FILENAME) + np.testing.assert_allclose(loaded, density_b) + + def test_list_ids_excludes_tmp_dirs(self, tmp_path: Path) -> None: + store = GeometryStore(tmp_path) + # Simulate a partial write leftover — store.save() shouldn't have + # created these, but a crash might have. + (tmp_path / ".g2.tmp.xyz").mkdir() + (tmp_path / "g3").mkdir() # no meta.json → ignored + assert store.list_ids() == [] + + spec = _make_spec() + result = _make_result( + str(tmp_path / "g1" / DENSITY_FILENAME), + str(tmp_path / "g1" / MASKS_FILENAME), + spec, + ) + store.save( + geometry_id="g1", + cache_key="k", + density=np.zeros(spec.size, dtype=np.float32), + masks=np.zeros(spec.size, dtype=np.uint16), + result=result, + ) + assert store.list_ids() == ["g1"] diff --git a/tests/services/test_rasterization.py b/tests/services/test_rasterization.py new file mode 100644 index 0000000..506b46b --- /dev/null +++ b/tests/services/test_rasterization.py @@ -0,0 +1,215 @@ +"""Unit tests for radiarch.services.rasterization.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import pytest + +from radiarch.models.geometry import GridSpec +from radiarch.services.rasterization import ( + TARGET_CANONICAL_NAMES, + rasterize_contours, +) + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + +@dataclass +class _FakeMask: + imageArray: np.ndarray + + +@dataclass +class _FakeContour: + """Duck-types the OpenTPS ROIContour shape we use — just .name and + getBinaryMask. The mask is prebaked at construction; the grid args + passed to ``getBinaryMask`` are validated in a couple of tests.""" + + name: str + mask: np.ndarray + last_origin: Tuple[float, float, float] | None = None + last_grid_size: Tuple[int, int, int] | None = None + last_spacing: Tuple[float, float, float] | None = None + + def getBinaryMask(self, origin, gridSize, spacing): + self.last_origin = origin + self.last_grid_size = gridSize + self.last_spacing = spacing + return _FakeMask(imageArray=self.mask.astype(bool)) + + +def _grid(size=(4, 4, 4)) -> GridSpec: + return GridSpec(spacing_mm=(1.0, 1.0, 1.0), origin_mm=(0.0, 0.0, 0.0), size=size) + + +def _disjoint_masks() -> tuple[np.ndarray, np.ndarray]: + ptv = np.zeros((4, 4, 4), dtype=bool) + ptv[0:2, 0:2, 0:2] = True + cord = np.zeros((4, 4, 4), dtype=bool) + cord[2:4, 2:4, 2:4] = True + return ptv, cord + + +# --------------------------------------------------------------------------- +# Basic packing +# --------------------------------------------------------------------------- + +class TestBasicRasterization: + def test_two_disjoint_contours_get_labels_1_and_2(self) -> None: + ptv, cord = _disjoint_masks() + contours = [_FakeContour("SpinalCord", cord), _FakeContour("PTV", ptv)] + mask, index = rasterize_contours(contours, _grid()) + + # PTV is a target, so it always sorts first and gets label 1. + assert index == {"PTV": 1, "SpinalCord": 2} + assert int(mask[0, 0, 0]) == 1 + assert int(mask[3, 3, 3]) == 2 + # Background is exactly zero everywhere else. + claimed = np.zeros_like(mask, dtype=bool) + claimed[0:2, 0:2, 0:2] = True + claimed[2:4, 2:4, 2:4] = True + assert np.all(mask[~claimed] == 0) + + def test_returns_uint16(self) -> None: + ptv, _ = _disjoint_masks() + mask, _ = rasterize_contours([_FakeContour("PTV", ptv)], _grid()) + assert mask.dtype == np.uint16 + + def test_passes_correct_grid_args_to_contour(self) -> None: + ptv, _ = _disjoint_masks() + contour = _FakeContour("PTV", ptv) + grid = GridSpec(spacing_mm=(2.0, 2.0, 3.0), origin_mm=(10.0, 20.0, -5.0), size=(4, 4, 4)) + rasterize_contours([contour], grid) + + assert contour.last_origin == (10.0, 20.0, -5.0) + assert contour.last_grid_size == (4, 4, 4) + assert contour.last_spacing == (2.0, 2.0, 3.0) + + +# --------------------------------------------------------------------------- +# Alias resolution +# --------------------------------------------------------------------------- + +class TestAliasResolution: + def test_alias_collapses_to_canonical(self) -> None: + ptv, _ = _disjoint_masks() + contours = [_FakeContour("PTV_60", ptv)] + name_map = {"PTV": ["PTV_60", "PTV60"]} + _, index = rasterize_contours(contours, _grid(), structure_name_map=name_map) + assert index == {"PTV": 1} + + def test_alias_match_is_case_insensitive(self) -> None: + ptv, _ = _disjoint_masks() + contours = [_FakeContour("ptv_60", ptv)] + name_map = {"PTV": ["PTV_60"]} + _, index = rasterize_contours(contours, _grid(), structure_name_map=name_map) + assert "PTV" in index + + def test_no_name_map_uses_contour_name(self) -> None: + ptv, _ = _disjoint_masks() + contours = [_FakeContour("BrainStem", ptv)] + _, index = rasterize_contours(contours, _grid()) + assert index == {"BrainStem": 1} + + def test_first_contour_wins_when_aliases_collide(self) -> None: + m1 = np.zeros((4, 4, 4), dtype=bool) + m1[0, 0, 0] = True + m2 = np.zeros((4, 4, 4), dtype=bool) + m2[3, 3, 3] = True + contours = [_FakeContour("PTV_60", m1), _FakeContour("PTV60", m2)] + mask, index = rasterize_contours( + contours, _grid(), structure_name_map={"PTV": ["PTV_60", "PTV60"]} + ) + # Both alias to PTV → only the first one is rasterized. + assert index == {"PTV": 1} + assert int(mask[0, 0, 0]) == 1 + assert int(mask[3, 3, 3]) == 0 + + +# --------------------------------------------------------------------------- +# Label ordering +# --------------------------------------------------------------------------- + +class TestLabelOrdering: + def test_targets_get_lowest_labels(self) -> None: + """Given a mixed bag, PTV/GTV/CTV always sort before OARs.""" + masks = {name: np.zeros((4, 4, 4), dtype=bool) for name in ("PTV", "GTV", "CTV", "BrainStem", "Parotid")} + # Mark distinct voxels so we can tell labels apart. + masks["PTV"][0, 0, 0] = True + masks["GTV"][0, 0, 1] = True + masks["CTV"][0, 0, 2] = True + masks["BrainStem"][0, 1, 0] = True + masks["Parotid"][0, 1, 1] = True + + # Deliberately shuffle input order — ordering must still be + # (PTV, GTV, CTV, BrainStem, Parotid). + contours = [_FakeContour(n, m) for n, m in masks.items()] + contours = contours[::-1] + _, index = rasterize_contours(contours, _grid()) + + ordered = [name for name, _ in sorted(index.items(), key=lambda kv: kv[1])] + assert ordered == ["PTV", "GTV", "CTV", "BrainStem", "Parotid"] + + def test_target_prefix_still_counts_as_target(self) -> None: + """PTV_Boost (without a name map) still sorts with the targets.""" + ptv_boost = np.zeros((4, 4, 4), dtype=bool); ptv_boost[0, 0, 0] = True + oar = np.zeros((4, 4, 4), dtype=bool); oar[1, 1, 1] = True + contours = [_FakeContour("Parotid", oar), _FakeContour("PTV_Boost", ptv_boost)] + _, index = rasterize_contours(contours, _grid()) + assert index["PTV_Boost"] == 1 + assert index["Parotid"] == 2 + + +# --------------------------------------------------------------------------- +# Overlap policy +# --------------------------------------------------------------------------- + +class TestOverlap: + def test_first_match_wins_at_voxel_level(self) -> None: + """PTV and Brain overlap on (1,1,1). PTV is the target → it wins.""" + ptv = np.zeros((4, 4, 4), dtype=bool) + ptv[0:2, 0:2, 0:2] = True + brain = np.zeros((4, 4, 4), dtype=bool) + brain[1:3, 1:3, 1:3] = True + contours = [_FakeContour("Brain", brain), _FakeContour("PTV", ptv)] + mask, index = rasterize_contours(contours, _grid()) + + assert index == {"PTV": 1, "Brain": 2} + # Overlap voxel (1,1,1): PTV claims it. + assert int(mask[1, 1, 1]) == 1 + # Brain-only voxel (2,2,2): Brain still wins. + assert int(mask[2, 2, 2]) == 2 + + +# --------------------------------------------------------------------------- +# Robustness / error surfaces +# --------------------------------------------------------------------------- + +class TestRobustness: + def test_rejects_underspecified_grid(self) -> None: + partial = GridSpec(spacing_mm=(1, 1, 1)) + with pytest.raises(ValueError, match="fully specified"): + rasterize_contours([], partial) + + def test_empty_contours_produces_empty_index(self) -> None: + mask, index = rasterize_contours([], _grid()) + assert index == {} + assert mask.shape == (4, 4, 4) + assert np.all(mask == 0) + + def test_shape_mismatch_is_skipped_not_raised(self) -> None: + wrong_shape = np.zeros((3, 3, 3), dtype=bool) + contours = [_FakeContour("BadMask", wrong_shape)] + mask, index = rasterize_contours(contours, _grid()) + # Skipped with a warning; final volume is still empty. + assert index == {} + assert np.all(mask == 0) + + def test_target_canonical_names_export(self) -> None: + """Guard the public constant used by tests and downstream callers.""" + assert TARGET_CANONICAL_NAMES == ("PTV", "GTV", "CTV") diff --git a/tests/services/test_resampling.py b/tests/services/test_resampling.py new file mode 100644 index 0000000..0f0c922 --- /dev/null +++ b/tests/services/test_resampling.py @@ -0,0 +1,145 @@ +"""Unit tests for the grid resampling module.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from radiarch.models.geometry import GridSpec +from radiarch.services.resampling import ( + identity_grid_from_affine, + resample_to_grid, +) + + +def _identity_affine(spacing=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0)) -> np.ndarray: + sx, sy, sz = spacing + ox, oy, oz = origin + return np.array( + [ + [sx, 0.0, 0.0, ox], + [0.0, sy, 0.0, oy], + [0.0, 0.0, sz, oz], + [0.0, 0.0, 0.0, 1.0], + ], + dtype=np.float64, + ) + + +class TestIdentityResample: + def test_same_grid_is_a_noop(self) -> None: + rng = np.random.RandomState(0) + vol = rng.uniform(-1000, 1000, size=(8, 8, 8)).astype(np.float32) + affine = _identity_affine(spacing=(2.0, 2.0, 3.0), origin=(10.0, -20.0, 0.0)) + target = identity_grid_from_affine(affine, size=vol.shape) + + out = resample_to_grid(vol, affine, target, order=1) + np.testing.assert_allclose(out, vol, atol=1e-5) + + def test_identity_grid_from_affine_roundtrip(self) -> None: + affine = _identity_affine(spacing=(1.5, 1.5, 2.0), origin=(5.0, 5.0, 5.0)) + spec = identity_grid_from_affine(affine, size=(4, 4, 4)) + assert spec.spacing_mm == (1.5, 1.5, 2.0) + assert spec.origin_mm == (5.0, 5.0, 5.0) + assert spec.size == (4, 4, 4) + np.testing.assert_allclose(spec.to_numpy_affine(), affine) + + +class TestContinuousResample: + def test_upsample_halving_spacing_preserves_smooth_field(self) -> None: + # Build a low-frequency linear ramp on a coarse grid, then + # upsample 2× — trilinear must reproduce the analytic field. + x = np.arange(10, dtype=np.float32) + y = np.arange(10, dtype=np.float32) + z = np.arange(10, dtype=np.float32) + X, Y, Z = np.meshgrid(x, y, z, indexing="ij") + ramp = (X + 2 * Y + 3 * Z).astype(np.float32) + src_affine = _identity_affine(spacing=(2.0, 2.0, 2.0), origin=(0.0, 0.0, 0.0)) + + # Target: same extent, half the spacing, so 2× more voxels per axis. + target = GridSpec( + spacing_mm=(1.0, 1.0, 1.0), + origin_mm=(0.0, 0.0, 0.0), + size=(19, 19, 19), + ) + out = resample_to_grid(ramp, src_affine, target, order=1) + + # Analytic comparison: at target index (i,j,k), mm coords are + # (i*1, j*1, k*1); in source index that's (i/2, j/2, k/2); the + # ramp value there is i/2 + 2*(j/2) + 3*(k/2) = 0.5i + j + 1.5k. + ii, jj, kk = np.meshgrid( + np.arange(19), np.arange(19), np.arange(19), indexing="ij" + ) + expected = (0.5 * ii + jj + 1.5 * kk).astype(np.float32) + np.testing.assert_allclose(out, expected, atol=1e-4) + + def test_out_of_bounds_filled_with_cval(self) -> None: + vol = np.ones((4, 4, 4), dtype=np.float32) * 7.0 + src_affine = _identity_affine(spacing=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0)) + # Target extends past the source → out-of-bounds voxels filled with cval. + target = GridSpec( + spacing_mm=(1.0, 1.0, 1.0), + origin_mm=(0.0, 0.0, 0.0), + size=(8, 8, 8), + ) + out = resample_to_grid(vol, src_affine, target, order=1, cval=-999.0) + + # The inner 4×4×4 block is still 7.0; the rest is -999. + assert out[0, 0, 0] == pytest.approx(7.0) + assert out[7, 7, 7] == pytest.approx(-999.0) + + +class TestLabelResample: + def test_nearest_neighbor_preserves_integer_labels(self) -> None: + labels = np.zeros((4, 4, 4), dtype=np.uint16) + labels[1:3, 1:3, 1:3] = 5 # cube of label 5 + src_affine = _identity_affine(spacing=(1.0, 1.0, 1.0), origin=(0.0, 0.0, 0.0)) + # Shift + upsample the target grid, forcing real resampling. + target = GridSpec( + spacing_mm=(0.5, 0.5, 0.5), + origin_mm=(0.0, 0.0, 0.0), + size=(7, 7, 7), + ) + out = resample_to_grid(labels, src_affine, target, order=0, cval=0) + + # No fractional labels were invented. + unique = set(np.unique(out).tolist()) + assert unique <= {0, 5} + assert out.dtype == labels.dtype + + def test_label_resample_integer_cval(self) -> None: + labels = np.ones((3, 3, 3), dtype=np.uint16) * 9 + src_affine = _identity_affine() + target = GridSpec( + spacing_mm=(1.0, 1.0, 1.0), + origin_mm=(-5.0, -5.0, -5.0), + size=(3, 3, 3), + ) + out = resample_to_grid(labels, src_affine, target, order=0, cval=0) + # Completely outside source → all background. + assert np.all(out == 0) + + +class TestAffineValidation: + def test_rejects_rotated_affine(self) -> None: + rotated = _identity_affine() + # Plant a small rotation into the x/y plane. + rotated[0, 1] = 0.1 + rotated[1, 0] = -0.1 + target = GridSpec(spacing_mm=(1, 1, 1), origin_mm=(0, 0, 0), size=(2, 2, 2)) + vol = np.zeros((2, 2, 2), dtype=np.float32) + with pytest.raises(NotImplementedError, match="Rotated affines"): + resample_to_grid(vol, rotated, target, order=1) + + def test_rejects_zero_spacing(self) -> None: + bad = _identity_affine(spacing=(1.0, 0.0, 1.0)) + target = GridSpec(spacing_mm=(1, 1, 1), origin_mm=(0, 0, 0), size=(2, 2, 2)) + vol = np.zeros((2, 2, 2), dtype=np.float32) + with pytest.raises(ValueError, match="Non-positive spacing"): + resample_to_grid(vol, bad, target, order=1) + + def test_requires_fully_specified_target(self) -> None: + partial = GridSpec(spacing_mm=(1, 1, 1)) # origin/size omitted + vol = np.zeros((2, 2, 2), dtype=np.float32) + with pytest.raises(ValueError, match="fully specified"): + resample_to_grid(vol, _identity_affine(), partial, order=1) diff --git a/tests/test_api_geometry.py b/tests/test_api_geometry.py new file mode 100644 index 0000000..f9f0d23 --- /dev/null +++ b/tests/test_api_geometry.py @@ -0,0 +1,165 @@ +"""End-to-end tests for the /geometry/* routes. + +Uses FastAPI's TestClient and a monkey-patched GeometryService so we can +drive the pipeline against synthetic CT + contours — no OpenTPS, no +real DICOM, no Celery. +""" + +from __future__ import annotations + +import tempfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import pytest +from fastapi.testclient import TestClient + +from radiarch import app as radiarch_app +from radiarch.api.routes import geometry as geometry_route +from radiarch.app import create_app +from radiarch.services.geometry import GeometryService, _LoadedCT + + +# --------------------------------------------------------------------------- +# Fakes (trimmed copies of test_geometry_service.py) +# --------------------------------------------------------------------------- + +@dataclass +class _FakePatient: + name: str = "API_TEST" + rtStructs: list = field(default_factory=list) + + +@dataclass +class _FakeCT: + imageArray: np.ndarray + origin: Tuple[float, float, float] + spacing: Tuple[float, float, float] + patient: _FakePatient + seriesInstanceUID: str = "1.2.3.4" + studyInstanceUID: str = "1.2.3" + frameOfReferenceUID: str = "1.2.3.9" + + +@dataclass +class _FakeMask: + imageArray: np.ndarray + + +@dataclass +class _FakeContour: + name: str + mask: np.ndarray + + def getBinaryMask(self, origin, gridSize, spacing): + return _FakeMask(imageArray=self.mask.astype(bool)) + + +def _build_loaded_ct() -> _LoadedCT: + ct_array = np.zeros((8, 8, 8), dtype=np.int16) + ct_array[2:6, 2:6, 2:6] = 50 + ptv = np.zeros(ct_array.shape, dtype=bool) + ptv[2:6, 2:6, 2:6] = True + ct = _FakeCT(imageArray=ct_array, origin=(0, 0, 0), spacing=(1, 1, 1), patient=_FakePatient()) + return _LoadedCT(ct=ct, patient=ct.patient, contours=[_FakeContour("PTV", ptv)]) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def client(monkeypatch): + """FastAPI client with a sandboxed, stubbed GeometryService singleton. + + We intercept two things that would otherwise break an offline pytest run: + * ``init_db`` — the real implementation tries to connect to Postgres, + which isn't available outside docker-compose. Stub to a no-op. + * ``geometry_route._service`` — swap the lru_cached factory for a + lambda returning our stubbed service. Clear the lru cache *before* + monkeypatching so any prior test's cached instance doesn't leak. + """ + tmp = tempfile.TemporaryDirectory() + svc = GeometryService(base_dir=tmp.name) + monkeypatch.setattr(svc, "_load", lambda _req: _build_loaded_ct()) + + # Clear the real lru_cache before we replace the function — otherwise + # the *previous* test's cached service would leak via the module global + # once monkeypatch reverts at teardown. + geometry_route._service.cache_clear() + + # No-op init_db so the app lifespan doesn't try to reach Postgres. + monkeypatch.setattr(radiarch_app, "init_db", lambda: None) + monkeypatch.setattr(geometry_route, "_service", lambda: svc) + + app = create_app() + with TestClient(app) as c: + yield c + tmp.cleanup() + # NOTE: don't call geometry_route._service.cache_clear() here — at this + # point it's still the monkeypatched lambda (which has no cache_clear). + # monkeypatch teardown restores the real lru_cached function on exit. + + +def _sample_payload(grid_spec=None) -> dict: + return { + "patient_ref": { + "dicom_study_uid": "1.2.3", + "ct_series_uid": "1.2.3.4", + "rtstruct_uid": "1.2.3.5", + }, + "grid_spec": grid_spec, + "hu_to_density_model": "LINEAR", + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestBuildEndpoint: + def test_happy_path_returns_geometry_result(self, client: TestClient) -> None: + r = client.post("/api/v1/geometry/build", json=_sample_payload()) + assert r.status_code == 200, r.text + body = r.json() + + assert body["geometry_id"] + assert body["structure_index"] == {"PTV": 1} + assert body["frame_of_reference_uid"] == "1.2.3.9" + assert body["ct_metadata"]["num_slices"] == 8 + assert body["grid_spec"]["size"] == [8, 8, 8] + assert body["cache_key"] + + def test_cached_second_build_returns_same_geometry_id(self, client: TestClient) -> None: + first = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + second = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + assert first["geometry_id"] == second["geometry_id"] + + +class TestGetEndpoint: + def test_roundtrip_build_then_fetch(self, client: TestClient) -> None: + built = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + fetched = client.get(f"/api/v1/geometry/{built['geometry_id']}").json() + assert fetched["geometry_id"] == built["geometry_id"] + assert fetched["cache_key"] == built["cache_key"] + + def test_unknown_id_is_404(self, client: TestClient) -> None: + r = client.get("/api/v1/geometry/does-not-exist") + assert r.status_code == 404 + + +class TestVolumeStreaming: + def test_density_stream_returns_nifti_bytes(self, client: TestClient) -> None: + built = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + r = client.get(f"/api/v1/geometry/{built['geometry_id']}/density") + assert r.status_code == 200 + # gzipped NIfTI starts with the gzip magic 0x1f 0x8b. + assert r.content[:2] == b"\x1f\x8b" + + def test_masks_stream_returns_nifti_bytes(self, client: TestClient) -> None: + built = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + r = client.get(f"/api/v1/geometry/{built['geometry_id']}/masks") + assert r.status_code == 200 + assert r.content[:2] == b"\x1f\x8b" From fa427def96e7471b8de7c002c4de0525dc83cdb8 Mon Sep 17 00:00:00 2001 From: vidyuthdev Date: Sat, 16 May 2026 12:59:03 -0400 Subject: [PATCH 3/3] Add Service 2 (Beam Model), async geometry jobs, and DICOM upload MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Service 1 — Geometry async + DICOM upload: - Celery task for async geometry builds + jobs endpoint + Postgres-backed job rows - POST /api/v1/uploads/dicom accepts ZIP, returns upload_id - PatientRef.upload_id wired into GeometryService._load - Validated end-to-end on TCIA LCTSC clinical patient Service 2 — Beam Model: - Pydantic models, machine_model loader, on-disk content-addressable cache - Proton spot generator + photon beamlet generator - BeamModelService orchestrator with 6-stage progress - Celery task + jobs table + Alembic migration - 5-endpoint REST surface (build / job / get / artifact / delete) Refactor: - _helpers.load_bdl / setup_calibration delegate to ProtonMachineModel Tests: 271 passing (259 prior + 12 new for uploads) Demos: demo/show_geometry.py (--upload flag) and demo/show_beam_model.py Docs: demo/README_DICOM.md (TCIA download walkthrough) --- demo/README_DICOM.md | 109 ++++++ demo/show_beam_model.py | 273 ++++++++++++++ demo/show_geometry.py | 260 +++++++++++++ package-lock.json | 6 + .../versions/b1e7c1f3a4f2_geometry_jobs.py | 49 +++ .../versions/c2f8d2a4b5e3_beam_model_jobs.py | 49 +++ src/radiarch/api/routes/beam_model.py | 171 +++++++++ src/radiarch/api/routes/geometry.py | 129 ++++++- src/radiarch/api/routes/uploads.py | 250 +++++++++++++ src/radiarch/app.py | 4 +- src/radiarch/config.py | 7 + src/radiarch/core/db_models.py | 47 +++ src/radiarch/core/store.py | 326 ++++++++++++++++ src/radiarch/core/workflows/_helpers.py | 38 +- src/radiarch/models/beam_model.py | 303 +++++++++++++++ src/radiarch/models/geometry.py | 81 +++- src/radiarch/services/beam_model.py | 276 ++++++++++++++ src/radiarch/services/beam_persistence.py | 209 +++++++++++ src/radiarch/services/geometry.py | 82 ++++- src/radiarch/services/machine_model.py | 273 ++++++++++++++ src/radiarch/services/persistence.py | 35 ++ src/radiarch/services/photon_beamlets.py | 152 ++++++++ src/radiarch/services/proton_spots.py | 211 +++++++++++ src/radiarch/tasks/beam_model_tasks.py | 91 +++++ src/radiarch/tasks/celery_app.py | 6 +- src/radiarch/tasks/geometry_tasks.py | 118 ++++++ tests/opentps/core/test_mcsquare_interface.py | 10 + tests/opentps/core/test_opentps_core.py | 13 + tests/services/test_beam_model_models.py | 221 +++++++++++ tests/services/test_beam_model_service.py | 229 ++++++++++++ tests/services/test_beam_persistence.py | 175 +++++++++ tests/services/test_machine_model.py | 69 ++++ tests/services/test_photon_beamlets.py | 181 +++++++++ tests/services/test_proton_spots.py | 215 +++++++++++ tests/test_api_beam_model.py | 280 ++++++++++++++ tests/test_api_geometry.py | 212 ++++++++--- tests/test_api_uploads.py | 347 ++++++++++++++++++ tests/test_opentps_integration.py | 4 + 38 files changed, 5412 insertions(+), 99 deletions(-) create mode 100644 demo/README_DICOM.md create mode 100644 demo/show_beam_model.py create mode 100644 demo/show_geometry.py create mode 100644 package-lock.json create mode 100644 src/migrations/versions/b1e7c1f3a4f2_geometry_jobs.py create mode 100644 src/migrations/versions/c2f8d2a4b5e3_beam_model_jobs.py create mode 100644 src/radiarch/api/routes/beam_model.py create mode 100644 src/radiarch/api/routes/uploads.py create mode 100644 src/radiarch/models/beam_model.py create mode 100644 src/radiarch/services/beam_model.py create mode 100644 src/radiarch/services/beam_persistence.py create mode 100644 src/radiarch/services/machine_model.py create mode 100644 src/radiarch/services/photon_beamlets.py create mode 100644 src/radiarch/services/proton_spots.py create mode 100644 src/radiarch/tasks/beam_model_tasks.py create mode 100644 src/radiarch/tasks/geometry_tasks.py create mode 100644 tests/services/test_beam_model_models.py create mode 100644 tests/services/test_beam_model_service.py create mode 100644 tests/services/test_beam_persistence.py create mode 100644 tests/services/test_machine_model.py create mode 100644 tests/services/test_photon_beamlets.py create mode 100644 tests/services/test_proton_spots.py create mode 100644 tests/test_api_beam_model.py create mode 100644 tests/test_api_uploads.py diff --git a/demo/README_DICOM.md b/demo/README_DICOM.md new file mode 100644 index 0000000..800d920 --- /dev/null +++ b/demo/README_DICOM.md @@ -0,0 +1,109 @@ +# Where to get a real DICOM study for the demo + +Radiarch's Geometry Service needs **a CT series + an RTSTRUCT file** for the same patient. This document lists three sources, fastest to slowest. Pick one, end up with a single `.zip` of the patient folder, then run: + +```bash +python demo/show_geometry.py --upload /path/to/your/study.zip +``` + +--- + +## Option 1 — The Cancer Imaging Archive (TCIA) — recommended + +The cleanest free source of anonymized CT + RTSTRUCT data. + +**Collection: LCTSC** (Lung CT Segmentation Challenge 2017). Small studies, well-curated, every patient has both a CT series and an RTSTRUCT with named OARs (Esophagus, Heart, Lung_L, Lung_R, SpinalCord). + +1. Open . +2. Scroll to the **Data Access** table → click the **Download** button next to "Images (DICOM, 4.3 GB)". TCIA will hand you a tiny `.tcia` manifest file (a few KB). +3. Install the **NBIA Data Retriever**: . macOS, Windows, and Linux builds available. +4. Open the `.tcia` manifest in NBIA Data Retriever. In the patient list, **uncheck "Select All"** and tick a single patient (e.g. `LCTSC-Test-S1-101`). Click **Start**. +5. After it finishes you'll have a folder structure like: + ``` + LCTSC/ + LCTSC-Test-S1-101/ + / + / ← ~100 .dcm slices + / ← 1 .dcm + ``` +6. Zip the patient folder: + ```bash + cd LCTSC + zip -r LCTSC-Test-S1-101.zip LCTSC-Test-S1-101 + ``` +7. Run the demo: + ```bash + python demo/show_geometry.py --upload LCTSC-Test-S1-101.zip --show + ``` + +Total time: 10–15 minutes including download. + +**Alternative TCIA collections** if LCTSC is overkill: `Head-Neck-PET-CT` has CT+RTSTRUCT, `QIN-HEADNECK` is also viable. The same flow works for any TCIA collection that includes the **RTSTRUCT** modality (check the collection's modality table before downloading). + +--- + +## Option 2 — The bundled SimpleFantom, repackaged as a ZIP + +Fastest path. Not "clinical" data — it's the synthetic phantom that already lives in the repo — but it does prove the upload pipeline end-to-end with zero downloads. + +```bash +cd tests/opentps/core/opentps-testData +zip -r ~/simplefantom_upload.zip SimpleFantomWithStruct +python demo/show_geometry.py --upload ~/simplefantom_upload.zip +``` + +Useful as a smoke test before pulling a real TCIA study. + +--- + +## Option 3 — An anonymized clinical study from your professor or collaborator + +If you have access to an anonymized CT+RTSTRUCT pair from a research collaborator, the requirements are: + +- All files must be valid DICOM Part 10 (`.dcm` extension preferred, but the upload endpoint will sniff for the `DICM` magic header at byte 128 if the extension is missing). +- The CT series and the RTSTRUCT must reference the same `FrameOfReferenceUID` — otherwise the contours won't line up with the voxel grid. The Geometry Service won't crash, but the structure masks will be all zeros. +- ZIP the whole patient or study folder. Subdirectories inside the ZIP are fine; the upload endpoint walks recursively. + +Re-anonymize first if there's any doubt — Radiarch keeps the bytes you upload until you `DELETE /api/v1/uploads/{upload_id}`. + +--- + +## What the upload endpoint expects + +`POST /api/v1/uploads/dicom` accepts a single multipart file part named `file`, value type `application/zip`. Example with curl: + +```bash +curl -F "file=@LCTSC-Test-S1-101.zip" \ + http://localhost:8000/api/v1/uploads/dicom +``` + +Response: + +```json +{ + "upload_id": "8c2f3d4e-...", + "file_count": 102, + "dicom_count": 102, + "ct_slice_count": 101, + "rtstruct_count": 1, + "total_bytes": 53412934, + "storage_path": "/data/artifacts/uploads/8c2f3d4e-..." +} +``` + +You then pass that `upload_id` into a geometry build: + +```bash +curl -X POST http://localhost:8000/api/v1/geometry/build \ + -H "Content-Type: application/json" \ + -d '{ + "patient_ref": {"upload_id": "8c2f3d4e-..."}, + "hu_to_density_model": "STOICHIOMETRIC" + }' +``` + +Cleanup when you're done: + +```bash +curl -X DELETE http://localhost:8000/api/v1/uploads/8c2f3d4e-... +``` diff --git a/demo/show_beam_model.py b/demo/show_beam_model.py new file mode 100644 index 0000000..1761f4f --- /dev/null +++ b/demo/show_beam_model.py @@ -0,0 +1,273 @@ +"""Live demo of Service 2 — Beam Model Service. + +Runs the Beam Model Service end-to-end against the OpenTPS sample data +that ships with the repo. Builds a geometry first (Service 1), then +builds a *proton* beam model and a *photon* beam model against it, +printing per-build timings and fluence-element summaries. Each build +is run twice so the cache hit shows up clearly. + +Usage: + python demo/show_beam_model.py # both modalities + python demo/show_beam_model.py --proton-only + python demo/show_beam_model.py --photon-only + +Requirements: + Whatever's in src/.venv (numpy, pydantic, fastapi, SimpleITK, + SQLAlchemy, OpenTPS + MCsquare for the proton path). No docker, + no Postgres, no Redis, no Celery worker. +""" + +from __future__ import annotations + +import os +import sys +import time +import traceback +from pathlib import Path +from typing import Optional + + +# --------------------------------------------------------------------------- +# Path + env setup BEFORE importing radiarch +# --------------------------------------------------------------------------- + +_REPO_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_REPO_ROOT / "src")) + +# Same env shim as the geometry demo: in-memory store, no broker, no Orthanc. +os.environ["RADIARCH_ORTHANC_USE_MOCK"] = "true" +os.environ["RADIARCH_DATABASE_URL"] = "" +os.environ["RADIARCH_BROKER_URL"] = "memory://" +os.environ["RADIARCH_RESULT_BACKEND"] = "cache+memory://" +os.environ["RADIARCH_DICOMWEB_URL"] = "" +os.environ["RADIARCH_ARTIFACT_DIR"] = str(_REPO_ROOT / "data" / "artifacts") + +# Point the OpenTPS data loader at the SimpleFantom sample. +_TEST_DATA = ( + _REPO_ROOT + / "tests" + / "opentps" + / "core" + / "opentps-testData" + / "SimpleFantomWithStruct" +) +os.environ["RADIARCH_OPENTPS_DATA_ROOT"] = str(_TEST_DATA) + + +from radiarch.models.beam_model import ( # noqa: E402 + BeamModelBuildRequest, + BeamModelStage, + BeamSetSpec, + BeamSpec, + DeliveryParams, + Modality, +) +from radiarch.models.geometry import ( # noqa: E402 + GeometryBuildRequest, + HUDensityModel, + PatientRef, +) +from radiarch.services.beam_model import BeamModelService # noqa: E402 +from radiarch.services.geometry import GeometryService # noqa: E402 + + +# --------------------------------------------------------------------------- +# Pretty-printing helpers +# --------------------------------------------------------------------------- + +BAR = "─" * 64 + + +def _h(label: str) -> None: + print(f"\n{BAR}\n {label}\n{BAR}") + + +def _row(key: str, value) -> None: + print(f" {key:<26} {value}") + + +def _on_progress(stage: BeamModelStage, fraction: float, message: str) -> None: + pct = int(fraction * 100) + print(f" [{pct:3d}%] {stage.value:<24} {message}") + + +# --------------------------------------------------------------------------- +# Request builders — these are the same payloads an HTTP client would send +# --------------------------------------------------------------------------- + +def _proton_request(geometry_id: str) -> BeamModelBuildRequest: + """Two opposed proton fields with 5 mm spot + layer spacing.""" + return BeamModelBuildRequest( + geometry_id=geometry_id, + modality=Modality.proton_pbs, + machine_model_id=None, + beam_set=BeamSetSpec( + isocenter_mm=[0.0, 0.0, 0.0], + beams=[ + BeamSpec(beam_id="B1", gantry_deg=0.0, couch_deg=0.0, + collimator_deg=0.0), + BeamSpec(beam_id="B2", gantry_deg=180.0, couch_deg=0.0, + collimator_deg=0.0), + ], + ), + delivery_params=DeliveryParams( + spot_spacing_mm=5.0, + layer_spacing_mm=5.0, + ), + ) + + +def _photon_request(geometry_id: str) -> BeamModelBuildRequest: + """Three coplanar photon fields with 5 mm beamlets, 10x10 cm jaw.""" + return BeamModelBuildRequest( + geometry_id=geometry_id, + modality=Modality.photon_imrt, + machine_model_id=None, + beam_set=BeamSetSpec( + isocenter_mm=[0.0, 0.0, 0.0], + beams=[ + BeamSpec(beam_id="B1", gantry_deg=0.0), + BeamSpec(beam_id="B2", gantry_deg=120.0), + BeamSpec(beam_id="B3", gantry_deg=240.0), + ], + ), + delivery_params=DeliveryParams( + beamlet_size_mm=[5.0, 5.0], + jaw_opening_mm=[100.0, 100.0], + ), + ) + + +# --------------------------------------------------------------------------- +# Demo helpers +# --------------------------------------------------------------------------- + +def _build_geometry() -> str: + """Build (or hit cached) geometry — returns its geometry_id.""" + _h("Step 1 — Geometry Service (Service 1)") + request = GeometryBuildRequest( + patient_ref=PatientRef( + dicom_study_uid="demo-study-001", + ct_series_uid=None, + rtstruct_uid=None, + ), + grid_spec=None, + hu_to_density_model=HUDensityModel.stoichiometric, + ) + _row("cache_key:", request.compute_cache_key()[:16] + "…") + t0 = time.monotonic() + result = GeometryService().build(request) + elapsed_ms = (time.monotonic() - t0) * 1000.0 + _row("geometry_id:", result.geometry_id) + _row("elapsed:", f"{elapsed_ms:8.1f} ms") + _row("density NIfTI:", Path(result.density_grid_uri).name) + _row("structures:", list(result.structure_index.keys())) + return result.geometry_id + + +def _build_beam_model( + label: str, + request: BeamModelBuildRequest, + service: BeamModelService, + show_progress: bool = True, +) -> tuple[Optional[object], float]: + """Run service.build with timing. Returns (result, elapsed_ms).""" + _h(label) + _row("modality:", request.modality.value) + _row("beams:", [b.beam_id for b in request.beam_set.beams]) + _row("cache_key:", request.compute_cache_key()[:16] + "…") + t0 = time.monotonic() + try: + cb = _on_progress if show_progress else None + result = service.build(request, progress_callback=cb) + except Exception as exc: + elapsed_ms = (time.monotonic() - t0) * 1000.0 + _row("status:", "FAILED") + _row("error:", f"{type(exc).__name__}: {exc}") + traceback.print_exc() + return None, elapsed_ms + elapsed_ms = (time.monotonic() - t0) * 1000.0 + _row("beam_model_id:", result.beam_model_id) + _row("elapsed:", f"{elapsed_ms:8.1f} ms") + return result, elapsed_ms + + +def _print_result(result) -> None: + fe = result.fluence_elements + _h(f"BeamModelResult — {result.modality.value}") + _row("beam_model_id:", result.beam_model_id) + _row("cache_key:", result.cache_key[:16] + "…") + _row("machine_model_id:", result.machine_model_id or "(default)") + _row("plan artifact:", Path(result.beam_model_ref_uri).name) + _row("fluence elements:", f"{fe.total_count} total") + for pb in fe.per_beam: + bits = [f"{pb.element_count} elements"] + if pb.energy_layers: + bits.append(f"{len(pb.energy_layers)} energy layers") + bits.append(f"E ∈ [{min(pb.energy_layers):.1f}, " + f"{max(pb.energy_layers):.1f}] MeV") + if pb.spots_per_layer: + bits.append(f"{sum(pb.spots_per_layer)} spots") + if pb.grid_dims: + bits.append(f"grid {pb.grid_dims[0]}×{pb.grid_dims[1]}") + if pb.active_beamlets is not None: + bits.append(f"{pb.active_beamlets} active beamlets") + print(f" • {pb.beam_id}: " + ", ".join(bits)) + + +def _run_modality( + name: str, + request: BeamModelBuildRequest, + service: BeamModelService, +) -> None: + """Build twice — first call exercises the full pipeline, second hits cache.""" + r1, t1 = _build_beam_model(f"{name} build #1 (cache miss)", request, + service, show_progress=True) + if r1 is None: + return + r2, t2 = _build_beam_model(f"{name} build #2 (same request)", request, + service, show_progress=False) + if r2 is not None and r1.beam_model_id == r2.beam_model_id: + speedup = t1 / max(t2, 0.01) + _row("speedup:", f"{speedup:6.1f}×") + print(" ✓ same beam_model_id — cache hit confirmed") + _print_result(r1) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main() -> None: + if not _TEST_DATA.exists(): + print(f"ERROR: test data not found at {_TEST_DATA}", file=sys.stderr) + sys.exit(1) + + do_proton = "--photon-only" not in sys.argv + do_photon = "--proton-only" not in sys.argv + + _h("Radiarch Beam Model Service — Live Demo") + _row("test data:", _TEST_DATA.relative_to(_REPO_ROOT)) + _row("artifact dir:", os.environ["RADIARCH_ARTIFACT_DIR"]) + _row("modalities:", ", ".join( + ([Modality.proton_pbs.value] if do_proton else []) + + ([Modality.photon_imrt.value] if do_photon else []) + )) + + geometry_id = _build_geometry() + + service = BeamModelService() + + if do_proton: + _run_modality("Proton (PROTON_PBS)", _proton_request(geometry_id), service) + if do_photon: + _run_modality("Photon (PHOTON_IMRT)", _photon_request(geometry_id), service) + + _h("Done") + print(f" Cached beam models live in: " + f"{Path(os.environ['RADIARCH_ARTIFACT_DIR']) / 'beam_models'}") + print() + + +if __name__ == "__main__": + main() diff --git a/demo/show_geometry.py b/demo/show_geometry.py new file mode 100644 index 0000000..60f8486 --- /dev/null +++ b/demo/show_geometry.py @@ -0,0 +1,260 @@ +"""Live demo of Service 1 — Geometry Service. + +Runs the Geometry Service end-to-end and prints the GeometryResult, +optionally rendering an axial slice with structure masks overlaid. + +Usage: + # 1. Bundled SimpleFantom sample (dev fallback) + python demo/show_geometry.py + python demo/show_geometry.py --show + + # 2. A real DICOM study you provide as a ZIP — exercises the same + # code path the production HTTP upload endpoint uses. + python demo/show_geometry.py --upload /path/to/study.zip + python demo/show_geometry.py --upload /path/to/study.zip --show + +Requirements: + Whatever's in src/.venv (numpy, scipy, pydantic, SimpleITK, + matplotlib if you pass --show). No docker, no Orthanc, no + Postgres, no Celery worker. + +Where to get a real DICOM study: + See demo/README_DICOM.md — points at TCIA's free LCTSC collection + (CT + RTSTRUCT, anonymized). +""" + +from __future__ import annotations + +import os +import sys +import time +from pathlib import Path +from typing import Optional + + +# --------------------------------------------------------------------------- +# Path + env setup BEFORE importing radiarch +# --------------------------------------------------------------------------- + +_REPO_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(_REPO_ROOT / "src")) + +# Mock Orthanc → service falls back to disk loading. +os.environ["RADIARCH_ORTHANC_USE_MOCK"] = "true" +# No Postgres/Redis needed for direct service usage. +os.environ["RADIARCH_DATABASE_URL"] = "" +os.environ["RADIARCH_BROKER_URL"] = "memory://" +os.environ["RADIARCH_RESULT_BACKEND"] = "cache+memory://" +os.environ["RADIARCH_DICOMWEB_URL"] = "" +# Override docker-mode artifact dir (which lives at /data/artifacts inside the +# container) with a local path so the demo writes to the repo, not the root fs. +os.environ["RADIARCH_ARTIFACT_DIR"] = str(_REPO_ROOT / "data" / "artifacts") +# Point at the test fixture that ships with the repo. +_TEST_DATA = ( + _REPO_ROOT + / "tests" + / "opentps" + / "core" + / "opentps-testData" + / "SimpleFantomWithStruct" +) +os.environ["RADIARCH_OPENTPS_DATA_ROOT"] = str(_TEST_DATA) + + +from radiarch.config import get_settings # noqa: E402 +from radiarch.models.geometry import ( # noqa: E402 + GeometryBuildRequest, + HUDensityModel, + PatientRef, +) +from radiarch.services.geometry import GeometryService # noqa: E402 + + +# --------------------------------------------------------------------------- +# Upload helper — mirrors what POST /uploads/dicom does, in-process so the +# demo doesn't need a running server. +# --------------------------------------------------------------------------- + +def _ingest_upload_zip(zip_path: Path) -> str: + """Extract a DICOM ZIP into the configured upload_dir, return an upload_id.""" + import shutil + import uuid + import zipfile + + settings = get_settings() + base = settings.upload_dir or str(Path(settings.artifact_dir) / "uploads") + upload_root = Path(base).expanduser().resolve() + upload_root.mkdir(parents=True, exist_ok=True) + + upload_id = str(uuid.uuid4()) + dest = upload_root / upload_id + dest.mkdir() + + # Refuse zip-slip traversal — same check as the upload endpoint. + with zipfile.ZipFile(zip_path) as zf: + dest_resolved = dest.resolve() + for member in zf.infolist(): + target = (dest / member.filename).resolve() + if dest_resolved not in target.parents and target != dest_resolved: + shutil.rmtree(dest, ignore_errors=True) + raise ValueError(f"Refusing unsafe ZIP entry: {member.filename!r}") + zf.extractall(dest) + + # Quick sanity check. + dcm_count = sum(1 for p in dest.rglob("*.dcm") if p.is_file()) + print(f" extracted {dcm_count} .dcm files into {dest}") + return upload_id + + +# --------------------------------------------------------------------------- +# Pretty-printing helpers +# --------------------------------------------------------------------------- + +BAR = "─" * 64 + + +def _h(label: str) -> None: + print(f"\n{BAR}\n {label}\n{BAR}") + + +def _row(key: str, value) -> None: + print(f" {key:<26} {value}") + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + +def _parse_upload_arg() -> Optional[Path]: + if "--upload" not in sys.argv: + return None + idx = sys.argv.index("--upload") + if idx + 1 >= len(sys.argv): + print("ERROR: --upload requires a path to a ZIP file", file=sys.stderr) + sys.exit(2) + return Path(sys.argv[idx + 1]).expanduser().resolve() + + +def main() -> None: + upload_zip = _parse_upload_arg() + + _h("Radiarch Geometry Service — Live Demo") + + if upload_zip is not None: + if not upload_zip.is_file(): + print(f"ERROR: upload ZIP not found at {upload_zip}", file=sys.stderr) + sys.exit(1) + _row("source:", "ZIP upload") + _row("zip path:", upload_zip) + upload_id = _ingest_upload_zip(upload_zip) + _row("upload_id:", upload_id) + patient_ref = PatientRef(upload_id=upload_id) + else: + if not _TEST_DATA.exists(): + print(f"ERROR: test data not found at {_TEST_DATA}", file=sys.stderr) + print("Either provide --upload or place the SimpleFantom") + print("sample at the expected location.") + sys.exit(1) + _row("source:", "bundled SimpleFantom (dev fallback)") + _row("test data:", _TEST_DATA.relative_to(_REPO_ROOT)) + patient_ref = PatientRef( + dicom_study_uid="demo-study-001", + ct_series_uid=None, + rtstruct_uid=None, + ) + + # Build the request — same shape an HTTP client would send. + request = GeometryBuildRequest( + patient_ref=patient_ref, + grid_spec=None, # match CT grid + hu_to_density_model=HUDensityModel.stoichiometric, + ) + + _h("Request") + _row("dicom_study_uid:", request.patient_ref.dicom_study_uid) + _row("hu_to_density_model:", request.hu_to_density_model.value) + _row("cache_key:", request.compute_cache_key()[:16] + "…") + + service = GeometryService() + + # First call — could be a cache miss (full build) or a cache hit + # if you've run this script before. + _h("Build #1") + t0 = time.monotonic() + result1 = service.build(request) + t1_ms = (time.monotonic() - t0) * 1000.0 + _row("geometry_id:", result1.geometry_id) + _row("elapsed:", f"{t1_ms:8.1f} ms") + + # Second call — guaranteed cache hit. + _h("Build #2 (same request)") + t0 = time.monotonic() + result2 = service.build(request) + t2_ms = (time.monotonic() - t0) * 1000.0 + _row("geometry_id:", result2.geometry_id) + _row("elapsed:", f"{t2_ms:8.1f} ms") + _row("speedup:", f"{t1_ms / max(t2_ms, 0.01):6.1f}×") + if result1.geometry_id == result2.geometry_id: + print(" ✓ same geometry_id — cache hit confirmed") + + _h("GeometryResult") + _row("modality:", result1.ct_metadata.modality) + _row("patient_name:", result1.ct_metadata.patient_name) + _row("num_slices:", result1.ct_metadata.num_slices) + _row("frame_of_reference_uid:", result1.frame_of_reference_uid or "(none)") + _row("grid spacing (mm):", result1.grid_spec.spacing_mm) + _row("grid origin (mm):", result1.grid_spec.origin_mm) + _row("grid size (vox):", result1.grid_spec.size) + _row("structure_index:", dict(result1.structure_index)) + _row("density NIfTI:", result1.density_grid_uri) + _row("masks NIfTI:", result1.structure_masks_uri) + + if "--show" in sys.argv: + _show_axial_slices(result1) + else: + print(f"\n (pass --show to also render axial slices)") + print() + + +# --------------------------------------------------------------------------- +# Optional visualization +# --------------------------------------------------------------------------- + +def _show_axial_slices(result) -> None: + """Render the middle axial slice — density, masks, overlay.""" + try: + import numpy as np + import SimpleITK as sitk + import matplotlib.pyplot as plt + except ImportError as exc: + print(f"\n (--show needs matplotlib + SimpleITK: {exc})") + return + + density = sitk.GetArrayFromImage(sitk.ReadImage(result.density_grid_uri)) + masks = sitk.GetArrayFromImage(sitk.ReadImage(result.structure_masks_uri)) + z = density.shape[0] // 2 # middle axial slice + + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + axes[0].imshow(density[z], cmap="gray") + axes[0].set_title(f"Density (slice {z})") + axes[1].imshow(masks[z], cmap="tab10", vmin=0, vmax=10) + axes[1].set_title("Structure masks") + axes[2].imshow(density[z], cmap="gray") + masked = np.ma.masked_where(masks[z] == 0, masks[z]) + axes[2].imshow(masked, cmap="tab10", alpha=0.5, vmin=0, vmax=10) + axes[2].set_title("Overlay") + for ax in axes: + ax.set_xticks([]) + ax.set_yticks([]) + + plt.suptitle( + f"Geometry {result.geometry_id[:8]}… " + f"({', '.join(result.structure_index.keys())})", + fontsize=11, + ) + plt.tight_layout() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..dc889ce --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "radiarch", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/src/migrations/versions/b1e7c1f3a4f2_geometry_jobs.py b/src/migrations/versions/b1e7c1f3a4f2_geometry_jobs.py new file mode 100644 index 0000000..dd7ef64 --- /dev/null +++ b/src/migrations/versions/b1e7c1f3a4f2_geometry_jobs.py @@ -0,0 +1,49 @@ +"""geometry_jobs + +Revision ID: b1e7c1f3a4f2 +Revises: a0d6cb2919ee +Create Date: 2026-04-19 20:00:00.000000 + +Adds the ``geometry_jobs`` table used by the async-mode Geometry Service +(POST /api/v1/geometry/build -> 202 + job_id; GET /api/v1/geometry/jobs/ +{job_id}). +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "b1e7c1f3a4f2" +down_revision: Union[str, None] = "a0d6cb2919ee" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "geometry_jobs", + sa.Column("id", sa.String(length=36), nullable=False), + sa.Column("cache_key", sa.String(length=64), nullable=False), + sa.Column("state", sa.String(length=20), nullable=False), + sa.Column("progress", sa.Float(), nullable=True), + sa.Column("stage", sa.String(length=32), nullable=True), + sa.Column("message", sa.Text(), nullable=True), + sa.Column("geometry_id", sa.String(length=36), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_geometry_jobs_cache_key", + "geometry_jobs", + ["cache_key"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("ix_geometry_jobs_cache_key", table_name="geometry_jobs") + op.drop_table("geometry_jobs") diff --git a/src/migrations/versions/c2f8d2a4b5e3_beam_model_jobs.py b/src/migrations/versions/c2f8d2a4b5e3_beam_model_jobs.py new file mode 100644 index 0000000..0530eef --- /dev/null +++ b/src/migrations/versions/c2f8d2a4b5e3_beam_model_jobs.py @@ -0,0 +1,49 @@ +"""beam_model_jobs + +Revision ID: c2f8d2a4b5e3 +Revises: b1e7c1f3a4f2 +Create Date: 2026-05-08 21:00:00.000000 + +Adds the ``beam_model_jobs`` table used by the async-mode Beam Model +Service (POST /api/v1/beam-model/build -> 202 + job_id; GET +/api/v1/beam-model/jobs/{job_id}). Mirrors the geometry_jobs table. +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "c2f8d2a4b5e3" +down_revision: Union[str, None] = "b1e7c1f3a4f2" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "beam_model_jobs", + sa.Column("id", sa.String(length=36), nullable=False), + sa.Column("cache_key", sa.String(length=64), nullable=False), + sa.Column("state", sa.String(length=20), nullable=False), + sa.Column("progress", sa.Float(), nullable=True), + sa.Column("stage", sa.String(length=32), nullable=True), + sa.Column("message", sa.Text(), nullable=True), + sa.Column("beam_model_id", sa.String(length=36), nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("finished_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_beam_model_jobs_cache_key", + "beam_model_jobs", + ["cache_key"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index("ix_beam_model_jobs_cache_key", table_name="beam_model_jobs") + op.drop_table("beam_model_jobs") diff --git a/src/radiarch/api/routes/beam_model.py b/src/radiarch/api/routes/beam_model.py new file mode 100644 index 0000000..6487558 --- /dev/null +++ b/src/radiarch/api/routes/beam_model.py @@ -0,0 +1,171 @@ +"""FastAPI routes for the Beam Model Service (Service 2). + +Status-code contract for ``POST /beam-model/build``: + +* ``200 OK`` + full :class:`BeamModelResult` — cache hit, no job created. +* ``202 Accepted`` + :class:`BeamModelBuildResponse` carrying ``job_id`` + — cache miss, build dispatched to Celery. Client polls + ``GET /beam-model/jobs/{job_id}`` until ``state == succeeded``, then + fetches ``GET /beam-model/{beam_model_id}``. +* ``422 Unprocessable Entity`` — request validation failed (unknown + ``geometry_id``, modality/params mismatch, unknown machine model id). + +Endpoints +--------- +``POST /api/v1/beam-model/build`` — build / reuse cached. +``GET /api/v1/beam-model/{id}`` — retrieve result. +``GET /api/v1/beam-model/{id}/artifact`` — stream pickled plan. +``DELETE /api/v1/beam-model/{id}`` — remove from cache. +``GET /api/v1/beam-model/jobs/{job_id}`` — async job status. +""" + +from __future__ import annotations + +import os +from functools import lru_cache + +from fastapi import APIRouter, HTTPException, Response, status +from fastapi.responses import FileResponse +from pydantic import BaseModel + +from ...core.store import store +from ...models.beam_model import ( + BeamModelBuildRequest, + BeamModelJobStatus, + BeamModelResult, +) +from ...services.beam_model import BeamModelService +from ...services.beam_persistence import PLAN_FILENAME + +router = APIRouter(prefix="/beam-model", tags=["beam-model"]) + + +@lru_cache(maxsize=1) +def _service() -> BeamModelService: + """Singleton service instance, reused across requests.""" + return BeamModelService() + + +# --------------------------------------------------------------------------- +# Async-dispatch response shape +# --------------------------------------------------------------------------- + +class BeamModelBuildResponse(BaseModel): + """Returned by ``POST /build`` when a Celery job is dispatched.""" + + job_id: str + cache_key: str + state: str = "queued" + message: str = ( + "Build dispatched; poll /beam-model/jobs/{job_id} for progress." + ) + + +# --------------------------------------------------------------------------- +# POST /build +# --------------------------------------------------------------------------- + +@router.post( + "/build", + summary="Build (or reuse cached) beam model from a geometry + beam set.", + responses={ + 200: {"description": "Cache hit — returned the existing beam model inline."}, + 202: {"description": "Cache miss — Celery job dispatched."}, + 422: {"description": "Request validation error."}, + }, +) +async def build_beam_model(request: BeamModelBuildRequest, response: Response): + try: + cache_key = request.compute_cache_key() + service = _service() + cached = service.store.lookup_by_cache_key(cache_key) + if cached is not None: + response.status_code = status.HTTP_200_OK + return cached + + job = store.create_beam_model_job(cache_key) + # Lazy task import: this module imports cleanly even if Celery + # isn't configured (tests without a broker). + from ...tasks.beam_model_tasks import build_beam_model_job + + build_beam_model_job.delay(job.id, request.model_dump(mode="json")) + response.status_code = status.HTTP_202_ACCEPTED + return BeamModelBuildResponse(job_id=job.id, cache_key=cache_key) + + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + + +# --------------------------------------------------------------------------- +# GET /jobs/{job_id} +# --------------------------------------------------------------------------- + +@router.get( + "/jobs/{job_id}", + response_model=BeamModelJobStatus, + summary="Poll an async beam-model build job.", +) +async def get_beam_model_job(job_id: str) -> BeamModelJobStatus: + job = store.get_beam_model_job(job_id) + if job is None: + raise HTTPException( + status_code=404, detail=f"Beam model job not found: {job_id}" + ) + return job + + +# --------------------------------------------------------------------------- +# Result retrieval + artifact stream + delete +# --------------------------------------------------------------------------- + +@router.get( + "/{beam_model_id}", + response_model=BeamModelResult, + summary="Retrieve completed beam-model metadata.", +) +async def get_beam_model(beam_model_id: str) -> BeamModelResult: + result = _service().store.get_by_id(beam_model_id) + if result is None: + raise HTTPException( + status_code=404, detail=f"Beam model not found: {beam_model_id}" + ) + return result + + +@router.delete( + "/{beam_model_id}", + status_code=204, + response_class=Response, + summary="Delete a cached beam model.", +) +async def delete_beam_model(beam_model_id: str): + svc = _service() + if svc.store.get_by_id(beam_model_id) is None: + raise HTTPException( + status_code=404, detail=f"Beam model not found: {beam_model_id}" + ) + svc.store.delete_by_id(beam_model_id) + return Response(status_code=204) + + +@router.get( + "/{beam_model_id}/artifact", + summary="Stream the serialized OpenTPS plan (pickled).", + response_class=FileResponse, +) +async def get_beam_model_artifact(beam_model_id: str): + svc = _service() + if svc.store.get_by_id(beam_model_id) is None: + raise HTTPException( + status_code=404, detail=f"Beam model not found: {beam_model_id}" + ) + plan_path = svc.store.base_dir / beam_model_id / PLAN_FILENAME + if not os.path.isfile(plan_path): + raise HTTPException( + status_code=410, detail=f"{PLAN_FILENAME} no longer on disk" + ) + return FileResponse( + path=str(plan_path), + media_type="application/octet-stream", + filename=PLAN_FILENAME, + ) diff --git a/src/radiarch/api/routes/geometry.py b/src/radiarch/api/routes/geometry.py index bc0cfc1..69d511d 100644 --- a/src/radiarch/api/routes/geometry.py +++ b/src/radiarch/api/routes/geometry.py @@ -1,27 +1,45 @@ """FastAPI routes for the Geometry Service (Service 1). -v1 is synchronous — ``POST /geometry/build`` runs the build in-process -and returns the :class:`GeometryResult` directly. The ``/jobs`` variant -listed in the spec comes with the async-mode PR (Celery + DB); in the -meantime, we expose a placeholder so existing clients can evolve. +The service runs either synchronously (cache hits — the answer is +already on disk) or asynchronously (cache misses — a Celery task builds +the geometry in the background and the client polls the jobs endpoint +for progress). Endpoints --------- -``POST /api/v1/geometry/build`` — build (or reuse cached) geometry. -``GET /api/v1/geometry/{id}`` — retrieve cached geometry metadata. -``GET /api/v1/geometry/{id}/density``— stream density NIfTI. -``GET /api/v1/geometry/{id}/masks`` — stream multi-label mask NIfTI. +``POST /api/v1/geometry/build`` — build or reuse cached. +``GET /api/v1/geometry/{id}`` — retrieve geometry metadata. +``GET /api/v1/geometry/{id}/density`` — stream density NIfTI. +``GET /api/v1/geometry/{id}/masks`` — stream multi-label mask NIfTI. +``DELETE /api/v1/geometry/{id}`` — remove a cached geometry. +``GET /api/v1/geometry/jobs/{job_id}`` — async job status / progress. + +Status-code contract for ``POST /build``: + +* ``200 OK`` + full :class:`GeometryResult` — cache hit, no job created. +* ``202 Accepted`` + :class:`GeometryBuildResponse` with ``job_id`` — + cache miss, build dispatched to Celery. Client polls the jobs + endpoint and then fetches the geometry once ``state=succeeded``. +* ``422 Unprocessable Entity`` — request-level validation error + (e.g. underspecified grid, rotated affine). """ from __future__ import annotations import os from functools import lru_cache +from typing import Optional -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Response, status from fastapi.responses import FileResponse +from pydantic import BaseModel -from ...models.geometry import GeometryBuildRequest, GeometryResult +from ...core.store import store +from ...models.geometry import ( + GeometryBuildRequest, + GeometryJobStatus, + GeometryResult, +) from ...services.geometry import GeometryService from ...services.persistence import DENSITY_FILENAME, MASKS_FILENAME @@ -35,22 +53,81 @@ def _service() -> GeometryService: return GeometryService() +# --------------------------------------------------------------------------- +# Response shape for the async build case +# --------------------------------------------------------------------------- + +class GeometryBuildResponse(BaseModel): + """Returned by ``POST /build`` when a Celery job is dispatched. + + When the cache hits, the endpoint returns a full + :class:`GeometryResult` instead — ``response_model=Union[...]`` on + the route union-types the OpenAPI schema accordingly. + """ + + job_id: str + cache_key: str + state: str = "queued" + message: str = "Build dispatched; poll /geometry/jobs/{job_id} for progress." + + +# --------------------------------------------------------------------------- +# POST /build — cache-hit fast path or async dispatch +# --------------------------------------------------------------------------- + @router.post( "/build", - response_model=GeometryResult, summary="Build (or reuse cached) geometry from a DICOM study.", + responses={ + 200: {"description": "Cache hit — returned the existing geometry inline."}, + 202: {"description": "Cache miss — Celery job dispatched; poll the jobs endpoint."}, + 422: {"description": "Request validation error."}, + }, ) -async def build_geometry(request: GeometryBuildRequest) -> GeometryResult: +async def build_geometry(request: GeometryBuildRequest, response: Response): try: - return _service().build(request) - except FileNotFoundError as exc: - raise HTTPException(status_code=404, detail=str(exc)) from exc + cache_key = request.compute_cache_key() + service = _service() + cached = service.store.lookup_by_cache_key(cache_key) + if cached is not None: + # Fast path: no job, no Celery round-trip. Return 200 + full result. + response.status_code = status.HTTP_200_OK + return cached + + # Cache miss → create job row and dispatch to Celery. + job = store.create_geometry_job(cache_key) + # Lazy import of the task so this module imports cleanly even + # when Celery isn't configured (e.g., unit tests without broker). + from ...tasks.geometry_tasks import build_geometry_job + + build_geometry_job.delay(job.id, request.model_dump(mode="json")) + response.status_code = status.HTTP_202_ACCEPTED + return GeometryBuildResponse(job_id=job.id, cache_key=cache_key) + except ValueError as exc: - # Request-level validation problems — e.g., rotated affine, - # underspecified grid. Bubble up as 422 so clients can act. raise HTTPException(status_code=422, detail=str(exc)) from exc +# --------------------------------------------------------------------------- +# GET /jobs/{job_id} — poll progress +# --------------------------------------------------------------------------- + +@router.get( + "/jobs/{job_id}", + response_model=GeometryJobStatus, + summary="Poll an async geometry build job.", +) +async def get_geometry_job(job_id: str) -> GeometryJobStatus: + job = store.get_geometry_job(job_id) + if job is None: + raise HTTPException(status_code=404, detail=f"Geometry job not found: {job_id}") + return job + + +# --------------------------------------------------------------------------- +# Geometry retrieval + streams +# --------------------------------------------------------------------------- + @router.get( "/{geometry_id}", response_model=GeometryResult, @@ -63,6 +140,24 @@ async def get_geometry(geometry_id: str) -> GeometryResult: return result +@router.delete( + "/{geometry_id}", + status_code=204, + response_class=Response, + summary="Delete a cached geometry (NIfTI files + metadata + cache index entry).", +) +async def delete_geometry(geometry_id: str): + svc = _service() + # Load first so we can fail with 404 if it never existed. Also gives + # us the ``cache_key`` to scrub from the index. + result = svc.store.get_by_id(geometry_id) + if result is None: + raise HTTPException(status_code=404, detail=f"Geometry not found: {geometry_id}") + + svc.store.delete_by_id(geometry_id) + return Response(status_code=204) + + @router.get( "/{geometry_id}/density", summary="Stream the density NIfTI volume.", diff --git a/src/radiarch/api/routes/uploads.py b/src/radiarch/api/routes/uploads.py new file mode 100644 index 0000000..0892047 --- /dev/null +++ b/src/radiarch/api/routes/uploads.py @@ -0,0 +1,250 @@ +"""DICOM upload endpoints — feeder for Service 1. + +Lets a client POST a single ZIP containing one patient study (CT series + +RTSTRUCT) and get back an ``upload_id`` that can be passed to +``POST /geometry/build`` via ``patient_ref.upload_id``. The ZIP is +extracted under ``{settings.upload_dir}/{upload_id}/`` and the contents +walked once to give the caller a quick sanity check (how many CT +slices, how many RTSTRUCTs, total bytes). + +Endpoints +--------- +``POST /api/v1/uploads/dicom`` — accept a multipart ZIP. +``GET /api/v1/uploads/{upload_id}`` — inspect what was uploaded. +``DELETE /api/v1/uploads/{upload_id}`` — remove the extracted bundle. +""" + +from __future__ import annotations + +import os +import shutil +import uuid +import zipfile +from functools import lru_cache +from pathlib import Path +from typing import List, Optional + +from fastapi import APIRouter, File, HTTPException, Response, UploadFile, status +from loguru import logger +from pydantic import BaseModel, Field + +from ...config import get_settings + +router = APIRouter(prefix="/uploads", tags=["uploads"]) + + +# --------------------------------------------------------------------------- +# Settings helper +# --------------------------------------------------------------------------- + +@lru_cache(maxsize=1) +def _upload_root() -> Path: + """Resolve the configured upload directory, creating it if needed. + + When ``settings.upload_dir`` is empty (the default), falls back to + ``{settings.artifact_dir}/uploads``. + """ + settings = get_settings() + base = settings.upload_dir or str(Path(settings.artifact_dir) / "uploads") + path = Path(base).expanduser().resolve() + path.mkdir(parents=True, exist_ok=True) + return path + + +def _upload_path(upload_id: str) -> Path: + return _upload_root() / upload_id + + +# --------------------------------------------------------------------------- +# Response shapes +# --------------------------------------------------------------------------- + +class UploadResponse(BaseModel): + upload_id: str = Field(..., description="Pass this as patient_ref.upload_id.") + file_count: int = Field(..., description="Total number of files extracted.") + dicom_count: int = Field(..., description="Files with a .dcm extension.") + ct_slice_count: int = Field(..., description="Best-effort count of CT instances.") + rtstruct_count: int = Field(..., description="Best-effort count of RTSTRUCT instances.") + total_bytes: int = Field(..., description="Sum of all extracted file sizes.") + storage_path: str = Field(..., description="Server-side directory holding the extracted bundle.") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# DICOM modality is at tag (0008,0060). Reading it requires pydicom; we +# only need it to give the upload response useful counts. Best-effort: +# if pydicom isn't available at request time, we still accept the upload +# and report zeros for the modality-specific fields. + +def _scan_directory(root: Path) -> UploadResponse: + file_count = 0 + dicom_count = 0 + ct_slice_count = 0 + rtstruct_count = 0 + total_bytes = 0 + + try: + import pydicom # type: ignore + has_pydicom = True + except Exception: # pragma: no cover — pydicom is a hard dep but be safe + has_pydicom = False + + for path in root.rglob("*"): + if not path.is_file(): + continue + file_count += 1 + total_bytes += path.stat().st_size + is_dcm = path.suffix.lower() == ".dcm" + if not is_dcm: + # Some clinical exports drop the .dcm extension. Sniff for it. + try: + with path.open("rb") as fh: + fh.seek(128) + if fh.read(4) == b"DICM": + is_dcm = True + except OSError: + pass + if not is_dcm: + continue + dicom_count += 1 + if not has_pydicom: + continue + try: + ds = pydicom.dcmread(str(path), stop_before_pixels=True, force=True) + modality = (getattr(ds, "Modality", "") or "").upper() + if modality == "CT": + ct_slice_count += 1 + elif modality == "RTSTRUCT": + rtstruct_count += 1 + except Exception as exc: # pragma: no cover — defensive + logger.debug("Could not parse DICOM modality for %s: %s", path, exc) + + return UploadResponse( + upload_id="", # filled in by caller + file_count=file_count, + dicom_count=dicom_count, + ct_slice_count=ct_slice_count, + rtstruct_count=rtstruct_count, + total_bytes=total_bytes, + storage_path=str(root), + ) + + +def _safe_extract_zip(zip_path: Path, dest: Path) -> None: + """Extract ``zip_path`` into ``dest``, refusing zip-slip traversal.""" + with zipfile.ZipFile(zip_path) as zf: + dest_resolved = dest.resolve() + for member in zf.infolist(): + target = (dest / member.filename).resolve() + if dest_resolved not in target.parents and target != dest_resolved: + raise HTTPException( + status_code=400, + detail=f"Refusing unsafe ZIP entry: {member.filename!r}", + ) + zf.extractall(dest) + + +# --------------------------------------------------------------------------- +# POST /uploads/dicom +# --------------------------------------------------------------------------- + +@router.post( + "/dicom", + response_model=UploadResponse, + status_code=status.HTTP_201_CREATED, + summary="Upload a ZIP of one patient's DICOM study (CT + RTSTRUCT).", + responses={ + 201: {"description": "Upload accepted; extracted bundle is ready."}, + 400: {"description": "File missing, not a ZIP, or unsafe entries."}, + 413: {"description": "Upload exceeds configured size limit."}, + }, +) +async def upload_dicom(file: UploadFile = File(...)) -> UploadResponse: + if not file.filename: + raise HTTPException(status_code=400, detail="Upload must include a filename.") + if not file.filename.lower().endswith(".zip"): + raise HTTPException( + status_code=400, + detail=f"Expected a .zip file, got {file.filename!r}.", + ) + + upload_id = str(uuid.uuid4()) + dest = _upload_path(upload_id) + dest.mkdir(parents=True, exist_ok=False) + + # Stream the upload to a temporary path inside the upload dir, then + # extract — keeps everything on one filesystem so we don't pay a + # cross-FS copy. + tmp_zip = dest / "_incoming.zip" + try: + with tmp_zip.open("wb") as out: + shutil.copyfileobj(file.file, out) + try: + _safe_extract_zip(tmp_zip, dest) + except zipfile.BadZipFile as exc: + shutil.rmtree(dest, ignore_errors=True) + raise HTTPException( + status_code=400, detail=f"Not a valid ZIP archive: {exc}" + ) from exc + finally: + try: + tmp_zip.unlink() + except FileNotFoundError: + pass + + summary = _scan_directory(dest) + if summary.dicom_count == 0: + shutil.rmtree(dest, ignore_errors=True) + raise HTTPException( + status_code=400, + detail="ZIP contained no DICOM files (.dcm). Upload rejected.", + ) + + summary = summary.model_copy(update={"upload_id": upload_id}) + logger.info( + "Upload %s accepted: %d files (%d DICOM, %d CT, %d RTSTRUCT, %.1f MB)", + upload_id, + summary.file_count, + summary.dicom_count, + summary.ct_slice_count, + summary.rtstruct_count, + summary.total_bytes / (1024 * 1024), + ) + return summary + + +# --------------------------------------------------------------------------- +# GET /uploads/{upload_id} +# --------------------------------------------------------------------------- + +@router.get( + "/{upload_id}", + response_model=UploadResponse, + summary="Inspect a previously-uploaded DICOM bundle.", +) +async def get_upload(upload_id: str) -> UploadResponse: + path = _upload_path(upload_id) + if not path.is_dir(): + raise HTTPException(status_code=404, detail=f"Upload not found: {upload_id}") + summary = _scan_directory(path) + return summary.model_copy(update={"upload_id": upload_id}) + + +# --------------------------------------------------------------------------- +# DELETE /uploads/{upload_id} +# --------------------------------------------------------------------------- + +@router.delete( + "/{upload_id}", + status_code=204, + response_class=Response, + summary="Delete an extracted DICOM bundle.", +) +async def delete_upload(upload_id: str): + path = _upload_path(upload_id) + if not path.is_dir(): + raise HTTPException(status_code=404, detail=f"Upload not found: {upload_id}") + shutil.rmtree(path) + return Response(status_code=204) diff --git a/src/radiarch/app.py b/src/radiarch/app.py index 6f6d4ea..f7a9fa2 100644 --- a/src/radiarch/app.py +++ b/src/radiarch/app.py @@ -4,7 +4,7 @@ from fastapi.middleware.cors import CORSMiddleware from .config import get_settings -from .api.routes import info, plans, jobs, artifacts, workflows, sessions, simulations, geometry +from .api.routes import info, plans, jobs, artifacts, workflows, sessions, simulations, geometry, beam_model, uploads from .adapters import build_orthanc_adapter from .core.database import init_db @@ -42,6 +42,8 @@ def create_app() -> FastAPI: app.include_router(sessions.router, prefix=settings.api_prefix) app.include_router(simulations.router, prefix=settings.api_prefix) app.include_router(geometry.router, prefix=settings.api_prefix) + app.include_router(beam_model.router, prefix=settings.api_prefix) + app.include_router(uploads.router, prefix=settings.api_prefix) @app.get("/") async def root(): diff --git a/src/radiarch/config.py b/src/radiarch/config.py index 7838f33..bb661dd 100644 --- a/src/radiarch/config.py +++ b/src/radiarch/config.py @@ -26,6 +26,13 @@ class Settings(BaseSettings): # Database and artifact storage database_url: str = Field(default="", description="Leave empty to use InMemoryStore; set to sqlite:///./radiarch.db or postgresql+psycopg://... for persistence") artifact_dir: str = Field(default="./data/artifacts") + upload_dir: str = Field( + default="", + description=( + "Where uploaded DICOM ZIPs get extracted to. " + "Empty (the default) resolves to {artifact_dir}/uploads at runtime." + ), + ) # Session TTL session_ttl: int = Field(default=3600, description="Session expiration in seconds") diff --git a/src/radiarch/core/db_models.py b/src/radiarch/core/db_models.py index 593f0b9..1b951e2 100644 --- a/src/radiarch/core/db_models.py +++ b/src/radiarch/core/db_models.py @@ -63,3 +63,50 @@ class ArtifactRow(Base): created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) plan = relationship("PlanRow", back_populates="artifacts") + + +class GeometryJobRow(Base): + """Async tracking row for one ``POST /geometry/build`` invocation. + + Unlike ``JobRow`` this isn't tied to a plan. ``cache_key`` is indexed + so we can look up in-flight builds for the same inputs (future dedup + work) and ``geometry_id`` is populated when the build succeeds so + clients polling the job endpoint can follow the link to the result. + """ + + __tablename__ = "geometry_jobs" + + id = Column(String(36), primary_key=True) + cache_key = Column(String(64), nullable=False, index=True) + state = Column(String(20), nullable=False, default="queued") + progress = Column(Float, default=0.0) + # Persist stage across restarts — clients polling want to see the + # current pipeline phase (loading_dicom / rasterizing_contours / …). + stage = Column(String(32), nullable=True, default="queued") + message = Column(Text, nullable=True) + geometry_id = Column(String(36), nullable=True) + started_at = Column(DateTime(timezone=True), nullable=True) + finished_at = Column(DateTime(timezone=True), nullable=True) + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) + + +class BeamModelJobRow(Base): + """Async tracking row for one ``POST /beam-model/build`` invocation. + + Mirrors :class:`GeometryJobRow` exactly — same fields, different + table. ``beam_model_id`` is populated when the build succeeds so + clients polling the jobs endpoint can deep-link to the result. + """ + + __tablename__ = "beam_model_jobs" + + id = Column(String(36), primary_key=True) + cache_key = Column(String(64), nullable=False, index=True) + state = Column(String(20), nullable=False, default="queued") + progress = Column(Float, default=0.0) + stage = Column(String(32), nullable=True, default="queued") + message = Column(Text, nullable=True) + beam_model_id = Column(String(36), nullable=True) + started_at = Column(DateTime(timezone=True), nullable=True) + finished_at = Column(DateTime(timezone=True), nullable=True) + created_at = Column(DateTime(timezone=True), default=_utcnow, nullable=False) diff --git a/src/radiarch/core/store.py b/src/radiarch/core/store.py index 6999cf7..c9abc9e 100644 --- a/src/radiarch/core/store.py +++ b/src/radiarch/core/store.py @@ -14,6 +14,8 @@ from ..models.plan import PlanDetail, PlanRequest, PlanSummary from ..models.job import JobStatus, JobState from ..models.artifact import ArtifactRecord +from ..models.geometry import GeometryJobStatus, GeometryStage +from ..models.beam_model import BeamModelJobStatus, BeamModelStage def _utcnow(): @@ -74,6 +76,50 @@ def delete_plan(self, plan_id: str) -> bool: """Delete a plan and all associated jobs/artifacts. Returns True if deleted.""" ... + # ---- Geometry async jobs ----------------------------------------- + + @abc.abstractmethod + def create_geometry_job(self, cache_key: str) -> GeometryJobStatus: + """Record a new queued geometry build. Returns the new status row.""" + ... + + @abc.abstractmethod + def get_geometry_job(self, job_id: str) -> Optional[GeometryJobStatus]: ... + + @abc.abstractmethod + def update_geometry_job( + self, + job_id: str, + *, + state: Optional[JobState] = None, + progress: Optional[float] = None, + message: Optional[str] = None, + stage: Optional[GeometryStage] = None, + geometry_id: Optional[str] = None, + ) -> Optional[GeometryJobStatus]: ... + + # ---- Beam model async jobs --------------------------------------- + + @abc.abstractmethod + def create_beam_model_job(self, cache_key: str) -> BeamModelJobStatus: + """Record a new queued beam-model build.""" + ... + + @abc.abstractmethod + def get_beam_model_job(self, job_id: str) -> Optional[BeamModelJobStatus]: ... + + @abc.abstractmethod + def update_beam_model_job( + self, + job_id: str, + *, + state: Optional[JobState] = None, + progress: Optional[float] = None, + message: Optional[str] = None, + stage: Optional[BeamModelStage] = None, + beam_model_id: Optional[str] = None, + ) -> Optional[BeamModelJobStatus]: ... + # --------------------------------------------------------------------------- # In-memory implementation (used in tests & dev) @@ -84,6 +130,8 @@ def __init__(self): self._plans: Dict[str, PlanDetail] = {} self._jobs: Dict[str, JobStatus] = {} self._artifacts: Dict[str, ArtifactRecord] = {} + self._geometry_jobs: Dict[str, GeometryJobStatus] = {} + self._beam_model_jobs: Dict[str, BeamModelJobStatus] = {} def create_plan(self, payload: PlanRequest) -> tuple[PlanDetail, JobStatus]: plan_id = str(uuid.uuid4()) @@ -209,6 +257,110 @@ def delete_plan(self, plan_id: str) -> bool: self._artifacts.pop(aid, None) return True + # ---- Geometry async jobs ----------------------------------------- + + def create_geometry_job(self, cache_key: str) -> GeometryJobStatus: + job_id = str(uuid.uuid4()) + now = _utcnow() + job = GeometryJobStatus( + id=job_id, + cache_key=cache_key, + state=JobState.queued, + progress=0.0, + stage=GeometryStage.queued, + created_at=now, + ) + self._geometry_jobs[job_id] = job + return job + + def get_geometry_job(self, job_id: str) -> Optional[GeometryJobStatus]: + return self._geometry_jobs.get(job_id) + + def update_geometry_job( + self, + job_id: str, + *, + state: Optional[JobState] = None, + progress: Optional[float] = None, + message: Optional[str] = None, + stage: Optional[GeometryStage] = None, + geometry_id: Optional[str] = None, + ) -> Optional[GeometryJobStatus]: + job = self._geometry_jobs.get(job_id) + if not job: + return None + data = job.model_dump() + now = _utcnow() + if state: + data["state"] = state + if state == JobState.running and not data.get("started_at"): + data["started_at"] = now + if state in {JobState.succeeded, JobState.failed, JobState.cancelled}: + data["finished_at"] = now + if progress is not None: + data["progress"] = progress + if message is not None: + data["message"] = message + if stage is not None: + data["stage"] = stage + if geometry_id is not None: + data["geometry_id"] = geometry_id + updated = GeometryJobStatus(**data) + self._geometry_jobs[job_id] = updated + return updated + + # ---- Beam model async jobs --------------------------------------- + + def create_beam_model_job(self, cache_key: str) -> BeamModelJobStatus: + job_id = str(uuid.uuid4()) + now = _utcnow() + job = BeamModelJobStatus( + id=job_id, + cache_key=cache_key, + state=JobState.queued, + progress=0.0, + stage=BeamModelStage.queued, + created_at=now, + ) + self._beam_model_jobs[job_id] = job + return job + + def get_beam_model_job(self, job_id: str) -> Optional[BeamModelJobStatus]: + return self._beam_model_jobs.get(job_id) + + def update_beam_model_job( + self, + job_id: str, + *, + state: Optional[JobState] = None, + progress: Optional[float] = None, + message: Optional[str] = None, + stage: Optional[BeamModelStage] = None, + beam_model_id: Optional[str] = None, + ) -> Optional[BeamModelJobStatus]: + job = self._beam_model_jobs.get(job_id) + if not job: + return None + data = job.model_dump() + now = _utcnow() + if state: + data["state"] = state + if state == JobState.running and not data.get("started_at"): + data["started_at"] = now + if state in {JobState.succeeded, JobState.failed, JobState.cancelled}: + data["finished_at"] = now + if progress is not None: + data["progress"] = progress + if message is not None: + data["message"] = message + if stage is not None: + data["stage"] = stage + if beam_model_id is not None: + data["beam_model_id"] = beam_model_id + updated = BeamModelJobStatus(**data) + self._beam_model_jobs[job_id] = updated + return updated + # --------------------------------------------------------------------------- # SQL implementation (production) @@ -416,6 +568,180 @@ def delete_plan(self, plan_id: str) -> bool: finally: session.close() + # ---- Geometry async jobs ----------------------------------------- + + def create_geometry_job(self, cache_key: str) -> GeometryJobStatus: + from .db_models import GeometryJobRow + + job_id = str(uuid.uuid4()) + now = _utcnow() + session = self._session() + try: + row = GeometryJobRow( + id=job_id, + cache_key=cache_key, + state=JobState.queued.value, + progress=0.0, + stage=GeometryStage.queued.value, + created_at=now, + ) + session.add(row) + session.commit() + return self._geometry_job_row_to_status(row) + finally: + session.close() + + def get_geometry_job(self, job_id: str) -> Optional[GeometryJobStatus]: + from .db_models import GeometryJobRow + + session = self._session() + try: + row = session.query(GeometryJobRow).filter_by(id=job_id).first() + if not row: + return None + return self._geometry_job_row_to_status(row) + finally: + session.close() + + def update_geometry_job( + self, + job_id: str, + *, + state: Optional[JobState] = None, + progress: Optional[float] = None, + message: Optional[str] = None, + stage: Optional[GeometryStage] = None, + geometry_id: Optional[str] = None, + ) -> Optional[GeometryJobStatus]: + from .db_models import GeometryJobRow + + session = self._session() + try: + row = session.query(GeometryJobRow).filter_by(id=job_id).first() + if not row: + return None + now = _utcnow() + if state: + row.state = state.value if hasattr(state, "value") else state + if state == JobState.running and not row.started_at: + row.started_at = now + if state in {JobState.succeeded, JobState.failed, JobState.cancelled}: + row.finished_at = now + if progress is not None: + row.progress = progress + if message is not None: + row.message = message + if stage is not None: + row.stage = stage.value if hasattr(stage, "value") else stage + if geometry_id is not None: + row.geometry_id = geometry_id + session.commit() + return self._geometry_job_row_to_status(row) + finally: + session.close() + + @staticmethod + def _geometry_job_row_to_status(row) -> GeometryJobStatus: + return GeometryJobStatus( + id=row.id, + cache_key=row.cache_key, + state=row.state, + progress=row.progress or 0.0, + stage=row.stage, + message=row.message, + geometry_id=row.geometry_id, + started_at=row.started_at, + finished_at=row.finished_at, + created_at=row.created_at, + ) + + # ---- Beam model async jobs --------------------------------------- + + def create_beam_model_job(self, cache_key: str) -> BeamModelJobStatus: + from .db_models import BeamModelJobRow + + job_id = str(uuid.uuid4()) + now = _utcnow() + session = self._session() + try: + row = BeamModelJobRow( + id=job_id, + cache_key=cache_key, + state=JobState.queued.value, + progress=0.0, + stage=BeamModelStage.queued.value, + created_at=now, + ) + session.add(row) + session.commit() + return self._beam_model_job_row_to_status(row) + finally: + session.close() + + def get_beam_model_job(self, job_id: str) -> Optional[BeamModelJobStatus]: + from .db_models import BeamModelJobRow + + session = self._session() + try: + row = session.query(BeamModelJobRow).filter_by(id=job_id).first() + if not row: + return None + return self._beam_model_job_row_to_status(row) + finally: + session.close() + + def update_beam_model_job( + self, + job_id: str, + *, + state: Optional[JobState] = None, + progress: Optional[float] = None, + message: Optional[str] = None, + stage: Optional[BeamModelStage] = None, + beam_model_id: Optional[str] = None, + ) -> Optional[BeamModelJobStatus]: + from .db_models import BeamModelJobRow + + session = self._session() + try: + row = session.query(BeamModelJobRow).filter_by(id=job_id).first() + if not row: + return None + now = _utcnow() + if state: + row.state = state.value if hasattr(state, "value") else state + if state == JobState.running and not row.started_at: + row.started_at = now + if state in {JobState.succeeded, JobState.failed, JobState.cancelled}: + row.finished_at = now + if progress is not None: + row.progress = progress + if message is not None: + row.message = message + if stage is not None: + row.stage = stage.value if hasattr(stage, "value") else stage + if beam_model_id is not None: + row.beam_model_id = beam_model_id + session.commit() + return self._beam_model_job_row_to_status(row) + finally: + session.close() + + @staticmethod + def _beam_model_job_row_to_status(row) -> BeamModelJobStatus: + return BeamModelJobStatus( + id=row.id, + cache_key=row.cache_key, + state=row.state, + progress=row.progress or 0.0, + stage=row.stage, + message=row.message, + beam_model_id=row.beam_model_id, + started_at=row.started_at, + finished_at=row.finished_at, + created_at=row.created_at, + ) + # -- helpers -- @staticmethod diff --git a/src/radiarch/core/workflows/_helpers.py b/src/radiarch/core/workflows/_helpers.py index 3121289..e4d85fb 100644 --- a/src/radiarch/core/workflows/_helpers.py +++ b/src/radiarch/core/workflows/_helpers.py @@ -80,6 +80,13 @@ def load_ct_and_patient(data_root: Optional[str] = None): # --------------------------------------------------------------------------- # Calibration & BDL # --------------------------------------------------------------------------- +# +# These two helpers used to load the BDL and MCsquare calibration directly. +# Service 2 (Beam Model) introduced ``ProtonMachineModel`` — a pluggable +# abstraction for per-machine calibration data — so both helpers are now +# thin shims that delegate to the canonical loader. Their public signatures +# are unchanged so legacy workflow modules (planner, dose, contour) keep +# working without modification. def setup_calibration(): """Load the default MCsquare CT calibration. @@ -87,17 +94,10 @@ def setup_calibration(): Returns: (calibration, mcsquare_path) tuple """ - import opentps.core.processing.doseCalculation.protons.MCsquare as MCsquareModule - from opentps.core.data.CTCalibrations.MCsquareCalibration._mcsquareCTCalibration import MCsquareCTCalibration + from ...services.machine_model import ProtonMachineModel - mcsquare_path = str(MCsquareModule.__path__[0]) - scanner_path = os.path.join(mcsquare_path, "Scanners", "UCL_Toshiba") - calibration = MCsquareCTCalibration.fromFiles( - huDensityFile=os.path.join(scanner_path, "HU_Density_Conversion.txt"), - huMaterialFile=os.path.join(scanner_path, "HU_Material_Conversion.txt"), - materialsPath=os.path.join(mcsquare_path, "Materials"), - ) - return calibration, mcsquare_path + mm = ProtonMachineModel.from_default() + return mm.calibration, mm.mcsquare_path def load_bdl(): @@ -106,21 +106,9 @@ def load_bdl(): Returns: BDL object from mcsquareIO.readBDL """ - import opentps - import opentps.core.processing.doseCalculation.protons.MCsquare as MCsquareModule - from opentps.core.io import mcsquareIO - - base_path = os.path.dirname(opentps.__file__) - bdl_path = os.path.join( - base_path, "core", "processing", "doseCalculation", - "protons", "MCsquare", "BDL", "BDL_default_DN_RangeShifter.txt", - ) - if not os.path.exists(bdl_path): - bdl_path = os.path.join( - os.path.dirname(MCsquareModule.__file__), - "BDL", "BDL_default_DN_RangeShifter.txt", - ) - return mcsquareIO.readBDL(bdl_path) + from ...services.machine_model import ProtonMachineModel + + return ProtonMachineModel.from_default().bdl # --------------------------------------------------------------------------- diff --git a/src/radiarch/models/beam_model.py b/src/radiarch/models/beam_model.py new file mode 100644 index 0000000..d5d0c19 --- /dev/null +++ b/src/radiarch/models/beam_model.py @@ -0,0 +1,303 @@ +"""Pydantic I/O models for the Beam Model Service (Service 2). + +A "beam model" is the modality-specific representation of deliverable +radiation elements for a treatment plan: + +* For protons (PROTON_PBS), one element is a *spot* — a discrete + ``(x, y, energy)`` triple in beam's-eye-view coordinates, organized + into energy layers and scanned across the target. +* For photons (PHOTON_IMRT), one element is a *beamlet* — a 2D pixel + in a regular grid in beam's-eye-view, modulated by an MLC aperture. + +These two physical concepts are unified under the ``FluenceElementSet`` +abstraction: each element produces one independent unit of dose deposit +downstream, and the dose engine doesn't need to know which modality +made it. + +See ``docs/tps_services_implementation_plan.md`` (Service 2) for the +full specification. +""" + +from __future__ import annotations + +import hashlib +import json +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel, Field, field_validator, model_validator + +from .job import JobState + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + +class Modality(str, Enum): + """Supported treatment modalities. + + PHOTON_IMRT — Intensity-Modulated Radiation Therapy with photons, + delivered as a fluence grid through an MLC. + PROTON_PBS — Pencil-Beam Scanning with protons, delivered as a + discrete spot map in energy layers. + """ + + photon_imrt = "PHOTON_IMRT" + proton_pbs = "PROTON_PBS" + + +class BeamModelStage(str, Enum): + """Stages a :class:`BeamModelService.build` call passes through. + + Reported via the ``stage`` field on ``BeamModelJobStatus`` so + clients polling ``GET /beam-model/jobs/{id}`` can render progress. + """ + + queued = "queued" + loading_geometry = "loading_geometry" + loading_machine_model = "loading_machine_model" + building_beams = "building_beams" + computing_elements = "computing_elements" + persisting = "persisting" + done = "done" + + +# --------------------------------------------------------------------------- +# Beam set specification +# --------------------------------------------------------------------------- + +class BeamSpec(BaseModel): + """One beam in a treatment plan — gantry, couch, collimator angles.""" + + beam_id: str = Field(..., min_length=1, max_length=64) + gantry_deg: float = Field(..., ge=0.0, lt=360.0, + description="IEC 61217 gantry angle, [0, 360).") + couch_deg: float = Field(default=0.0, ge=-180.0, le=180.0, + description="IEC 61217 couch angle, [-180, 180].") + collimator_deg: float = Field(default=0.0, ge=-180.0, le=180.0, + description="Collimator rotation, [-180, 180].") + + +class BeamSetSpec(BaseModel): + """The geometric beam configuration for a plan. + + ``isocenter_mm`` is in patient LPS coordinates. ``beams`` carries one + entry per delivery direction (1–9 beams supported, matching the + existing PlanRequest constraint). + """ + + isocenter_mm: Tuple[float, float, float] + beams: List[BeamSpec] = Field(..., min_length=1, max_length=9) + + @field_validator("beams") + @classmethod + def _unique_beam_ids(cls, beams: List[BeamSpec]) -> List[BeamSpec]: + ids = [b.beam_id for b in beams] + if len(set(ids)) != len(ids): + raise ValueError(f"beam_id values must be unique, got {ids}") + return beams + + +# --------------------------------------------------------------------------- +# Delivery parameters (modality-tagged union, expressed as nullable union) +# --------------------------------------------------------------------------- + +class DeliveryParams(BaseModel): + """Delivery-system parameters. Modality determines which fields apply. + + For PROTON_PBS the proton fields are honored and photon fields are + ignored (and excluded from the cache key). For PHOTON_IMRT the + inverse holds. Defaults below match the existing project conventions + (5 mm spot/layer spacing for protons, 5×5 mm beamlets for photons). + """ + + # ---- PROTON_PBS ------------------------------------------------------ + spot_spacing_mm: Optional[float] = Field(default=5.0, gt=0) + layer_spacing_mm: Optional[float] = Field(default=5.0, gt=0) + energy_range: Optional[Tuple[float, float]] = Field( + default=None, + description="MeV range. None → derived from target depth.", + ) + + # ---- PHOTON_IMRT ----------------------------------------------------- + beamlet_size_mm: Optional[Tuple[float, float]] = Field( + default=(5.0, 5.0), + description="Beamlet pixel size in BEV (x, y) mm.", + ) + mlc_leaf_width_mm: Optional[float] = Field(default=None, gt=0) + jaw_opening_mm: Optional[Tuple[float, float]] = Field(default=None) + + # ---- Helpers --------------------------------------------------------- + + def for_modality(self, modality: Modality) -> Dict[str, Any]: + """Return only the fields that affect dose for ``modality``. + + This is what gets folded into the cache key — so a default + photon param change doesn't invalidate proton cached results. + """ + if modality is Modality.proton_pbs: + return { + "spot_spacing_mm": self.spot_spacing_mm, + "layer_spacing_mm": self.layer_spacing_mm, + "energy_range": list(self.energy_range) if self.energy_range else None, + } + if modality is Modality.photon_imrt: + return { + "beamlet_size_mm": list(self.beamlet_size_mm) if self.beamlet_size_mm else None, + "mlc_leaf_width_mm": self.mlc_leaf_width_mm, + "jaw_opening_mm": list(self.jaw_opening_mm) if self.jaw_opening_mm else None, + } + raise ValueError(f"Unhandled modality: {modality!r}") + + +# --------------------------------------------------------------------------- +# Build request +# --------------------------------------------------------------------------- + +class BeamModelBuildRequest(BaseModel): + """Input payload for POST /api/v1/beam-model/build.""" + + plan_id: Optional[str] = Field(default=None, max_length=36) + geometry_id: str = Field(..., min_length=1, max_length=36, + description="ID returned by Geometry Service.") + modality: Modality + machine_model_id: Optional[str] = Field( + default=None, + description="Custom machine model identifier; null = project default.", + ) + beam_set: BeamSetSpec + delivery_params: DeliveryParams = Field(default_factory=DeliveryParams) + + # ---- Cache key ------------------------------------------------------- + + def compute_cache_key(self) -> str: + """Deterministic sha256 over the inputs that affect output content. + + Excludes ``plan_id`` (a downstream reference, not an input to the + physics) and excludes the modality-irrelevant subset of + ``delivery_params``. + """ + # Sort beams by beam_id so equivalent BeamSetSpecs hash the same + # regardless of insertion order. + sorted_beams = sorted( + ( + { + "beam_id": b.beam_id, + "gantry_deg": b.gantry_deg, + "couch_deg": b.couch_deg, + "collimator_deg": b.collimator_deg, + } + for b in self.beam_set.beams + ), + key=lambda b: b["beam_id"], + ) + payload = { + "geometry_id": self.geometry_id, + "modality": self.modality.value, + "machine_model_id": self.machine_model_id, + "isocenter_mm": list(self.beam_set.isocenter_mm), + "beams": sorted_beams, + "delivery_params": self.delivery_params.for_modality(self.modality), + } + blob = json.dumps(payload, sort_keys=True, separators=(",", ":")) + return hashlib.sha256(blob.encode("utf-8")).hexdigest() + + +# --------------------------------------------------------------------------- +# Result types +# --------------------------------------------------------------------------- + +class PerBeamElements(BaseModel): + """Per-beam fluence-element breakdown. + + The proton fields (``energy_layers``, ``spots_per_layer``) are + populated for PROTON_PBS results. The photon fields (``grid_dims``, + ``active_beamlets``) are populated for PHOTON_IMRT. The other set is + ``None`` — same model serves both modalities. + """ + + beam_id: str + element_count: int = Field(..., ge=0) + + # PROTON_PBS extras + energy_layers: Optional[List[float]] = None + spots_per_layer: Optional[List[int]] = None + + # PHOTON_IMRT extras + grid_dims: Optional[Tuple[int, int]] = None + active_beamlets: Optional[int] = None + + @model_validator(mode="after") + def _spot_arrays_align(self) -> "PerBeamElements": + """If both proton arrays are populated, their lengths must match.""" + if self.energy_layers is not None and self.spots_per_layer is not None: + if len(self.energy_layers) != len(self.spots_per_layer): + raise ValueError( + "energy_layers and spots_per_layer must have the same length" + ) + return self + + +class FluenceElementSet(BaseModel): + """Aggregate fluence-element summary across all beams. + + ``total_count`` is the sum of ``per_beam[i].element_count`` — + validated on construction so the two views can't drift. + """ + + total_count: int = Field(..., ge=0) + per_beam: List[PerBeamElements] + + @model_validator(mode="after") + def _totals_match(self) -> "FluenceElementSet": + s = sum(pb.element_count for pb in self.per_beam) + if s != self.total_count: + raise ValueError( + f"total_count={self.total_count} does not match " + f"sum(per_beam.element_count)={s}" + ) + return self + + +class BeamModelResult(BaseModel): + """Output of a completed beam-model build.""" + + beam_model_id: str + geometry_id: str + modality: Modality + fluence_elements: FluenceElementSet + beam_model_ref_uri: str = Field( + ..., description="URI/path to the serialized OpenTPS plan artifact." + ) + machine_model_id: str = Field( + ..., description="Resolved machine model id (never None on output)." + ) + cache_key: str + created_at: Optional[datetime] = None + + +# --------------------------------------------------------------------------- +# Async job tracking +# --------------------------------------------------------------------------- + +class BeamModelJobStatus(BaseModel): + """Tracks an async ``POST /beam-model/build`` invocation. + + Identified by its own ``id``. Carries ``cache_key`` so a second + request for the same inputs can short-circuit (or reuse an in-flight + job, in a future enhancement). ``beam_model_id`` is populated when + the job succeeds. + """ + + id: str + cache_key: str + state: JobState = JobState.queued + progress: float = 0.0 + stage: Optional[BeamModelStage] = BeamModelStage.queued + message: Optional[str] = None + beam_model_id: Optional[str] = None + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + created_at: Optional[datetime] = None diff --git a/src/radiarch/models/geometry.py b/src/radiarch/models/geometry.py index 44c5948..00c84d4 100644 --- a/src/radiarch/models/geometry.py +++ b/src/radiarch/models/geometry.py @@ -28,6 +28,8 @@ import numpy as np from pydantic import BaseModel, Field, field_validator, model_validator +from .job import JobState + # --------------------------------------------------------------------------- # Enums @@ -127,9 +129,24 @@ def is_fully_specified(self) -> bool: # --------------------------------------------------------------------------- class PatientRef(BaseModel): - """Points the Geometry Service at a specific CT + RTSTRUCT pair.""" + """Points the Geometry Service at a specific CT + RTSTRUCT pair. + + Two mutually-compatible ways to identify the source data: + + * ``dicom_study_uid`` (+ optional series UIDs) — pulls from a PACS + via the configured Orthanc/DICOMweb adapter, or from the + ``opentps_data_root`` fallback in dev mode. + * ``upload_id`` — points at a previously-uploaded DICOM bundle + sitting under ``{settings.upload_dir}/{upload_id}/``. Takes + precedence when both are present. + + At least one of the two MUST be provided. + """ - dicom_study_uid: str = Field(..., description="DICOM Study Instance UID") + dicom_study_uid: Optional[str] = Field( + default=None, + description="DICOM Study Instance UID. Required unless upload_id is set.", + ) ct_series_uid: Optional[str] = Field( default=None, description="CT Series Instance UID. Null = auto-detect the primary CT.", @@ -138,6 +155,22 @@ class PatientRef(BaseModel): default=None, description="RTSTRUCT Series Instance UID. Null = auto-detect.", ) + upload_id: Optional[str] = Field( + default=None, + description=( + "Upload id returned by POST /uploads/dicom. When set, the " + "geometry build reads CT + RTSTRUCT from the extracted upload " + "directory instead of going to PACS." + ), + ) + + @model_validator(mode="after") + def _require_one_source(self) -> "PatientRef": + if not self.dicom_study_uid and not self.upload_id: + raise ValueError( + "PatientRef requires either dicom_study_uid or upload_id." + ) + return self # --------------------------------------------------------------------------- @@ -180,6 +213,7 @@ def compute_cache_key(self) -> str: "study": self.patient_ref.dicom_study_uid, "ct": self.patient_ref.ct_series_uid, "rts": self.patient_ref.rtstruct_uid, + "upload": self.patient_ref.upload_id, "grid": self.grid_spec.model_dump() if self.grid_spec else None, "hu_model": self.hu_to_density_model.value, "name_map": self._normalized_name_map(), @@ -236,3 +270,46 @@ def _check_structure_index(self) -> "GeometryResult": if len(set(self.structure_index.values())) != len(self.structure_index): raise ValueError("structure_index labels must be unique") return self + + +# --------------------------------------------------------------------------- +# Async job tracking +# --------------------------------------------------------------------------- + +class GeometryStage(str, Enum): + """Stages a :class:`GeometryService.build` call passes through. + + Reported via the ``stage`` field on ``GeometryJobStatus`` so clients + polling ``GET /geometry/jobs/{job_id}`` can render a progress UI + without waiting for the final result. + """ + + queued = "queued" + loading_dicom = "loading_dicom" + converting_hu = "converting_hu" + rasterizing_contours = "rasterizing_contours" + resampling = "resampling" + persisting = "persisting" + done = "done" + + +class GeometryJobStatus(BaseModel): + """Tracks an async ``POST /geometry/build`` invocation. + + Unlike plan jobs this has no parent — it's identified by its own + ``id`` and carries the ``cache_key`` so a second request for the same + inputs can short-circuit to the cached geometry or reuse the + in-flight job. + """ + + id: str + cache_key: str + state: JobState = JobState.queued + progress: float = 0.0 + stage: Optional[GeometryStage] = GeometryStage.queued + message: Optional[str] = None + geometry_id: Optional[str] = None # populated when state == succeeded + eta_seconds: Optional[float] = None + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + created_at: Optional[datetime] = None diff --git a/src/radiarch/services/beam_model.py b/src/radiarch/services/beam_model.py new file mode 100644 index 0000000..8383d35 --- /dev/null +++ b/src/radiarch/services/beam_model.py @@ -0,0 +1,276 @@ +"""``BeamModelService`` — geometry + beam set → modality-specific beam model. + +This is the public entry point for Service 2. One method, one contract: +``build(request) -> BeamModelResult``. + +Pipeline +-------- +1. Compute ``cache_key`` and short-circuit if already built. +2. Load the geometry — verify it exists in the Geometry Service's store + (404 → 422 to the API caller). Then load OpenTPS objects we'll + need to feed the modality builder. v1: re-uses the existing + ``load_ct_and_patient`` helper (which is fast on subsequent calls + thanks to OS file cache + the Geometry Service's NIfTI cache for the + density grid). Future work (B5) eliminates this re-load entirely by + wiring workflows to consume geometries by id. +3. Resolve the machine model. +4. Dispatch to ``_build_proton`` or ``_build_photon`` based on modality. +5. Persist the OpenTPS plan + meta.json atomically. +6. Return :class:`BeamModelResult`. + +Testability +----------- +Same ``_load_*`` / ``_process`` seam as :class:`GeometryService`. Tests +stub out ``_load_geometry`` and ``_load_machine_model`` (or the modality +builders directly) so they never invoke OpenTPS or MCsquare. +""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Optional + +from loguru import logger + +from ..config import get_settings +from ..models.beam_model import ( + BeamModelBuildRequest, + BeamModelResult, + BeamModelStage, + FluenceElementSet, + Modality, +) +from .beam_persistence import BeamModelPaths, BeamModelStore +from .machine_model import ( + MachineModelBase, + MachineModelError, + PhotonMachineModel, + ProtonMachineModel, + get_machine_model, +) + + +# Same callback shape as Geometry Service: (stage, fraction, message) → None. +ProgressCallback = Callable[[BeamModelStage, float, str], None] + + +@dataclass +class _LoadedGeometry: + """Internal bundle returned by ``_load_geometry``. + + Carries the OpenTPS objects the modality builders need plus the + ``geometry_id`` so the result can reference its upstream input. + """ + + geometry_id: str + ct: Any # OpenTPS CTImage (or test double) + patient: Any # OpenTPS Patient (or test double) + target_contour: Any # First target ROI; may be None + + +class BeamModelService: + """Stateless service. One instance can serve many requests. + + The persistent state lives on disk via :class:`BeamModelStore`. A + second optional dependency — the geometry store — is resolved lazily + so tests can construct a service with no env / no settings. + """ + + def __init__(self, base_dir: Optional[str | Path] = None) -> None: + if base_dir is None: + settings = get_settings() + base_dir = Path(settings.artifact_dir) / "beam_models" + self.store = BeamModelStore(base_dir) + + # ----------------------------------------------------------------- + # Public entry point + # ----------------------------------------------------------------- + + def build( + self, + request: BeamModelBuildRequest, + progress_callback: Optional[ProgressCallback] = None, + ) -> BeamModelResult: + on_progress = progress_callback or _noop_progress + cache_key = request.compute_cache_key() + + cached = self.store.lookup_by_cache_key(cache_key) + if cached is not None: + logger.info( + "Beam model cache hit for key %s → %s", + cache_key[:10], cached.beam_model_id, + ) + on_progress(BeamModelStage.done, 1.0, "cache hit") + return cached + + logger.info("Building beam model (cache miss, key %s)", cache_key[:10]) + + on_progress(BeamModelStage.loading_geometry, 0.05, "Loading geometry") + geometry = self._load_geometry(request.geometry_id) + + on_progress(BeamModelStage.loading_machine_model, 0.15, "Loading machine model") + machine_model = self._load_machine_model(request.modality, request.machine_model_id) + + return self._process(request, geometry, machine_model, cache_key, on_progress) + + # ----------------------------------------------------------------- + # Loading — the testability seam + # ----------------------------------------------------------------- + + def _load_geometry(self, geometry_id: str) -> _LoadedGeometry: + """Verify the geometry exists, then load CT + patient via OpenTPS. + + v1 simplification: we don't reconstruct an OpenTPS ``Patient`` + from the persisted NIfTI artifacts. Instead we look up the + ``GeometryResult`` to confirm the id is valid, then call the + existing ``load_ct_and_patient`` helper (which reads from the + configured ``opentps_data_root`` or the request's source DICOM, + depending on adapter mode). This matches what the four legacy + workflows do today. + + Future work (B5): wire ``load_ct_and_patient`` to fetch by + ``geometry_id`` so we share Geometry Service's DicomFetcher and + eliminate the second DICOM load. + """ + from .geometry import GeometryService # lazy + from ..core.workflows._helpers import ( # lazy + find_target_roi, + load_ct_and_patient, + ) + + geom_store = GeometryService().store + if geom_store.get_by_id(geometry_id) is None: + raise ValueError( + f"geometry_id {geometry_id!r} not found — build the geometry " + "first via POST /api/v1/geometry/build." + ) + + ct, patient, _ = load_ct_and_patient(data_root=None) + target = find_target_roi(patient, fallback_to_first=True) + return _LoadedGeometry( + geometry_id=geometry_id, + ct=ct, + patient=patient, + target_contour=target, + ) + + @staticmethod + def _load_machine_model( + modality: Modality, + machine_model_id: Optional[str], + ) -> MachineModelBase: + """Resolve the machine model. Wraps the factory so tests can stub.""" + try: + return get_machine_model(modality, machine_model_id) + except MachineModelError as exc: + # Surface as a request-validation error to the API layer. + raise ValueError(str(exc)) from exc + + # ----------------------------------------------------------------- + # Modality dispatch + persistence + # ----------------------------------------------------------------- + + def _process( + self, + request: BeamModelBuildRequest, + geometry: _LoadedGeometry, + machine_model: MachineModelBase, + cache_key: str, + on_progress: ProgressCallback, + ) -> BeamModelResult: + on_progress(BeamModelStage.building_beams, 0.30, f"Building {request.modality.value}") + + if request.modality is Modality.proton_pbs: + built = self._build_proton(request, geometry, machine_model) + elif request.modality is Modality.photon_imrt: + built = self._build_photon(request, geometry, machine_model) + else: # pragma: no cover — enum-exhaustive + raise ValueError(f"Unsupported modality: {request.modality!r}") + + on_progress( + BeamModelStage.computing_elements, + 0.75, + f"{built.fluence_elements.total_count} fluence elements", + ) + + beam_model_id = str(uuid.uuid4()) + paths = BeamModelPaths.for_id(self.store.base_dir, beam_model_id) + result = BeamModelResult( + beam_model_id=beam_model_id, + geometry_id=geometry.geometry_id, + modality=request.modality, + fluence_elements=built.fluence_elements, + beam_model_ref_uri=str(paths.plan), + machine_model_id=machine_model.machine_model_id, + cache_key=cache_key, + ) + + on_progress(BeamModelStage.persisting, 0.92, "Pickling plan") + self.store.save( + beam_model_id=beam_model_id, + cache_key=cache_key, + plan=built.plan, + result=result, + ) + + logger.info( + "Beam model %s built: modality=%s elements=%d", + beam_model_id, + request.modality.value, + built.fluence_elements.total_count, + ) + on_progress(BeamModelStage.done, 1.0, f"beam_model_id={beam_model_id}") + return result + + @staticmethod + def _build_proton( + request: BeamModelBuildRequest, + geometry: _LoadedGeometry, + machine_model: MachineModelBase, + ): + from .proton_spots import generate_proton_spots # lazy + + if not isinstance(machine_model, ProtonMachineModel): + raise ValueError( + f"Proton modality requires a ProtonMachineModel, " + f"got {type(machine_model).__name__}" + ) + return generate_proton_spots( + ct=geometry.ct, + patient=geometry.patient, + target_contour=geometry.target_contour, + machine_model=machine_model, + beam_set=request.beam_set, + params=request.delivery_params, + ) + + @staticmethod + def _build_photon( + request: BeamModelBuildRequest, + geometry: _LoadedGeometry, + machine_model: MachineModelBase, + ): + from .photon_beamlets import generate_photon_beamlets # lazy + + if not isinstance(machine_model, PhotonMachineModel): + raise ValueError( + f"Photon modality requires a PhotonMachineModel, " + f"got {type(machine_model).__name__}" + ) + return generate_photon_beamlets( + ct=geometry.ct, + patient=geometry.patient, + machine_model=machine_model, + beam_set=request.beam_set, + params=request.delivery_params, + ) + + +def _noop_progress(stage: BeamModelStage, fraction: float, message: str) -> None: + """Default ``progress_callback`` when none is supplied.""" + del stage, fraction, message + + +__all__ = ["BeamModelService", "ProgressCallback"] diff --git a/src/radiarch/services/beam_persistence.py b/src/radiarch/services/beam_persistence.py new file mode 100644 index 0000000..6ef6231 --- /dev/null +++ b/src/radiarch/services/beam_persistence.py @@ -0,0 +1,209 @@ +"""On-disk persistence for Beam Model Service outputs. + +Layout under ``{artifact_dir}/beam_models/``:: + + beam_models/ + _index.json # cache_key → beam_model_id + {beam_model_id}/ + plan.pkl # pickled OpenTPS plan (proton or photon) + meta.json # full BeamModelResult + +The plan file is whatever ``ProtonPlan`` or ``PhotonPlan`` instance the +modality builder produced; it's pickled rather than serialized to DICOM +RT to keep round-trip fidelity for the in-process dose engine. Future +work could add an alternate DICOM RT export for clinical interop. + +Atomicity follows the same belt-and-suspenders pattern as +``GeometryStore``: + +* writes go to a sibling ``.tmp.*`` directory and are renamed + atomically into ``{beam_model_id}/`` on success; +* the cache index entry is written *after* the directory is in place + (so a crash mid-write leaves orphan dirs, not dangling cache entries); +* deletes scrub the index entry first, then ``rmtree`` the directory. +""" + +from __future__ import annotations + +import json +import os +import pickle +import shutil +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional + +from ..models.beam_model import BeamModelResult + + +PLAN_FILENAME = "plan.pkl" +META_FILENAME = "meta.json" +INDEX_FILENAME = "_index.json" + + +@dataclass +class BeamModelPaths: + """Convenience bundle of the on-disk paths for one beam model.""" + + root: Path + plan: Path + meta: Path + + @classmethod + def for_id(cls, base_dir: Path, beam_model_id: str) -> "BeamModelPaths": + root = base_dir / beam_model_id + return cls( + root=root, + plan=root / PLAN_FILENAME, + meta=root / META_FILENAME, + ) + + +class BeamModelStore: + """File-backed beam-model persistence with a JSON cache index. + + Concurrency story is identical to :class:`GeometryStore`: not + process-safe across the cache index file, fine for synchronous + a + single Celery worker. Production multi-worker deployments will want + a DB row with a unique index on cache_key. + """ + + def __init__(self, base_dir: str | os.PathLike[str]) -> None: + self.base_dir = Path(base_dir).resolve() + self.base_dir.mkdir(parents=True, exist_ok=True) + + # ---- cache index -------------------------------------------------- + + @property + def _index_path(self) -> Path: + return self.base_dir / INDEX_FILENAME + + def _load_index(self) -> Dict[str, str]: + if not self._index_path.exists(): + return {} + try: + return json.loads(self._index_path.read_text()) + except (OSError, json.JSONDecodeError): + # Treat a corrupt index as empty; the next successful save + # overwrites it cleanly. + return {} + + def _save_index(self, index: Dict[str, str]) -> None: + tmp = self._index_path.with_suffix(".json.tmp") + tmp.write_text(json.dumps(index, indent=2, sort_keys=True)) + os.replace(tmp, self._index_path) + + def lookup_by_cache_key(self, cache_key: str) -> Optional[BeamModelResult]: + index = self._load_index() + beam_model_id = index.get(cache_key) + if not beam_model_id: + return None + return self.get_by_id(beam_model_id) + + def get_by_id(self, beam_model_id: str) -> Optional[BeamModelResult]: + paths = BeamModelPaths.for_id(self.base_dir, beam_model_id) + if not paths.meta.exists(): + return None + try: + data = json.loads(paths.meta.read_text()) + except (OSError, json.JSONDecodeError): + return None + return BeamModelResult.model_validate(data) + + def load_plan(self, beam_model_id: str) -> Any: + """Read and unpickle the saved OpenTPS plan object.""" + paths = BeamModelPaths.for_id(self.base_dir, beam_model_id) + if not paths.plan.exists(): + raise FileNotFoundError(f"plan artifact missing for {beam_model_id}") + with paths.plan.open("rb") as fh: + return pickle.load(fh) + + # ---- writes ------------------------------------------------------- + + def save( + self, + *, + beam_model_id: str, + cache_key: str, + plan: Any, + result: BeamModelResult, + ) -> BeamModelPaths: + """Pickle plan + write meta + update cache index, all atomically.""" + paths = BeamModelPaths.for_id(self.base_dir, beam_model_id) + with tempfile.TemporaryDirectory( + dir=self.base_dir, + prefix=f".{beam_model_id}.tmp.", + ) as tmp: + tmp_path = Path(tmp) + tmp_plan = tmp_path / PLAN_FILENAME + tmp_meta = tmp_path / META_FILENAME + + with tmp_plan.open("wb") as fh: + pickle.dump(plan, fh, protocol=pickle.HIGHEST_PROTOCOL) + tmp_meta.write_text(result.model_dump_json(indent=2)) + + # Atomic replace: nuke any existing dir first (retry-safe). + if paths.root.exists(): + shutil.rmtree(paths.root) + os.replace(tmp_path, paths.root) + # Recreate the tempdir reference so the context manager's + # cleanup is a no-op (the original tmp_path has moved). + os.makedirs(tmp_path, exist_ok=True) + + # Update the index *after* files are in place — order matters. + index = self._load_index() + index[cache_key] = beam_model_id + self._save_index(index) + return paths + + # ---- deletes ------------------------------------------------------ + + def delete_by_id(self, beam_model_id: str) -> bool: + """Remove a beam model dir and scrub its cache_key entry. + + Returns True if a model was actually deleted, False if the id + was unknown. Safe against partial state — each step is defensive. + """ + root = self.base_dir / beam_model_id + if not root.exists(): + return False + + cache_key = self._read_cache_key(root) + if cache_key is not None: + index = self._load_index() + if index.get(cache_key) == beam_model_id: + index.pop(cache_key) + self._save_index(index) + + shutil.rmtree(root, ignore_errors=True) + return True + + def _read_cache_key(self, root: Path) -> Optional[str]: + meta = root / META_FILENAME + if not meta.exists(): + return None + try: + return json.loads(meta.read_text()).get("cache_key") + except (OSError, json.JSONDecodeError): + return None + + # ---- debugging helpers ------------------------------------------- + + def list_ids(self) -> list[str]: + if not self.base_dir.exists(): + return [] + return sorted( + p.name + for p in self.base_dir.iterdir() + if p.is_dir() and not p.name.startswith(".") and (p / META_FILENAME).exists() + ) + + +__all__ = [ + "PLAN_FILENAME", + "META_FILENAME", + "INDEX_FILENAME", + "BeamModelPaths", + "BeamModelStore", +] diff --git a/src/radiarch/services/geometry.py b/src/radiarch/services/geometry.py index 6fa20c9..9e5da8d 100644 --- a/src/radiarch/services/geometry.py +++ b/src/radiarch/services/geometry.py @@ -36,7 +36,7 @@ import uuid from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import numpy as np from loguru import logger @@ -46,9 +46,14 @@ CTMetadata, GeometryBuildRequest, GeometryResult, + GeometryStage, GridSpec, HUDensityModel, ) + + +# Type alias: (stage, progress_fraction_0_to_1, human_message) -> None. +ProgressCallback = Callable[[GeometryStage, float, str], None] from .hu_density import get_model as get_hu_density_model from .persistence import ( DENSITY_FILENAME, @@ -103,36 +108,62 @@ def _get_adapter(self): # Public entry point # ----------------------------------------------------------------- - def build(self, request: GeometryBuildRequest) -> GeometryResult: + def build( + self, + request: GeometryBuildRequest, + progress_callback: Optional[ProgressCallback] = None, + ) -> GeometryResult: + """Run the geometry pipeline for ``request``. + + ``progress_callback`` is invoked as the pipeline advances through + its stages so async callers (the Celery task) can update job + status rows in real time. A no-op is used when None. + """ + on_progress = progress_callback or _noop_progress + cache_key = request.compute_cache_key() cached = self.store.lookup_by_cache_key(cache_key) if cached is not None: logger.info("Geometry cache hit for key %s → %s", cache_key[:10], cached.geometry_id) + on_progress(GeometryStage.done, 1.0, "cache hit") return cached logger.info("Building geometry (cache miss, key %s)", cache_key[:10]) + on_progress(GeometryStage.loading_dicom, 0.05, "Loading CT and RTSTRUCT") loaded = self._load(request) - return self._process(request, loaded, cache_key) + return self._process(request, loaded, cache_key, on_progress) # ----------------------------------------------------------------- # DICOM loading. # - # Two paths, chosen at request time: + # Three paths, chosen at request time (in priority order): # - # 1. PACS path — if the adapter exposes ``can_retrieve_instances()`` + # 1. Upload path — if ``patient_ref.upload_id`` is set, read from + # ``{settings.upload_dir}/{upload_id}/``. This is the production + # entry point (client uploads a ZIP, gets back an upload_id, + # then references it in a build request). + # 2. PACS path — if the adapter exposes ``can_retrieve_instances()`` # (i.e. a real Orthanc / DICOMweb backend), we fetch the study # into a temp dir and point OpenTPS at it. - # 2. Disk path — fall back to the legacy ``load_ct_and_patient`` + # 3. Disk path — fall back to the legacy ``load_ct_and_patient`` # helper which reads from ``opentps_data_root`` (or the request's - # ``data_root_override``). This keeps the existing dev/test - # flows working when Orthanc is mocked or unreachable. + # ``data_root_override``). Dev-only convenience that keeps the + # existing tests + SimpleFantom demo working when Orthanc is + # mocked and no upload was provided. # # This is also the seam that tests stub out (monkeypatch ``_load`` # to return a synthetic CT + fake contours). # ----------------------------------------------------------------- def _load(self, request: GeometryBuildRequest) -> _LoadedCT: + # 1. Upload path — highest priority. The client explicitly + # uploaded files; we should read those, never silently fall + # through to anything else. + if request.patient_ref.upload_id: + upload_path = self._resolve_upload_path(request.patient_ref.upload_id) + return self._load_from_disk(str(upload_path)) + # If the caller forced a data_root, honor it — useful for tests # and one-off debugging against local fixtures even when Orthanc # is reachable. @@ -160,6 +191,26 @@ def _load(self, request: GeometryBuildRequest) -> _LoadedCT: return self._load_from_pacs(fetcher, request) + # ---- Upload path ------------------------------------------------- + + @staticmethod + def _resolve_upload_path(upload_id: str) -> Path: + """Map an upload_id to its extracted directory. + + Raises ``ValueError`` (→ 422 at the API) if the upload directory + doesn't exist — covers the case where a client passes a stale or + bogus upload_id. + """ + settings = get_settings() + base = settings.upload_dir or str(Path(settings.artifact_dir) / "uploads") + path = Path(base).expanduser().resolve() / upload_id + if not path.is_dir(): + raise ValueError( + f"Upload id not found: {upload_id!r}. " + "POST a ZIP to /api/v1/uploads/dicom first." + ) + return path + # ---- Disk path ---------------------------------------------------- @staticmethod @@ -230,7 +281,10 @@ def _process( request: GeometryBuildRequest, loaded: _LoadedCT, cache_key: str, + on_progress: Optional[ProgressCallback] = None, ) -> GeometryResult: + on_progress = on_progress or _noop_progress + ct = loaded.ct ct_array = np.asarray(ct.imageArray) if ct_array.ndim != 3: @@ -249,6 +303,7 @@ def _process( src_affine = source_grid.to_numpy_affine() # 1. HU → density on the NATIVE grid. + on_progress(GeometryStage.converting_hu, 0.25, "HU → density") hu_model = get_hu_density_model(request.hu_to_density_model) density_native = hu_model.convert(ct_array) @@ -256,11 +311,13 @@ def _process( target_grid = self._resolve_target_grid(request, source_grid) # 3. Resample density if target ≠ source. + on_progress(GeometryStage.resampling, 0.45, "Resampling density to target grid") density_final = self._maybe_resample( density_native, src_affine, source_grid, target_grid ) # 4. Rasterize contours directly on the target grid. + on_progress(GeometryStage.rasterizing_contours, 0.70, "Rasterizing contours") masks, structure_index = rasterize_contours( loaded.contours, target_grid, @@ -268,6 +325,7 @@ def _process( ) # 5. Persist + build the result. + on_progress(GeometryStage.persisting, 0.90, "Writing NIfTI + cache index") geometry_id = str(uuid.uuid4()) paths = GeometryPaths.for_id(self.store.base_dir, geometry_id) ct_meta = self._ct_metadata(ct) @@ -296,6 +354,7 @@ def _process( target_grid.size, list(structure_index.keys()), ) + on_progress(GeometryStage.done, 1.0, f"geometry_id={geometry_id}") return result # ----------------------------------------------------------------- @@ -370,4 +429,9 @@ def _frame_of_reference(ct: Any) -> str: return str(for_uid) -__all__ = ["GeometryService"] +def _noop_progress(stage: GeometryStage, fraction: float, message: str) -> None: + """Default ``progress_callback`` when none is supplied.""" + del stage, fraction, message # unused + + +__all__ = ["GeometryService", "ProgressCallback"] diff --git a/src/radiarch/services/machine_model.py b/src/radiarch/services/machine_model.py new file mode 100644 index 0000000..63834cb --- /dev/null +++ b/src/radiarch/services/machine_model.py @@ -0,0 +1,273 @@ +"""Pluggable equipment-specific calibration data. + +The Beam Model Service needs equipment configuration to turn beam-set +specifications into deliverable plans: + +* For protons, that's the Beam Data Library (BDL) — a per-machine, + per-energy calibration table — plus the HU→density / HU→stopping-power + CT calibration. Today these are loaded inline by ``_helpers.load_bdl`` + and ``_helpers.setup_calibration``; this module replaces that with a + pluggable abstraction. +* For photons, that's the MLC leaf width, jaw extents, and beam quality. + Today these are hard-coded inside ``photon_ccc.py``; this module + extracts them. + +Both are lazy-loaded — instantiating a ``ProtonMachineModel`` does no +disk I/O until you access ``.bdl`` or ``.calibration``. + +Custom machine models (different LINAC, different proton system) load +from ``{settings.opentps_beam_library}/{machine_model_id}/`` if that +directory exists. ``machine_model_id=None`` means "use project default." +""" + +from __future__ import annotations + +import abc +import os +from typing import Any, Optional + +from loguru import logger + +from ..config import get_settings +from ..models.beam_model import Modality + + +class MachineModelError(RuntimeError): + """Raised when a machine model cannot be loaded (missing files, + unknown id, etc.).""" + + +# --------------------------------------------------------------------------- +# Default file-path resolution (proton) +# --------------------------------------------------------------------------- + +def _default_mcsquare_path() -> str: + """Path to the vendored MCsquare module — base for BDL + calibration.""" + import opentps.core.processing.doseCalculation.protons.MCsquare as MCsquareModule + return str(MCsquareModule.__path__[0]) + + +def _default_bdl_path() -> str: + return os.path.join( + _default_mcsquare_path(), + "BDL", + "BDL_default_DN_RangeShifter.txt", + ) + + +def _default_scanner_dir() -> str: + return os.path.join(_default_mcsquare_path(), "Scanners", "UCL_Toshiba") + + +# --------------------------------------------------------------------------- +# Base +# --------------------------------------------------------------------------- + +class MachineModelBase(abc.ABC): + """Common interface — every machine model knows its modality and id.""" + + #: Symbolic id used to look this model up; "default" for the project default. + machine_model_id: str + + @property + @abc.abstractmethod + def modality(self) -> Modality: ... + + +# --------------------------------------------------------------------------- +# Proton machine model +# --------------------------------------------------------------------------- + +class ProtonMachineModel(MachineModelBase): + """BDL + MCsquare CT calibration for a proton therapy system. + + Wraps the vendored OpenTPS data files. Properties are loaded lazily + so constructing the object is cheap (good for cache-key hashing + where we don't want to touch disk). + """ + + def __init__( + self, + machine_model_id: str = "default", + *, + bdl_path: Optional[str] = None, + scanner_dir: Optional[str] = None, + ) -> None: + self.machine_model_id = machine_model_id + self._bdl_path = bdl_path + self._scanner_dir = scanner_dir + # Lazy caches. + self._bdl: Any = None + self._calibration: Any = None + + @property + def modality(self) -> Modality: + return Modality.proton_pbs + + # ---- Classmethod constructors ------------------------------------ + + @classmethod + def from_default(cls) -> "ProtonMachineModel": + """The project default — vendored BDL + UCL_Toshiba calibration.""" + return cls( + machine_model_id="default", + bdl_path=_default_bdl_path(), + scanner_dir=_default_scanner_dir(), + ) + + @classmethod + def from_id(cls, machine_model_id: str) -> "ProtonMachineModel": + """Load a custom machine model from the configured beam library. + + Looks for ``{opentps_beam_library}/{id}/BDL.txt`` and + ``{opentps_beam_library}/{id}/Scanner/`` — falls back to the + default if those don't exist (with a logged warning). + """ + settings = get_settings() + base = os.path.join(settings.opentps_beam_library, machine_model_id) + bdl = os.path.join(base, "BDL.txt") + scanner = os.path.join(base, "Scanner") + if not os.path.isfile(bdl): + raise MachineModelError( + f"Proton machine model {machine_model_id!r} not found at {base}" + ) + return cls( + machine_model_id=machine_model_id, + bdl_path=bdl, + scanner_dir=scanner, + ) + + # ---- Lazy properties --------------------------------------------- + + @property + def bdl_path(self) -> str: + return self._bdl_path or _default_bdl_path() + + @property + def scanner_dir(self) -> str: + return self._scanner_dir or _default_scanner_dir() + + @property + def bdl(self) -> Any: + """Loaded ``BDL`` object (an OpenTPS BeamModel instance).""" + if self._bdl is None: + from opentps.core.io import mcsquareIO # lazy + path = self.bdl_path + if not os.path.isfile(path): + raise MachineModelError(f"BDL file not found: {path}") + self._bdl = mcsquareIO.readBDL(path) + return self._bdl + + @property + def calibration(self) -> Any: + """Loaded ``MCsquareCTCalibration``.""" + if self._calibration is None: + from opentps.core.data.CTCalibrations.MCsquareCalibration._mcsquareCTCalibration import ( + MCsquareCTCalibration, + ) + scanner = self.scanner_dir + mcsquare_path = _default_mcsquare_path() + self._calibration = MCsquareCTCalibration.fromFiles( + huDensityFile=os.path.join(scanner, "HU_Density_Conversion.txt"), + huMaterialFile=os.path.join(scanner, "HU_Material_Conversion.txt"), + materialsPath=os.path.join(mcsquare_path, "Materials"), + ) + return self._calibration + + @property + def mcsquare_path(self) -> str: + return _default_mcsquare_path() + + @property + def energy_range_mev(self) -> tuple: + """(min, max) proton energy in MeV for this BDL. + + Walks the loaded BDL's energy entries. Falls back to the + clinical-default range if the BDL doesn't expose energies in a + way we can introspect. + """ + try: + energies = [layer.NominalEnergy for layer in self.bdl.layers] + return (float(min(energies)), float(max(energies))) + except (AttributeError, ValueError): + # Conservative default for pencil-beam scanning systems. + return (70.0, 230.0) + + +# --------------------------------------------------------------------------- +# Photon machine model +# --------------------------------------------------------------------------- + +class PhotonMachineModel(MachineModelBase): + """MLC + jaw + beam-quality config for a photon LINAC.""" + + def __init__( + self, + machine_model_id: str = "default", + *, + mlc_leaf_width_mm: float = 10.0, + max_jaw_opening_mm: float = 200.0, + beam_quality_mv: float = 6.0, + ) -> None: + self.machine_model_id = machine_model_id + self.mlc_leaf_width_mm = mlc_leaf_width_mm + self.max_jaw_opening_mm = max_jaw_opening_mm + self.beam_quality_mv = beam_quality_mv + + @property + def modality(self) -> Modality: + return Modality.photon_imrt + + @classmethod + def from_default(cls) -> "PhotonMachineModel": + """The project default — Varian-class 6 MV LINAC, 10 mm MLC.""" + return cls(machine_model_id="default") + + @classmethod + def from_id(cls, machine_model_id: str) -> "PhotonMachineModel": + """Load a custom photon machine config. + + v1: only "default" is recognized (the spec leaves photon machine + configs as a future enhancement). Unknown ids raise rather than + falling back silently. + """ + if machine_model_id == "default": + return cls.from_default() + raise MachineModelError( + f"Photon machine model {machine_model_id!r} not found. " + "Only 'default' is recognized in v1." + ) + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + +def get_machine_model( + modality: Modality, + machine_model_id: Optional[str] = None, +) -> MachineModelBase: + """Resolve a machine model instance for ``modality``. + + ``machine_model_id=None`` (or "default") returns the project default. + Any other id is looked up in the beam library; missing ids raise + :class:`MachineModelError` so the caller can surface a clean 404. + """ + if modality is Modality.proton_pbs: + if machine_model_id in (None, "default"): + return ProtonMachineModel.from_default() + return ProtonMachineModel.from_id(machine_model_id) + if modality is Modality.photon_imrt: + if machine_model_id in (None, "default"): + return PhotonMachineModel.from_default() + return PhotonMachineModel.from_id(machine_model_id) + raise MachineModelError(f"Unsupported modality: {modality!r}") + + +__all__ = [ + "MachineModelBase", + "MachineModelError", + "PhotonMachineModel", + "ProtonMachineModel", + "get_machine_model", +] diff --git a/src/radiarch/services/persistence.py b/src/radiarch/services/persistence.py index 625e488..c1829e4 100644 --- a/src/radiarch/services/persistence.py +++ b/src/radiarch/services/persistence.py @@ -219,6 +219,41 @@ def save( self._save_index(index) return paths + # ---- deletes ------------------------------------------------------ + + def delete_by_id(self, geometry_id: str) -> bool: + """Remove a geometry directory and scrub its cache_key from the index. + + Returns True when a geometry was actually deleted, False when the + id was unknown. Safe to call against partial state (missing meta, + missing files, stale index entry) — each step is defensive. + """ + root = self.base_dir / geometry_id + if not root.exists(): + return False + + # Scrub the cache index first. If we're about to delete files, + # the index entry should disappear atomically enough that a + # racing reader never sees "cache hit → files gone". + cache_key = self._read_cache_key(root) + if cache_key is not None: + index = self._load_index() + if index.get(cache_key) == geometry_id: + index.pop(cache_key) + self._save_index(index) + + shutil.rmtree(root, ignore_errors=True) + return True + + def _read_cache_key(self, root: Path) -> Optional[str]: + meta = root / META_FILENAME + if not meta.exists(): + return None + try: + return json.loads(meta.read_text()).get("cache_key") + except (OSError, json.JSONDecodeError): + return None + # ---- debugging helpers ------------------------------------------- def list_ids(self) -> list[str]: diff --git a/src/radiarch/services/photon_beamlets.py b/src/radiarch/services/photon_beamlets.py new file mode 100644 index 0000000..fda6a98 --- /dev/null +++ b/src/radiarch/services/photon_beamlets.py @@ -0,0 +1,152 @@ +"""Generate a photon-IMRT beam model from a geometry + beam set. + +Adapter between Service 2's modality-neutral types and OpenTPS's +``PhotonPlan`` construction. The function below is the only place that +knows how to call OpenTPS for photon plan assembly; the rest of the +service treats beamlets as opaque "fluence elements." + +v1 scope: each beam carries one ``PlanPhotonSegment`` (one MLC aperture +shape). The "fluence element" count for that beam is the number of +beamlets that fit inside the jaw opening at the beamlet pixel size. +True multi-segment IMRT (multiple MLC apertures per beam, computed by +an inverse planner) is a v2 enhancement. + +Beamlet grid math: beamlets are square-ish pixels in beam's-eye-view +(BEV) coordinates at isocenter. The grid extent is the jaw opening; the +per-pixel size comes from ``DeliveryParams.beamlet_size_mm``. So a +20×20 cm jaw with 5×5 mm beamlets gives a 40×40 = 1600 element grid. +After MLC clipping (open beamlets only, no shielded regions in v1) +``active_beamlets`` is the same as ``element_count`` for now — when +multi-segment / inverse-planning lands, this will diverge. +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +from loguru import logger + +from ..models.beam_model import ( + BeamSetSpec, + DeliveryParams, + FluenceElementSet, + Modality, + PerBeamElements, +) +from .machine_model import PhotonMachineModel + + +@dataclass +class PhotonBuildResult: + """Internal bundle — the FluenceElementSet plus the OpenTPS plan.""" + + fluence_elements: FluenceElementSet + plan: Any # OpenTPS PhotonPlan (or test double) + + +def generate_photon_beamlets( + ct: Any, + patient: Any, + machine_model: PhotonMachineModel, + beam_set: BeamSetSpec, + params: DeliveryParams, + monitor_units_per_beam: float = 5000.0, +) -> PhotonBuildResult: + """Build a photon plan and summarize its beamlet grid. + + Parameters + ---------- + ct, patient + Loaded OpenTPS objects (kept in the signature for symmetry with + the proton builder; v1 photon construction doesn't need patient + contours, but downstream dose calculation will). + machine_model + Resolved :class:`PhotonMachineModel` — supplies the default jaw + opening when the request doesn't override it. + beam_set + Geometric beam configuration. + params + :class:`DeliveryParams` carrying ``beamlet_size_mm`` and + ``jaw_opening_mm``. + monitor_units_per_beam + MU per beam. Default 5000 — clinical norm. Same value used by + the existing photon_ccc workflow. + """ + # Lazy import — OpenTPS is heavy. + from opentps.core.data.plan._photonPlan import PhotonPlan + from opentps.core.data.plan._planPhotonBeam import PlanPhotonBeam + from opentps.core.data.plan._planPhotonSegment import PlanPhotonSegment + + # Resolve the actual jaw opening: request → machine default. + jaw_x, jaw_y = _resolve_jaw_opening(params, machine_model) + bx, by = params.beamlet_size_mm or (5.0, 5.0) + grid_dims = ( + max(1, math.ceil(jaw_x / bx)), + max(1, math.ceil(jaw_y / by)), + ) + elements_per_beam = grid_dims[0] * grid_dims[1] + + logger.info( + "Building photon plan: %d beams, jaw=%.1f×%.1f mm, " + "beamlet=%.1f×%.1f mm → grid %s (%d elements/beam)", + len(beam_set.beams), jaw_x, jaw_y, bx, by, grid_dims, elements_per_beam, + ) + + photon_plan = PhotonPlan() + per_beam: List[PerBeamElements] = [] + + for beam_spec in beam_set.beams: + beam = PlanPhotonBeam() + beam.gantryAngle = beam_spec.gantry_deg + beam.couchAngle = beam_spec.couch_deg + + segment = PlanPhotonSegment() + segment.monitorUnits = monitor_units_per_beam + # OpenTPS expects [-half, +half] for jaw opening — convert from + # an absolute opening size centered on isocenter. + segment.jawOpeningMM = [-jaw_x / 2.0, jaw_x / 2.0] + beam.segments = [segment] + + photon_plan.beams.append(beam) + per_beam.append( + PerBeamElements( + beam_id=beam_spec.beam_id, + element_count=elements_per_beam, + grid_dims=grid_dims, + active_beamlets=elements_per_beam, # see module docstring + ) + ) + + fluence = FluenceElementSet( + total_count=elements_per_beam * len(beam_set.beams), + per_beam=per_beam, + ) + return PhotonBuildResult(fluence_elements=fluence, plan=photon_plan) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _resolve_jaw_opening( + params: DeliveryParams, + machine_model: PhotonMachineModel, +) -> Tuple[float, float]: + """Use request override if present, else machine default. + + Treats the machine's ``max_jaw_opening_mm`` as a square default. + Real LINACs have rectangular jaws — when a custom machine model + starts carrying x/y separately, this resolves separately too. + """ + if params.jaw_opening_mm is not None: + # The request gives [-half, +half] semantics or [width_x, width_y]? + # The Pydantic field is a Tuple[float, float]; we treat both + # entries as positive widths in mm. + return float(params.jaw_opening_mm[0]), float(params.jaw_opening_mm[1]) + default = float(machine_model.max_jaw_opening_mm) + return default, default + + +__all__ = ["PhotonBuildResult", "generate_photon_beamlets"] diff --git a/src/radiarch/services/proton_spots.py b/src/radiarch/services/proton_spots.py new file mode 100644 index 0000000..65d8b0f --- /dev/null +++ b/src/radiarch/services/proton_spots.py @@ -0,0 +1,211 @@ +"""Generate a proton-PBS beam model from a geometry + beam set. + +This is the adapter between Service 2's modality-neutral types +(:class:`BeamSetSpec`, :class:`DeliveryParams`, :class:`FluenceElementSet`) +and OpenTPS's ``ProtonPlanDesign.buildPlan()``. The function below is +intentionally the *only* place that knows how to call OpenTPS for proton +plan construction; the rest of the Beam Model Service treats spots as +opaque "fluence elements." + +OpenTPS's ``buildPlan()`` returns a ``ProtonPlan`` whose ``beams`` list +contains ``PlanProtonBeam`` objects, each with ``layers`` (per energy) +that hold ``spots`` (per scanning position). We walk that structure to +populate :class:`PerBeamElements` with energy_layers + spots_per_layer. + +The returned ``ProtonPlan`` is the artifact persisted by the Beam Model +Service (pickled). Downstream dose engines consume it directly. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence, Tuple + +from loguru import logger + +from ..models.beam_model import ( + BeamSetSpec, + DeliveryParams, + FluenceElementSet, + Modality, + PerBeamElements, +) +from .machine_model import ProtonMachineModel + + +@dataclass +class ProtonBuildResult: + """Internal bundle — the FluenceElementSet plus the OpenTPS plan.""" + + fluence_elements: FluenceElementSet + plan: Any # OpenTPS ProtonPlan (or test double) + + +def generate_proton_spots( + ct: Any, + patient: Any, + target_contour: Any, + machine_model: ProtonMachineModel, + beam_set: BeamSetSpec, + params: DeliveryParams, + prescription_gy: float = 2.0, +) -> ProtonBuildResult: + """Build a proton plan and summarize its fluence elements. + + Parameters + ---------- + ct + OpenTPS ``CTImage`` for the patient. + patient + OpenTPS ``Patient`` carrying the RTStructs. + target_contour + The target ``ROIContour`` to define the prescription against. + May be ``None`` — the plan still builds but with no target mask, + which exercises the "no target" warning path inside OpenTPS. + machine_model + Resolved :class:`ProtonMachineModel`. + beam_set + Geometric beam configuration. + params + :class:`DeliveryParams` carrying ``spot_spacing_mm`` and + ``layer_spacing_mm``. + prescription_gy + Dose prescription, in Gy. Default 2.0 — clinical norm for one + fraction. Used by ``defineTargetMaskAndPrescription``. + + Returns + ------- + ProtonBuildResult + ``fluence_elements`` for the cache + the raw ``ProtonPlan`` for + persistence. + """ + # Lazy import — keeps fast tests from pulling in OpenTPS. + from opentps.core.data.plan import ProtonPlanDesign + + plan_design = ProtonPlanDesign() + plan_design.ct = ct + plan_design.patient = patient + plan_design.calibration = machine_model.calibration + + # Convert BeamSetSpec → flat angle lists in beam order. + plan_design.gantryAngles = [b.gantry_deg for b in beam_set.beams] + plan_design.couchAngles = [b.couch_deg for b in beam_set.beams] + + if params.spot_spacing_mm is not None: + plan_design.spotSpacing = params.spot_spacing_mm + if params.layer_spacing_mm is not None: + plan_design.layerSpacing = params.layer_spacing_mm + + if target_contour is not None: + plan_design.defineTargetMaskAndPrescription(target_contour, prescription_gy) + else: + logger.warning( + "No target contour supplied — proton plan will build without a " + "prescription mask and may have zero spots." + ) + + logger.info( + "Building proton plan: %d beams, spot_spacing=%s mm, layer_spacing=%s mm", + len(beam_set.beams), + plan_design.spotSpacing, + plan_design.layerSpacing, + ) + proton_plan = plan_design.buildPlan() + + fluence = _summarize_proton_plan(proton_plan, beam_set) + return ProtonBuildResult(fluence_elements=fluence, plan=proton_plan) + + +# --------------------------------------------------------------------------- +# Plan introspection +# --------------------------------------------------------------------------- + +def _summarize_proton_plan(plan: Any, beam_set: BeamSetSpec) -> FluenceElementSet: + """Walk a built ``ProtonPlan`` and produce a :class:`FluenceElementSet`. + + OpenTPS internal layout (best-effort, defensive to API drift): + plan.beams -> [PlanProtonBeam] + beam.layers -> [PlanProtonLayer] + layer.nominalEnergy : float (MeV) + layer.spots : list (or .numberOfSpots / .scanSpotPositions) + """ + per_beam: List[PerBeamElements] = [] + total = 0 + + beams = list(getattr(plan, "beams", []) or []) + # Align to BeamSetSpec by index when possible — OpenTPS preserves + # construction order from gantryAngles. + for idx, beam in enumerate(beams): + beam_id = ( + beam_set.beams[idx].beam_id + if idx < len(beam_set.beams) + else f"beam_{idx}" + ) + energies, spots = _walk_layers(beam) + element_count = sum(spots) + total += element_count + per_beam.append( + PerBeamElements( + beam_id=beam_id, + element_count=element_count, + energy_layers=energies if energies else None, + spots_per_layer=spots if spots else None, + ) + ) + + # If the plan ended up empty (no target / no spots placed), still + # emit one PerBeamElements per beam in the request so downstream + # code can find them by beam_id. + if not per_beam: + per_beam = [ + PerBeamElements(beam_id=b.beam_id, element_count=0) + for b in beam_set.beams + ] + + return FluenceElementSet(total_count=total, per_beam=per_beam) + + +def _walk_layers(beam: Any) -> Tuple[List[float], List[int]]: + """Extract (energies, spots-per-layer) from one PlanProtonBeam. + + OpenTPS exposes spot lists under a few different attribute names + across versions — try the most common shapes and fall back to zero + if none match. + """ + energies: List[float] = [] + spots: List[int] = [] + + layers = list(getattr(beam, "layers", []) or []) + for layer in layers: + energy = ( + getattr(layer, "nominalEnergy", None) + or getattr(layer, "energy", None) + or 0.0 + ) + count = _count_spots(layer) + if count == 0 and energy == 0.0: + # Bogus layer entry — skip it. + continue + energies.append(float(energy)) + spots.append(int(count)) + return energies, spots + + +def _count_spots(layer: Any) -> int: + """Try several OpenTPS attribute paths to count spots in a layer.""" + # Most direct: a Python list of spot objects. + spots = getattr(layer, "spots", None) + if spots is not None and hasattr(spots, "__len__"): + return len(spots) + # Some versions expose a count attribute directly. + n = getattr(layer, "numberOfSpots", None) + if isinstance(n, int): + return n + # Or a 2D array of (x, y) scan positions. + pos = getattr(layer, "scanSpotPositions", None) + if pos is not None and hasattr(pos, "__len__"): + return len(pos) + return 0 + + +__all__ = ["ProtonBuildResult", "generate_proton_spots"] diff --git a/src/radiarch/tasks/beam_model_tasks.py b/src/radiarch/tasks/beam_model_tasks.py new file mode 100644 index 0000000..1d6c4d1 --- /dev/null +++ b/src/radiarch/tasks/beam_model_tasks.py @@ -0,0 +1,91 @@ +"""Celery task for async beam-model builds. + +Mirrors ``build_geometry_job`` from ``geometry_tasks.py``: pull the job +from the store, execute :meth:`BeamModelService.build` with a progress +callback that mirrors stages and progress to the DB row, and on success +stash the resulting ``beam_model_id`` so polling clients can deep-link. +""" + +from __future__ import annotations + +import time + +from celery.exceptions import SoftTimeLimitExceeded +from loguru import logger + +from .celery_app import celery_app +from ..core.store import store +from ..models.beam_model import BeamModelBuildRequest, BeamModelStage +from ..models.job import JobState + + +@celery_app.task( + name="radiarch.beam_model.build", + autoretry_for=(ConnectionError, OSError), + retry_backoff=True, + retry_backoff_max=120, + max_retries=3, +) +def build_beam_model_job(job_id: str, request_payload: dict): + """Execute :meth:`BeamModelService.build` and mirror progress to the job row.""" + # Lazy import keeps Celery worker boot fast and avoids circular imports. + from ..services.beam_model import BeamModelService + + job = store.get_beam_model_job(job_id) + if not job: + logger.error("build_beam_model_job called with unknown job_id=%s", job_id) + return + + t0 = time.monotonic() + + def _on_progress(stage: BeamModelStage, fraction: float, message: str) -> None: + store.update_beam_model_job( + job_id, + state=JobState.running if fraction < 1.0 else job.state, + progress=round(fraction, 3), + stage=stage, + message=message, + ) + + try: + store.update_beam_model_job( + job_id, + state=JobState.running, + progress=0.0, + stage=BeamModelStage.queued, + message="Queued → running", + ) + request = BeamModelBuildRequest.model_validate(request_payload) + service = BeamModelService() + result = service.build(request, progress_callback=_on_progress) + + except SoftTimeLimitExceeded: + logger.error("Beam model build timed out for job %s", job_id) + store.update_beam_model_job( + job_id, + state=JobState.failed, + progress=1.0, + stage=BeamModelStage.done, + message="Timed out", + ) + return + + except Exception as exc: # pragma: no cover — exercised in tests + logger.exception("Beam model build failed for job %s", job_id) + store.update_beam_model_job( + job_id, + state=JobState.failed, + progress=1.0, + stage=BeamModelStage.done, + message=f"{type(exc).__name__}: {exc}", + ) + return + + store.update_beam_model_job( + job_id, + state=JobState.succeeded, + progress=1.0, + stage=BeamModelStage.done, + message=f"Built in {time.monotonic() - t0:.1f}s", + beam_model_id=result.beam_model_id, + ) diff --git a/src/radiarch/tasks/celery_app.py b/src/radiarch/tasks/celery_app.py index a0c1168..e083936 100644 --- a/src/radiarch/tasks/celery_app.py +++ b/src/radiarch/tasks/celery_app.py @@ -15,7 +15,11 @@ "radiarch", broker=settings.broker_url, backend=settings.result_backend, - include=["radiarch.tasks.plan_tasks"], + include=[ + "radiarch.tasks.plan_tasks", + "radiarch.tasks.geometry_tasks", + "radiarch.tasks.beam_model_tasks", + ], ) celery_app.conf.update( diff --git a/src/radiarch/tasks/geometry_tasks.py b/src/radiarch/tasks/geometry_tasks.py new file mode 100644 index 0000000..e619060 --- /dev/null +++ b/src/radiarch/tasks/geometry_tasks.py @@ -0,0 +1,118 @@ +"""Celery task for async geometry builds. + +Lives alongside ``plan_tasks`` and follows the same pattern: + +* Accept a ``job_id`` and the serialized :class:`GeometryBuildRequest`. +* Update the DB-backed ``geometry_jobs`` row as the pipeline advances + through its stages, via the ``progress_callback`` that + :meth:`GeometryService.build` now accepts. +* On success, stash the returned ``geometry_id`` on the job row so the + ``GET /api/v1/geometry/jobs/{job_id}`` endpoint can deep-link clients + to the finished result. + +In ``environment=dev`` Celery is in eager mode, so the task runs +synchronously in the API worker process — tests exercise the same code +path as production without requiring a broker. +""" + +from __future__ import annotations + +import time + +from celery.exceptions import SoftTimeLimitExceeded +from loguru import logger + +from .celery_app import celery_app +from ..core.store import store +from ..models.geometry import GeometryBuildRequest, GeometryStage +from ..models.job import JobState + + +@celery_app.task( + name="radiarch.geometry.build", + autoretry_for=(ConnectionError, OSError), + retry_backoff=True, + retry_backoff_max=120, + max_retries=3, +) +def build_geometry_job(job_id: str, request_payload: dict): + """Execute :meth:`GeometryService.build` and mirror progress to the job row.""" + # Lazy import — keeps Celery worker start fast and avoids circular + # imports between tasks and the service layer. + from ..services.geometry import GeometryService + + job = store.get_geometry_job(job_id) + if not job: + logger.error("build_geometry_job called with unknown job_id=%s", job_id) + return + + t0 = time.monotonic() + + def _on_progress(stage: GeometryStage, fraction: float, message: str) -> None: + # Map the service's stage + fraction onto the job row. ETA is a + # rough extrapolation; rather than nothing. + elapsed = time.monotonic() - t0 + eta = None + if fraction > 0.05: + eta = max(0.0, elapsed * (1.0 - fraction) / fraction) + store.update_geometry_job( + job_id, + state=JobState.running if fraction < 1.0 else job.state, + progress=round(fraction, 3), + stage=stage, + message=message, + ) + # Separate write for ETA so we don't clobber stage transitions; + # update_geometry_job doesn't take eta_seconds (see note on + # SQL persistence above) — we just log it for now. + if eta is not None: + logger.debug( + "geometry job %s stage=%s progress=%.2f eta=%.1fs", + job_id, + stage.value if hasattr(stage, "value") else stage, + fraction, + eta, + ) + + try: + store.update_geometry_job( + job_id, + state=JobState.running, + progress=0.0, + stage=GeometryStage.queued, + message="Queued → running", + ) + request = GeometryBuildRequest.model_validate(request_payload) + service = GeometryService() + result = service.build(request, progress_callback=_on_progress) + + except SoftTimeLimitExceeded: + logger.error("Geometry build timed out for job %s", job_id) + store.update_geometry_job( + job_id, + state=JobState.failed, + progress=1.0, + stage=GeometryStage.done, + message="Timed out", + ) + return + + except Exception as exc: # pragma: no cover — defensive; exercised in tests + logger.exception("Geometry build failed for job %s", job_id) + store.update_geometry_job( + job_id, + state=JobState.failed, + progress=1.0, + stage=GeometryStage.done, + message=f"{type(exc).__name__}: {exc}", + ) + return + + store.update_geometry_job( + job_id, + state=JobState.succeeded, + progress=1.0, + stage=GeometryStage.done, + message=f"Built in {time.monotonic() - t0:.1f}s", + geometry_id=result.geometry_id, + ) diff --git a/tests/opentps/core/test_mcsquare_interface.py b/tests/opentps/core/test_mcsquare_interface.py index e680a21..e045429 100644 --- a/tests/opentps/core/test_mcsquare_interface.py +++ b/tests/opentps/core/test_mcsquare_interface.py @@ -6,6 +6,10 @@ DVH, DICE, and WET/SPR computation. Test data: tests/opentps/core/MCsquare-python_interface/data/ + +Note: ``test_mcsquare_simulation`` requires a Linux MCsquare binary and +is skipped on macOS (the vendored OpenTPS strips Darwin binaries — see +README § "About OpenTPS"). """ import os @@ -13,6 +17,11 @@ import pytest import numpy as np +requires_mcsquare = pytest.mark.skipif( + sys.platform == "darwin", + reason="MCsquare binary not shipped for Darwin; vendored OpenTPS is Linux-only.", +) + # --------------------------------------------------------------------------- # 1. Patient data loading @@ -197,6 +206,7 @@ class TestMCsquareSimulation: """Test MCsquare dose calculation through the python_interface.""" @pytest.mark.slow + @requires_mcsquare def test_mcsquare_simulation(self, mcsquare_sample_data_dir, mcsquare_interface_dir, tmp_path): """Run MCsquare simulation; verify dose output exists. diff --git a/tests/opentps/core/test_opentps_core.py b/tests/opentps/core/test_opentps_core.py index b28e805..8ece2a2 100644 --- a/tests/opentps/core/test_opentps_core.py +++ b/tests/opentps/core/test_opentps_core.py @@ -6,12 +6,23 @@ Requires: vendored opentps.core at src/opentps/core/ Test data: tests/opentps/core/opentps-testData/SimpleFantomWithStruct/ + +Note: MCsquare dose-calculation tests require the Linux MCsquare binary, +which is the only variant shipped with the vendored OpenTPS copy (Darwin +and Windows binaries are stripped — see README § "About OpenTPS"). +Those tests skip automatically on macOS. """ import os +import sys import pytest import numpy as np +requires_mcsquare = pytest.mark.skipif( + sys.platform == "darwin", + reason="MCsquare binary not shipped for Darwin; vendored OpenTPS is Linux-only.", +) + # --------------------------------------------------------------------------- # 1. Data loading @@ -136,6 +147,7 @@ def test_mcsquare_dose_calculator_setup(self): print("MCsquareDoseCalculator configured successfully") @pytest.mark.slow + @requires_mcsquare def test_dose_computation_basic(self, simple_fantom_dir): """Run MCsquare dose calc with low primaries; verify dose is non-zero. @@ -233,6 +245,7 @@ class TestDVH: """Test Dose-Volume Histogram computation.""" @pytest.mark.slow + @requires_mcsquare def test_dvh_from_dose(self, simple_fantom_dir): """Compute DVH from dose + contour; verify dose stats are positive. diff --git a/tests/services/test_beam_model_models.py b/tests/services/test_beam_model_models.py new file mode 100644 index 0000000..48c06f8 --- /dev/null +++ b/tests/services/test_beam_model_models.py @@ -0,0 +1,221 @@ +"""Unit tests for the Pydantic schemas in radiarch.models.beam_model.""" + +from __future__ import annotations + +import pytest + +from radiarch.models.beam_model import ( + BeamModelBuildRequest, + BeamSetSpec, + BeamSpec, + DeliveryParams, + FluenceElementSet, + Modality, + PerBeamElements, +) + + +# --------------------------------------------------------------------------- +# BeamSpec / BeamSetSpec validation +# --------------------------------------------------------------------------- + +class TestBeamSpec: + def test_rejects_gantry_out_of_range(self) -> None: + with pytest.raises(ValueError): + BeamSpec(beam_id="B0", gantry_deg=361.0) + + def test_rejects_gantry_at_360(self) -> None: + # IEC 61217: gantry is [0, 360), 360 wraps to 0. + with pytest.raises(ValueError): + BeamSpec(beam_id="B0", gantry_deg=360.0) + + def test_rejects_couch_out_of_range(self) -> None: + with pytest.raises(ValueError): + BeamSpec(beam_id="B0", gantry_deg=0.0, couch_deg=200.0) + + def test_defaults_couch_and_collimator_to_zero(self) -> None: + b = BeamSpec(beam_id="B0", gantry_deg=90.0) + assert b.couch_deg == 0.0 + assert b.collimator_deg == 0.0 + + +class TestBeamSetSpec: + def test_unique_beam_ids_enforced(self) -> None: + with pytest.raises(ValueError, match="unique"): + BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[ + BeamSpec(beam_id="A", gantry_deg=0), + BeamSpec(beam_id="A", gantry_deg=90), + ], + ) + + def test_minimum_one_beam(self) -> None: + with pytest.raises(ValueError): + BeamSetSpec(isocenter_mm=(0, 0, 0), beams=[]) + + def test_max_nine_beams(self) -> None: + beams = [BeamSpec(beam_id=f"B{i}", gantry_deg=0) for i in range(10)] + with pytest.raises(ValueError): + BeamSetSpec(isocenter_mm=(0, 0, 0), beams=beams) + + +# --------------------------------------------------------------------------- +# DeliveryParams modality filtering +# --------------------------------------------------------------------------- + +class TestDeliveryParams: + def test_proton_filter_keeps_only_proton_fields(self) -> None: + p = DeliveryParams( + spot_spacing_mm=4.0, + layer_spacing_mm=3.0, + beamlet_size_mm=(7, 7), # photon — should be filtered out + mlc_leaf_width_mm=2.5, # photon — should be filtered out + ) + out = p.for_modality(Modality.proton_pbs) + assert "spot_spacing_mm" in out + assert out["spot_spacing_mm"] == 4.0 + assert "beamlet_size_mm" not in out + assert "mlc_leaf_width_mm" not in out + + def test_photon_filter_keeps_only_photon_fields(self) -> None: + p = DeliveryParams( + spot_spacing_mm=4.0, # proton — filtered + beamlet_size_mm=(7, 7), + mlc_leaf_width_mm=2.5, + ) + out = p.for_modality(Modality.photon_imrt) + assert "beamlet_size_mm" in out + assert "spot_spacing_mm" not in out + + def test_tuple_normalized_to_list(self) -> None: + # Tuples and lists hash differently in JSON but are semantically + # equivalent — for_modality should normalize. + p = DeliveryParams(beamlet_size_mm=(5.0, 5.0)) + out = p.for_modality(Modality.photon_imrt) + assert out["beamlet_size_mm"] == [5.0, 5.0] + + +# --------------------------------------------------------------------------- +# Cache key behaviour — the centerpiece +# --------------------------------------------------------------------------- + +def _request(**overrides) -> BeamModelBuildRequest: + base = dict( + geometry_id="g-1", + modality=Modality.proton_pbs, + machine_model_id=None, + beam_set=BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[BeamSpec(beam_id="B1", gantry_deg=0)], + ), + delivery_params=DeliveryParams(), + ) + base.update(overrides) + return BeamModelBuildRequest(**base) + + +class TestCacheKey: + def test_deterministic(self) -> None: + assert _request().compute_cache_key() == _request().compute_cache_key() + + def test_changes_with_geometry_id(self) -> None: + a = _request(geometry_id="g-1").compute_cache_key() + b = _request(geometry_id="g-2").compute_cache_key() + assert a != b + + def test_changes_with_modality(self) -> None: + a = _request(modality=Modality.proton_pbs).compute_cache_key() + b = _request(modality=Modality.photon_imrt).compute_cache_key() + assert a != b + + def test_invariant_to_beam_order(self) -> None: + beams_a = [ + BeamSpec(beam_id="B1", gantry_deg=0), + BeamSpec(beam_id="B2", gantry_deg=90), + ] + beams_b = list(reversed(beams_a)) + a = _request(beam_set=BeamSetSpec(isocenter_mm=(0, 0, 0), beams=beams_a)).compute_cache_key() + b = _request(beam_set=BeamSetSpec(isocenter_mm=(0, 0, 0), beams=beams_b)).compute_cache_key() + assert a == b + + def test_proton_param_change_does_not_bust_photon_cache(self) -> None: + """The headline invariant — modality filter actually filters.""" + photon_a = _request( + modality=Modality.photon_imrt, + delivery_params=DeliveryParams(spot_spacing_mm=4.0, beamlet_size_mm=(5, 5)), + ).compute_cache_key() + photon_b = _request( + modality=Modality.photon_imrt, + delivery_params=DeliveryParams(spot_spacing_mm=8.0, beamlet_size_mm=(5, 5)), + ).compute_cache_key() + assert photon_a == photon_b, "spot_spacing change must not affect photon hash" + + def test_photon_param_change_does_not_bust_proton_cache(self) -> None: + proton_a = _request( + modality=Modality.proton_pbs, + delivery_params=DeliveryParams(spot_spacing_mm=5.0, beamlet_size_mm=(5, 5)), + ).compute_cache_key() + proton_b = _request( + modality=Modality.proton_pbs, + delivery_params=DeliveryParams(spot_spacing_mm=5.0, beamlet_size_mm=(8, 8)), + ).compute_cache_key() + assert proton_a == proton_b + + def test_changes_with_relevant_param(self) -> None: + proton_a = _request( + delivery_params=DeliveryParams(spot_spacing_mm=4.0) + ).compute_cache_key() + proton_b = _request( + delivery_params=DeliveryParams(spot_spacing_mm=8.0) + ).compute_cache_key() + assert proton_a != proton_b + + def test_excludes_plan_id(self) -> None: + a = _request(plan_id="plan-A").compute_cache_key() + b = _request(plan_id="plan-B").compute_cache_key() + assert a == b, "plan_id is a downstream reference, not a build input" + + +# --------------------------------------------------------------------------- +# FluenceElementSet totals invariant +# --------------------------------------------------------------------------- + +class TestFluenceElementSet: + def test_totals_must_match(self) -> None: + with pytest.raises(ValueError, match="total_count"): + FluenceElementSet( + total_count=10, + per_beam=[ + PerBeamElements(beam_id="B1", element_count=3), + PerBeamElements(beam_id="B2", element_count=4), + ], + ) + + def test_happy_path(self) -> None: + fe = FluenceElementSet( + total_count=7, + per_beam=[ + PerBeamElements(beam_id="B1", element_count=3), + PerBeamElements(beam_id="B2", element_count=4), + ], + ) + assert fe.total_count == 7 + + def test_proton_layer_arrays_must_align(self) -> None: + with pytest.raises(ValueError, match="length"): + PerBeamElements( + beam_id="B1", + element_count=10, + energy_layers=[100.0, 110.0], + spots_per_layer=[5], + ) + + def test_proton_layer_arrays_aligned_ok(self) -> None: + pbe = PerBeamElements( + beam_id="B1", + element_count=12, + energy_layers=[100.0, 110.0, 120.0], + spots_per_layer=[4, 4, 4], + ) + assert sum(pbe.spots_per_layer) == pbe.element_count diff --git a/tests/services/test_beam_model_service.py b/tests/services/test_beam_model_service.py new file mode 100644 index 0000000..8168a80 --- /dev/null +++ b/tests/services/test_beam_model_service.py @@ -0,0 +1,229 @@ +"""Unit tests for radiarch.services.beam_model.BeamModelService. + +These tests stub out ``_load_geometry``, ``_load_machine_model``, and +the modality builders so the orchestrator runs end-to-end without +OpenTPS, without MCsquare, and without DICOM. We assert: cache lookup, +modality dispatch, persistence round-trip, validation errors. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest + +from radiarch.models.beam_model import ( + BeamModelBuildRequest, + BeamSetSpec, + BeamSpec, + DeliveryParams, + FluenceElementSet, + Modality, + PerBeamElements, +) +from radiarch.services.beam_model import BeamModelService, _LoadedGeometry +from radiarch.services.machine_model import ( + PhotonMachineModel, + ProtonMachineModel, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _request(modality: Modality = Modality.proton_pbs, **overrides) -> BeamModelBuildRequest: + base = dict( + geometry_id="g-1", + modality=modality, + machine_model_id=None, + beam_set=BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[BeamSpec(beam_id="B1", gantry_deg=0)], + ), + delivery_params=DeliveryParams(), + ) + base.update(overrides) + return BeamModelBuildRequest(**base) + + +@dataclass +class _MockPlan: + note: str = "mock plan" + + +def _stub_loaded_geometry() -> _LoadedGeometry: + return _LoadedGeometry( + geometry_id="g-1", + ct=object(), + patient=object(), + target_contour=object(), + ) + + +# --------------------------------------------------------------------------- +# Happy paths +# --------------------------------------------------------------------------- + +class TestProtonHappyPath: + def test_builds_end_to_end(self, tmp_path: Path, monkeypatch): + service = BeamModelService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load_geometry", lambda gid: _stub_loaded_geometry()) + + # Stub the modality builder to return a deterministic result. + from radiarch.services import beam_model as bm_module + + def _fake_proton(req, geom, mm): + from radiarch.services.proton_spots import ProtonBuildResult + return ProtonBuildResult( + fluence_elements=FluenceElementSet( + total_count=15, + per_beam=[PerBeamElements(beam_id="B1", element_count=15, + energy_layers=[100.0, 110.0], + spots_per_layer=[7, 8])], + ), + plan=_MockPlan(note="proton"), + ) + monkeypatch.setattr(BeamModelService, "_build_proton", staticmethod(_fake_proton)) + + result = service.build(_request()) + assert result.modality is Modality.proton_pbs + assert result.geometry_id == "g-1" + assert result.fluence_elements.total_count == 15 + assert result.machine_model_id == "default" + assert (tmp_path / result.beam_model_id / "plan.pkl").exists() + assert (tmp_path / result.beam_model_id / "meta.json").exists() + + +class TestPhotonHappyPath: + def test_builds_end_to_end(self, tmp_path: Path, monkeypatch): + service = BeamModelService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load_geometry", lambda gid: _stub_loaded_geometry()) + + def _fake_photon(req, geom, mm): + from radiarch.services.photon_beamlets import PhotonBuildResult + return PhotonBuildResult( + fluence_elements=FluenceElementSet( + total_count=400, + per_beam=[PerBeamElements(beam_id="B1", element_count=400, + grid_dims=(20, 20), + active_beamlets=400)], + ), + plan=_MockPlan(note="photon"), + ) + monkeypatch.setattr(BeamModelService, "_build_photon", staticmethod(_fake_photon)) + + result = service.build(_request(modality=Modality.photon_imrt)) + assert result.modality is Modality.photon_imrt + assert result.fluence_elements.per_beam[0].grid_dims == (20, 20) + + +# --------------------------------------------------------------------------- +# Caching +# --------------------------------------------------------------------------- + +class TestCaching: + def test_second_build_hits_cache(self, tmp_path: Path, monkeypatch): + service = BeamModelService(base_dir=tmp_path) + load_calls = {"n": 0} + + def _counting_load(gid): + load_calls["n"] += 1 + return _stub_loaded_geometry() + monkeypatch.setattr(service, "_load_geometry", _counting_load) + + def _fake_proton(req, geom, mm): + from radiarch.services.proton_spots import ProtonBuildResult + return ProtonBuildResult( + fluence_elements=FluenceElementSet( + total_count=1, + per_beam=[PerBeamElements(beam_id="B1", element_count=1)], + ), + plan=_MockPlan(), + ) + monkeypatch.setattr(BeamModelService, "_build_proton", staticmethod(_fake_proton)) + + req = _request() + r1 = service.build(req) + r2 = service.build(req) + assert r1.beam_model_id == r2.beam_model_id + assert load_calls["n"] == 1, "cache hit must skip _load_geometry" + + def test_modality_change_misses_cache(self, tmp_path: Path, monkeypatch): + service = BeamModelService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load_geometry", lambda gid: _stub_loaded_geometry()) + monkeypatch.setattr(BeamModelService, "_build_proton", staticmethod( + lambda req, geom, mm: _build_dummy("proton"))) + monkeypatch.setattr(BeamModelService, "_build_photon", staticmethod( + lambda req, geom, mm: _build_dummy("photon"))) + + r_proton = service.build(_request(modality=Modality.proton_pbs)) + r_photon = service.build(_request(modality=Modality.photon_imrt)) + assert r_proton.beam_model_id != r_photon.beam_model_id + + +def _build_dummy(label: str): + from radiarch.services.proton_spots import ProtonBuildResult + return ProtonBuildResult( + fluence_elements=FluenceElementSet( + total_count=1, + per_beam=[PerBeamElements(beam_id="B1", element_count=1)], + ), + plan=_MockPlan(note=label), + ) + + +# --------------------------------------------------------------------------- +# Error surfaces +# --------------------------------------------------------------------------- + +class TestErrors: + def test_unknown_geometry_id_raises(self, tmp_path: Path, monkeypatch): + """When the upstream geometry doesn't exist, surface a clean error.""" + service = BeamModelService(base_dir=tmp_path) + + def _missing_geometry(gid): + raise ValueError(f"geometry_id {gid!r} not found") + monkeypatch.setattr(service, "_load_geometry", _missing_geometry) + + with pytest.raises(ValueError, match="not found"): + service.build(_request(geometry_id="does-not-exist")) + + def test_proton_modality_with_photon_machine_raises(self, tmp_path: Path, monkeypatch): + service = BeamModelService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load_geometry", lambda gid: _stub_loaded_geometry()) + + # Force the loader to return the wrong machine model type. + def _wrong_machine(modality, mm_id): + return PhotonMachineModel.from_default() # but request is proton + monkeypatch.setattr(BeamModelService, "_load_machine_model", staticmethod(_wrong_machine)) + + with pytest.raises(ValueError, match="Proton modality requires"): + service.build(_request(modality=Modality.proton_pbs)) + + +# --------------------------------------------------------------------------- +# Progress callback +# --------------------------------------------------------------------------- + +class TestProgressCallback: + def test_callback_fires_for_each_stage(self, tmp_path: Path, monkeypatch): + service = BeamModelService(base_dir=tmp_path) + monkeypatch.setattr(service, "_load_geometry", lambda gid: _stub_loaded_geometry()) + monkeypatch.setattr(BeamModelService, "_build_proton", staticmethod( + lambda req, geom, mm: _build_dummy("p"))) + + events = [] + def _cb(stage, frac, msg): + events.append((stage.value, round(frac, 2))) + + service.build(_request(), progress_callback=_cb) + stages = [e[0] for e in events] + assert "loading_geometry" in stages + assert "loading_machine_model" in stages + assert "building_beams" in stages + assert "computing_elements" in stages + assert "persisting" in stages + assert stages[-1] == "done" diff --git a/tests/services/test_beam_persistence.py b/tests/services/test_beam_persistence.py new file mode 100644 index 0000000..e2bf1f1 --- /dev/null +++ b/tests/services/test_beam_persistence.py @@ -0,0 +1,175 @@ +"""Unit tests for radiarch.services.beam_persistence. + +Uses a small picklable plan stand-in so we don't need OpenTPS to test +the persistence layer's contract. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from radiarch.models.beam_model import ( + BeamModelResult, + FluenceElementSet, + Modality, + PerBeamElements, +) +from radiarch.services.beam_persistence import ( + META_FILENAME, + PLAN_FILENAME, + BeamModelStore, +) + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + +@dataclass +class _MockProtonPlan: + """Minimal picklable stand-in for OpenTPS's ProtonPlan.""" + + spots: int + layers: int + note: str = "mock proton plan" + + +def _make_result(plan_uri: str, cache_key: str = "abc") -> BeamModelResult: + return BeamModelResult( + beam_model_id="bm-1", + geometry_id="g-1", + modality=Modality.proton_pbs, + fluence_elements=FluenceElementSet( + total_count=12, + per_beam=[PerBeamElements(beam_id="B1", element_count=12)], + ), + beam_model_ref_uri=plan_uri, + machine_model_id="default", + cache_key=cache_key, + ) + + +# --------------------------------------------------------------------------- +# Roundtrip +# --------------------------------------------------------------------------- + +class TestPickleRoundtrip: + def test_save_then_load_plan_roundtrips(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + plan = _MockProtonPlan(spots=42, layers=7) + result = _make_result(str(tmp_path / "bm-1" / PLAN_FILENAME)) + + store.save(beam_model_id="bm-1", cache_key="k1", plan=plan, result=result) + + loaded = store.load_plan("bm-1") + assert loaded.spots == 42 + assert loaded.layers == 7 + assert loaded.note == "mock proton plan" + + def test_save_writes_expected_files(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + result = _make_result(str(tmp_path / "bm-1" / PLAN_FILENAME)) + store.save( + beam_model_id="bm-1", + cache_key="k1", + plan=_MockProtonPlan(spots=1, layers=1), + result=result, + ) + assert (tmp_path / "bm-1" / PLAN_FILENAME).exists() + assert (tmp_path / "bm-1" / META_FILENAME).exists() + + +# --------------------------------------------------------------------------- +# Cache index +# --------------------------------------------------------------------------- + +class TestCacheIndex: + def _save(self, store: BeamModelStore, tmp_path: Path, beam_model_id: str, cache_key: str): + result = _make_result(str(tmp_path / beam_model_id / PLAN_FILENAME), cache_key=cache_key) + store.save( + beam_model_id=beam_model_id, + cache_key=cache_key, + plan=_MockProtonPlan(spots=1, layers=1), + result=result, + ) + + def test_lookup_roundtrip(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + self._save(store, tmp_path, "bm-1", "deadbeef") + hit = store.lookup_by_cache_key("deadbeef") + assert hit is not None + assert hit.beam_model_id == "bm-1" + + def test_lookup_miss_returns_none(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + assert store.lookup_by_cache_key("nope") is None + assert store.get_by_id("nope") is None + + +# --------------------------------------------------------------------------- +# Atomic retry / delete +# --------------------------------------------------------------------------- + +class TestAtomicity: + def test_save_overwrites_cleanly_on_retry(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + result = _make_result(str(tmp_path / "bm-1" / PLAN_FILENAME)) + + # First save with a small plan + store.save( + beam_model_id="bm-1", cache_key="k", + plan=_MockProtonPlan(spots=1, layers=1), + result=result, + ) + # Second save with a different plan, same id + store.save( + beam_model_id="bm-1", cache_key="k", + plan=_MockProtonPlan(spots=99, layers=99), + result=result, + ) + + loaded = store.load_plan("bm-1") + assert loaded.spots == 99 + + def test_delete_scrubs_index_and_files(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + result = _make_result(str(tmp_path / "bm-1" / PLAN_FILENAME), cache_key="k1") + store.save( + beam_model_id="bm-1", cache_key="k1", + plan=_MockProtonPlan(spots=1, layers=1), + result=result, + ) + assert store.lookup_by_cache_key("k1") is not None + + deleted = store.delete_by_id("bm-1") + assert deleted is True + assert not (tmp_path / "bm-1").exists() + assert store.lookup_by_cache_key("k1") is None + + def test_delete_unknown_id_returns_false(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + assert store.delete_by_id("nope") is False + + +# --------------------------------------------------------------------------- +# Listing +# --------------------------------------------------------------------------- + +class TestListing: + def test_list_ids_excludes_tmp_dirs(self, tmp_path: Path) -> None: + store = BeamModelStore(tmp_path) + # Simulate a leftover tmp dir from a crashed write + (tmp_path / ".bm-2.tmp.xyz").mkdir() + # And a directory without a meta.json (incomplete write) + (tmp_path / "bm-3").mkdir() + + result = _make_result(str(tmp_path / "bm-1" / PLAN_FILENAME)) + store.save( + beam_model_id="bm-1", cache_key="k", + plan=_MockProtonPlan(spots=1, layers=1), + result=result, + ) + assert store.list_ids() == ["bm-1"] diff --git a/tests/services/test_machine_model.py b/tests/services/test_machine_model.py new file mode 100644 index 0000000..43182e7 --- /dev/null +++ b/tests/services/test_machine_model.py @@ -0,0 +1,69 @@ +"""Unit tests for radiarch.services.machine_model. + +These tests exercise the factory's dispatch logic without actually +loading the BDL / calibration files — those are heavyweight and only +get touched when ``.bdl`` / ``.calibration`` are accessed. +""" + +from __future__ import annotations + +import pytest + +from radiarch.models.beam_model import Modality +from radiarch.services.machine_model import ( + MachineModelError, + PhotonMachineModel, + ProtonMachineModel, + get_machine_model, +) + + +class TestProtonMachineModel: + def test_default_constructor_does_no_io(self) -> None: + # Should be instant — no disk reads until .bdl / .calibration. + mm = ProtonMachineModel.from_default() + assert mm.modality is Modality.proton_pbs + assert mm.machine_model_id == "default" + + def test_from_id_raises_when_path_missing(self) -> None: + with pytest.raises(MachineModelError, match="not found"): + ProtonMachineModel.from_id("definitely-does-not-exist-xyz") + + +class TestPhotonMachineModel: + def test_default_has_sensible_values(self) -> None: + mm = PhotonMachineModel.from_default() + assert mm.modality is Modality.photon_imrt + assert mm.machine_model_id == "default" + assert mm.mlc_leaf_width_mm > 0 + assert mm.max_jaw_opening_mm > 0 + assert mm.beam_quality_mv > 0 + + def test_from_id_only_recognizes_default(self) -> None: + with pytest.raises(MachineModelError, match="not found"): + PhotonMachineModel.from_id("varian-truebeam-customized") + + def test_from_id_default_returns_default(self) -> None: + assert PhotonMachineModel.from_id("default").machine_model_id == "default" + + +class TestFactory: + def test_dispatches_to_proton_for_proton_modality(self) -> None: + mm = get_machine_model(Modality.proton_pbs) + assert isinstance(mm, ProtonMachineModel) + + def test_dispatches_to_photon_for_photon_modality(self) -> None: + mm = get_machine_model(Modality.photon_imrt) + assert isinstance(mm, PhotonMachineModel) + + def test_none_id_returns_default(self) -> None: + mm = get_machine_model(Modality.proton_pbs, None) + assert mm.machine_model_id == "default" + + def test_explicit_default_id_returns_default(self) -> None: + mm = get_machine_model(Modality.proton_pbs, "default") + assert mm.machine_model_id == "default" + + def test_unknown_id_raises_clean_error(self) -> None: + with pytest.raises(MachineModelError): + get_machine_model(Modality.proton_pbs, "not-a-real-machine") diff --git a/tests/services/test_photon_beamlets.py b/tests/services/test_photon_beamlets.py new file mode 100644 index 0000000..91e43d6 --- /dev/null +++ b/tests/services/test_photon_beamlets.py @@ -0,0 +1,181 @@ +"""Unit tests for radiarch.services.photon_beamlets. + +Patches the OpenTPS PhotonPlan / PlanPhotonBeam / PlanPhotonSegment +classes with simple stand-ins so the test exercises the adapter math +(grid dim derivation, jaw resolution) without OpenTPS available. +""" + +from __future__ import annotations + +import sys +import types +from dataclasses import dataclass, field +from typing import List + +import pytest + +from radiarch.models.beam_model import ( + BeamSetSpec, + BeamSpec, + DeliveryParams, +) +from radiarch.services.machine_model import PhotonMachineModel + + +# --------------------------------------------------------------------------- +# OpenTPS stand-ins +# --------------------------------------------------------------------------- + +@dataclass +class _FakeSegment: + monitorUnits: float = 0.0 + jawOpeningMM: list = field(default_factory=list) + + +@dataclass +class _FakeBeam: + gantryAngle: float = 0.0 + couchAngle: float = 0.0 + segments: list = field(default_factory=list) + + +@dataclass +class _FakePlan: + beams: list = field(default_factory=list) + + +@pytest.fixture(autouse=True) +def _patch_opentps_photon(monkeypatch): + """Insert fake modules at the import paths used by the lazy imports.""" + photon_plan_mod = types.SimpleNamespace(PhotonPlan=_FakePlan) + plan_beam_mod = types.SimpleNamespace(PlanPhotonBeam=_FakeBeam) + plan_seg_mod = types.SimpleNamespace(PlanPhotonSegment=_FakeSegment) + monkeypatch.setitem(sys.modules, + "opentps.core.data.plan._photonPlan", photon_plan_mod) + monkeypatch.setitem(sys.modules, + "opentps.core.data.plan._planPhotonBeam", plan_beam_mod) + monkeypatch.setitem(sys.modules, + "opentps.core.data.plan._planPhotonSegment", plan_seg_mod) + yield + + +def _beam_set(n: int = 2) -> BeamSetSpec: + return BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[BeamSpec(beam_id=f"B{i+1}", gantry_deg=i * 90.0) for i in range(n)], + ) + + +# --------------------------------------------------------------------------- +# Grid math +# --------------------------------------------------------------------------- + +class TestGridMath: + def test_default_jaw_default_beamlet_yields_expected_grid(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + # 200 mm jaw / 5 mm beamlets = 40×40 = 1600 elements per beam. + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel(machine_model_id="default"), + beam_set=_beam_set(1), + params=DeliveryParams(), # default beamlet 5×5 + ) + pb = result.fluence_elements.per_beam[0] + assert pb.grid_dims == (40, 40) + assert pb.element_count == 1600 + assert pb.active_beamlets == 1600 + + def test_custom_jaw_overrides_machine_default(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + # 100 mm jaw / 10 mm beamlets = 10×10 = 100 per beam. + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel( + machine_model_id="default", max_jaw_opening_mm=300.0, + ), + beam_set=_beam_set(1), + params=DeliveryParams( + jaw_opening_mm=(100.0, 100.0), + beamlet_size_mm=(10.0, 10.0), + ), + ) + pb = result.fluence_elements.per_beam[0] + assert pb.grid_dims == (10, 10) + assert pb.element_count == 100 + + def test_total_count_is_per_beam_sum(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel(machine_model_id="default"), + beam_set=_beam_set(3), + params=DeliveryParams(), + ) + assert result.fluence_elements.total_count == 1600 * 3 + assert len(result.fluence_elements.per_beam) == 3 + + def test_non_divisible_jaw_rounds_up(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + # 100 mm / 7 mm = 14.28… → ceil to 15 per axis. + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel(machine_model_id="default"), + beam_set=_beam_set(1), + params=DeliveryParams( + jaw_opening_mm=(100.0, 100.0), + beamlet_size_mm=(7.0, 7.0), + ), + ) + pb = result.fluence_elements.per_beam[0] + assert pb.grid_dims == (15, 15) + + +# --------------------------------------------------------------------------- +# Plan construction +# --------------------------------------------------------------------------- + +class TestPlanConstruction: + def test_one_beam_per_request(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel(machine_model_id="default"), + beam_set=_beam_set(3), + params=DeliveryParams(), + ) + assert len(result.plan.beams) == 3 + + def test_gantry_angles_propagate(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel(machine_model_id="default"), + beam_set=BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[ + BeamSpec(beam_id="B1", gantry_deg=72), + BeamSpec(beam_id="B2", gantry_deg=144), + ], + ), + params=DeliveryParams(), + ) + assert [b.gantryAngle for b in result.plan.beams] == [72.0, 144.0] + + def test_segment_jaw_centered_on_isocenter(self): + from radiarch.services.photon_beamlets import generate_photon_beamlets + + # Jaw opening 80 mm should produce [-40, +40] mm jaw bounds. + result = generate_photon_beamlets( + ct=object(), patient=object(), + machine_model=PhotonMachineModel(machine_model_id="default"), + beam_set=_beam_set(1), + params=DeliveryParams(jaw_opening_mm=(80.0, 80.0)), + ) + seg = result.plan.beams[0].segments[0] + assert seg.jawOpeningMM == [-40.0, 40.0] diff --git a/tests/services/test_proton_spots.py b/tests/services/test_proton_spots.py new file mode 100644 index 0000000..49bddbc --- /dev/null +++ b/tests/services/test_proton_spots.py @@ -0,0 +1,215 @@ +"""Unit tests for radiarch.services.proton_spots. + +We bypass OpenTPS by patching ``ProtonPlanDesign`` to return a hand-built +plan stub. The function-under-test is the *adapter* that maps +``BeamSetSpec`` + ``DeliveryParams`` onto OpenTPS conventions and walks +the resulting plan into a ``FluenceElementSet`` — that's what we cover. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import List +from unittest.mock import patch + +import pytest + +from radiarch.models.beam_model import ( + BeamSetSpec, + BeamSpec, + DeliveryParams, + Modality, +) +from radiarch.services.machine_model import ProtonMachineModel + + +# --------------------------------------------------------------------------- +# Test doubles for the OpenTPS plan structure walked by _summarize_proton_plan +# --------------------------------------------------------------------------- + +@dataclass +class _FakeLayer: + nominalEnergy: float + spots: List[int] # Just a list whose len() is the spot count + + +@dataclass +class _FakeBeam: + layers: List[_FakeLayer] = field(default_factory=list) + + +@dataclass +class _FakePlan: + beams: List[_FakeBeam] = field(default_factory=list) + + +class _FakeProtonPlanDesign: + """Stand-in for OpenTPS's ``ProtonPlanDesign``. + + Records the inputs it receives (so tests can assert mapping logic) + and returns a deterministic plan from ``buildPlan``. + """ + + instances: List["_FakeProtonPlanDesign"] = [] + + def __init__(self): + self.ct = None + self.patient = None + self.calibration = None + self.gantryAngles = None + self.couchAngles = None + self.spotSpacing = None + self.layerSpacing = None + self.target_call = None + # Build-time return shape — set by the test before buildPlan is called. + self._plan = None + _FakeProtonPlanDesign.instances.append(self) + + def defineTargetMaskAndPrescription(self, contour, prescription_gy): + self.target_call = (contour, prescription_gy) + + def buildPlan(self): + # Default plan: two energy layers per beam, 5 spots each. + if self._plan is not None: + return self._plan + beams = [] + for _ in (self.gantryAngles or []): + beams.append(_FakeBeam(layers=[ + _FakeLayer(nominalEnergy=100.0, spots=list(range(5))), + _FakeLayer(nominalEnergy=110.0, spots=list(range(5))), + ])) + return _FakePlan(beams=beams) + + +@pytest.fixture(autouse=True) +def _reset_design_instances(): + _FakeProtonPlanDesign.instances.clear() + + +@pytest.fixture +def patched_proton_plan_design(monkeypatch): + """Patch the OpenTPS ProtonPlanDesign import inside the generator.""" + import radiarch.services.proton_spots as ps + + # Insert a fake into the import path used by the lazy import inside + # generate_proton_spots. We patch sys.modules so the lazy + # ``from opentps.core.data.plan import ProtonPlanDesign`` returns + # our fake. + import sys + import types + + fake_mod = types.SimpleNamespace(ProtonPlanDesign=_FakeProtonPlanDesign) + monkeypatch.setitem(sys.modules, "opentps.core.data.plan", fake_mod) + yield + + +def _beam_set(n_beams: int = 2) -> BeamSetSpec: + return BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[ + BeamSpec(beam_id=f"B{i+1}", gantry_deg=i * 90.0) + for i in range(n_beams) + ], + ) + + +# --------------------------------------------------------------------------- +# Mapping: request → ProtonPlanDesign attributes +# --------------------------------------------------------------------------- + +class TestRequestMapping: + def test_gantry_couch_passed_through_in_order(self, patched_proton_plan_design): + from radiarch.services.proton_spots import generate_proton_spots + + beam_set = BeamSetSpec( + isocenter_mm=(0, 0, 0), + beams=[ + BeamSpec(beam_id="A", gantry_deg=0, couch_deg=10), + BeamSpec(beam_id="B", gantry_deg=90, couch_deg=-5), + ], + ) + result = generate_proton_spots( + ct=object(), patient=object(), target_contour=object(), + machine_model=ProtonMachineModel(machine_model_id="default"), + beam_set=beam_set, + params=DeliveryParams(), + ) + design = _FakeProtonPlanDesign.instances[0] + assert design.gantryAngles == [0.0, 90.0] + assert design.couchAngles == [10.0, -5.0] + + def test_spot_and_layer_spacing_applied(self, patched_proton_plan_design): + from radiarch.services.proton_spots import generate_proton_spots + + params = DeliveryParams(spot_spacing_mm=3.5, layer_spacing_mm=2.5) + generate_proton_spots( + ct=object(), patient=object(), target_contour=object(), + machine_model=ProtonMachineModel(machine_model_id="default"), + beam_set=_beam_set(), + params=params, + ) + design = _FakeProtonPlanDesign.instances[0] + assert design.spotSpacing == 3.5 + assert design.layerSpacing == 2.5 + + def test_target_contour_supplied_to_design(self, patched_proton_plan_design): + from radiarch.services.proton_spots import generate_proton_spots + + sentinel = object() + generate_proton_spots( + ct=object(), patient=object(), target_contour=sentinel, + machine_model=ProtonMachineModel(machine_model_id="default"), + beam_set=_beam_set(1), + params=DeliveryParams(), + prescription_gy=4.0, + ) + design = _FakeProtonPlanDesign.instances[0] + assert design.target_call == (sentinel, 4.0) + + def test_no_target_skips_define(self, patched_proton_plan_design): + from radiarch.services.proton_spots import generate_proton_spots + + generate_proton_spots( + ct=object(), patient=object(), target_contour=None, + machine_model=ProtonMachineModel(machine_model_id="default"), + beam_set=_beam_set(1), + params=DeliveryParams(), + ) + design = _FakeProtonPlanDesign.instances[0] + assert design.target_call is None + + +# --------------------------------------------------------------------------- +# Result: built plan → FluenceElementSet +# --------------------------------------------------------------------------- + +class TestPlanSummarization: + def test_per_beam_layout_matches_request(self, patched_proton_plan_design): + from radiarch.services.proton_spots import generate_proton_spots + + beam_set = _beam_set(2) + result = generate_proton_spots( + ct=object(), patient=object(), target_contour=object(), + machine_model=ProtonMachineModel(machine_model_id="default"), + beam_set=beam_set, + params=DeliveryParams(), + ) + # Default fake plan: 2 layers × 5 spots per beam, 2 beams = 20 spots. + assert result.fluence_elements.total_count == 20 + assert [pb.beam_id for pb in result.fluence_elements.per_beam] == ["B1", "B2"] + for pb in result.fluence_elements.per_beam: + assert pb.element_count == 10 + assert pb.energy_layers == [100.0, 110.0] + assert pb.spots_per_layer == [5, 5] + + def test_plan_object_is_returned(self, patched_proton_plan_design): + from radiarch.services.proton_spots import generate_proton_spots + + result = generate_proton_spots( + ct=object(), patient=object(), target_contour=object(), + machine_model=ProtonMachineModel(machine_model_id="default"), + beam_set=_beam_set(1), + params=DeliveryParams(), + ) + assert isinstance(result.plan, _FakePlan) + assert len(result.plan.beams) == 1 diff --git a/tests/test_api_beam_model.py b/tests/test_api_beam_model.py new file mode 100644 index 0000000..e7b6f3a --- /dev/null +++ b/tests/test_api_beam_model.py @@ -0,0 +1,280 @@ +"""End-to-end tests for the /beam-model/* routes (sync + async + DELETE). + +Uses FastAPI's TestClient + the same monkey-patch fixture pattern as +``test_api_geometry.py``: stub out _load_geometry, force the modality +builder to return a deterministic plan, patch ``.delay`` to ``.run`` so +Celery doesn't try to reach Redis, patch init_db to avoid Postgres. +""" + +from __future__ import annotations + +import tempfile +import types +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from radiarch import app as radiarch_app +from radiarch.api.routes import beam_model as beam_model_route +from radiarch.app import create_app +from radiarch.core import store as store_module +from radiarch.models.beam_model import ( + FluenceElementSet, + Modality, + PerBeamElements, +) +from radiarch.services.beam_model import BeamModelService, _LoadedGeometry +from radiarch.tasks import beam_model_tasks as beam_tasks_module + + +# --------------------------------------------------------------------------- +# Test plan double — picklable so persistence tests can run end-to-end. +# --------------------------------------------------------------------------- + +@dataclass +class _MockPlan: + note: str = "mock" + + +def _stub_geometry() -> _LoadedGeometry: + return _LoadedGeometry( + geometry_id="g-1", + ct=object(), + patient=object(), + target_contour=object(), + ) + + +def _stub_proton_build(req, geom, mm): + from radiarch.services.proton_spots import ProtonBuildResult + return ProtonBuildResult( + fluence_elements=FluenceElementSet( + total_count=20, + per_beam=[PerBeamElements(beam_id="B1", element_count=20, + energy_layers=[100.0, 110.0], + spots_per_layer=[10, 10])], + ), + plan=_MockPlan(note="proton"), + ) + + +def _stub_photon_build(req, geom, mm): + from radiarch.services.photon_beamlets import PhotonBuildResult + return PhotonBuildResult( + fluence_elements=FluenceElementSet( + total_count=400, + per_beam=[PerBeamElements(beam_id="B1", element_count=400, + grid_dims=(20, 20), + active_beamlets=400)], + ), + plan=_MockPlan(note="photon"), + ) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def client(monkeypatch): + tmp = tempfile.TemporaryDirectory() + svc = BeamModelService(base_dir=tmp.name) + + # Stub the dependency loaders so OpenTPS / DICOM are never touched. + monkeypatch.setattr(svc, "_load_geometry", lambda gid: _stub_geometry()) + monkeypatch.setattr(BeamModelService, "_load_geometry", + lambda self, gid: _stub_geometry()) + monkeypatch.setattr(BeamModelService, "_build_proton", + staticmethod(_stub_proton_build)) + monkeypatch.setattr(BeamModelService, "_build_photon", + staticmethod(_stub_photon_build)) + + # Force every BeamModelService instance built inside the Celery task + # to use the same tempdir. Celery instantiates its own service. + original_init = BeamModelService.__init__ + + def _init_to_tmp(self, base_dir=None): + original_init(self, base_dir=tmp.name) + + monkeypatch.setattr(BeamModelService, "__init__", _init_to_tmp) + + # Reset and re-point singletons. + store_module.reset_store() + beam_model_route._service.cache_clear() + monkeypatch.setattr(beam_model_route, "_service", lambda: svc) + + # Bypass the Postgres lifespan init. + monkeypatch.setattr(radiarch_app, "init_db", lambda: None) + + # Bypass Celery: call the task body synchronously. + def _eager_delay(job_id, request_payload): + beam_tasks_module.build_beam_model_job.run(job_id, request_payload) + return types.SimpleNamespace(id=job_id) + + monkeypatch.setattr( + beam_tasks_module.build_beam_model_job, "delay", _eager_delay + ) + + app = create_app() + with TestClient(app) as c: + yield c + tmp.cleanup() + store_module.reset_store() + + +def _proton_payload() -> dict: + return { + "geometry_id": "g-1", + "modality": "PROTON_PBS", + "machine_model_id": None, + "beam_set": { + "isocenter_mm": [0, 0, 0], + "beams": [{"beam_id": "B1", "gantry_deg": 0.0, "couch_deg": 0.0, + "collimator_deg": 0.0}], + }, + "delivery_params": { + "spot_spacing_mm": 5.0, + "layer_spacing_mm": 5.0, + }, + } + + +def _photon_payload() -> dict: + return { + "geometry_id": "g-1", + "modality": "PHOTON_IMRT", + "beam_set": { + "isocenter_mm": [0, 0, 0], + "beams": [{"beam_id": "B1", "gantry_deg": 0.0}], + }, + "delivery_params": { + "beamlet_size_mm": [5.0, 5.0], + "jaw_opening_mm": [100.0, 100.0], + }, + } + + +# --------------------------------------------------------------------------- +# Async dispatch +# --------------------------------------------------------------------------- + +class TestBuildAsyncDispatch: + def test_proton_cache_miss_returns_202(self, client: TestClient): + r = client.post("/api/v1/beam-model/build", json=_proton_payload()) + assert r.status_code == 202, r.text + body = r.json() + assert body["job_id"] + assert body["cache_key"] + assert "beam_model_id" not in body + + def test_photon_cache_miss_returns_202(self, client: TestClient): + r = client.post("/api/v1/beam-model/build", json=_photon_payload()) + assert r.status_code == 202, r.text + + +# --------------------------------------------------------------------------- +# Cache hit +# --------------------------------------------------------------------------- + +class TestBuildCacheHit: + def test_second_build_returns_200_with_full_result(self, client: TestClient): + first = client.post("/api/v1/beam-model/build", json=_proton_payload()) + assert first.status_code == 202 + + second = client.post("/api/v1/beam-model/build", json=_proton_payload()) + assert second.status_code == 200, second.text + body = second.json() + assert body["beam_model_id"] + assert body["modality"] == "PROTON_PBS" + assert body["fluence_elements"]["total_count"] == 20 + + def test_cache_hit_response_has_no_job_id(self, client: TestClient): + client.post("/api/v1/beam-model/build", json=_proton_payload()) + second = client.post("/api/v1/beam-model/build", json=_proton_payload()).json() + assert "job_id" not in second + + +# --------------------------------------------------------------------------- +# Jobs endpoint +# --------------------------------------------------------------------------- + +class TestJobsEndpoint: + def test_unknown_job_id_is_404(self, client: TestClient): + r = client.get("/api/v1/beam-model/jobs/does-not-exist") + assert r.status_code == 404 + + def test_succeeded_job_carries_beam_model_id(self, client: TestClient): + r = client.post("/api/v1/beam-model/build", json=_proton_payload()) + job_id = r.json()["job_id"] + + status = client.get(f"/api/v1/beam-model/jobs/{job_id}").json() + assert status["state"] == "succeeded" + assert status["beam_model_id"] + assert status["progress"] == 1.0 + assert status["stage"] == "done" + + def test_job_beam_model_id_resolves_to_actual_result(self, client: TestClient): + r = client.post("/api/v1/beam-model/build", json=_photon_payload()).json() + status = client.get(f"/api/v1/beam-model/jobs/{r['job_id']}").json() + bm = client.get(f"/api/v1/beam-model/{status['beam_model_id']}").json() + assert bm["beam_model_id"] == status["beam_model_id"] + assert bm["cache_key"] == r["cache_key"] + assert bm["modality"] == "PHOTON_IMRT" + + +# --------------------------------------------------------------------------- +# GET /{id} + /{id}/artifact +# --------------------------------------------------------------------------- + +class TestRetrieval: + def _submit_and_get_id(self, client: TestClient) -> str: + r = client.post("/api/v1/beam-model/build", json=_proton_payload()) + job_id = r.json()["job_id"] + return client.get(f"/api/v1/beam-model/jobs/{job_id}").json()["beam_model_id"] + + def test_get_returns_full_result(self, client: TestClient): + bm_id = self._submit_and_get_id(client) + r = client.get(f"/api/v1/beam-model/{bm_id}").json() + assert r["beam_model_id"] == bm_id + assert r["fluence_elements"]["per_beam"][0]["beam_id"] == "B1" + + def test_artifact_stream_returns_pickled_bytes(self, client: TestClient): + bm_id = self._submit_and_get_id(client) + r = client.get(f"/api/v1/beam-model/{bm_id}/artifact") + assert r.status_code == 200 + # Pickle protocol 5 starts with 0x80 0x05. + assert r.content[:1] == b"\x80" + + def test_unknown_id_is_404(self, client: TestClient): + assert client.get("/api/v1/beam-model/nope").status_code == 404 + assert client.get("/api/v1/beam-model/nope/artifact").status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE +# --------------------------------------------------------------------------- + +class TestDelete: + def _submit_and_get_id(self, client: TestClient) -> str: + r = client.post("/api/v1/beam-model/build", json=_proton_payload()) + job_id = r.json()["job_id"] + return client.get(f"/api/v1/beam-model/jobs/{job_id}").json()["beam_model_id"] + + def test_delete_returns_204_and_removes(self, client: TestClient): + bm_id = self._submit_and_get_id(client) + r = client.delete(f"/api/v1/beam-model/{bm_id}") + assert r.status_code == 204 + assert client.get(f"/api/v1/beam-model/{bm_id}").status_code == 404 + + def test_delete_scrubs_cache(self, client: TestClient): + bm_id = self._submit_and_get_id(client) + client.delete(f"/api/v1/beam-model/{bm_id}") + # Same payload should rebuild rather than cache-hit. + r = client.post("/api/v1/beam-model/build", json=_proton_payload()) + assert r.status_code == 202 + + def test_delete_unknown_id_is_404(self, client: TestClient): + assert client.delete("/api/v1/beam-model/nope").status_code == 404 diff --git a/tests/test_api_geometry.py b/tests/test_api_geometry.py index f9f0d23..b483d86 100644 --- a/tests/test_api_geometry.py +++ b/tests/test_api_geometry.py @@ -1,8 +1,11 @@ -"""End-to-end tests for the /geometry/* routes. - -Uses FastAPI's TestClient and a monkey-patched GeometryService so we can -drive the pipeline against synthetic CT + contours — no OpenTPS, no -real DICOM, no Celery. +"""End-to-end tests for the /geometry/* routes (sync + async paths). + +Uses FastAPI's TestClient plus a monkey-patched GeometryService so the +pipeline runs against synthetic CT + contours — no OpenTPS, no real +DICOM. In ``environment=dev`` Celery is configured to run tasks eagerly, +so the dispatch path ``build_geometry_job.delay()`` executes +synchronously in the API thread and we can poll the job endpoint +immediately. """ from __future__ import annotations @@ -19,11 +22,13 @@ from radiarch import app as radiarch_app from radiarch.api.routes import geometry as geometry_route from radiarch.app import create_app +from radiarch.core import store as store_module from radiarch.services.geometry import GeometryService, _LoadedCT +from radiarch.tasks import geometry_tasks as geometry_tasks_module # --------------------------------------------------------------------------- -# Fakes (trimmed copies of test_geometry_service.py) +# Fakes # --------------------------------------------------------------------------- @dataclass @@ -72,35 +77,61 @@ def _build_loaded_ct() -> _LoadedCT: @pytest.fixture def client(monkeypatch): - """FastAPI client with a sandboxed, stubbed GeometryService singleton. - - We intercept two things that would otherwise break an offline pytest run: - * ``init_db`` — the real implementation tries to connect to Postgres, - which isn't available outside docker-compose. Stub to a no-op. - * ``geometry_route._service`` — swap the lru_cached factory for a - lambda returning our stubbed service. Clear the lru cache *before* - monkeypatching so any prior test's cached instance doesn't leak. + """FastAPI client with a sandboxed GeometryService and an in-memory store. + + The Celery task ``build_geometry_job`` imports ``GeometryService`` + lazily and constructs its own instance inside the task — we + monkey-patch the class's ``_load`` method to use the stub, so both + the API-route service *and* the Celery-task service produce the same + synthetic data. """ tmp = tempfile.TemporaryDirectory() svc = GeometryService(base_dir=tmp.name) monkeypatch.setattr(svc, "_load", lambda _req: _build_loaded_ct()) - # Clear the real lru_cache before we replace the function — otherwise - # the *previous* test's cached service would leak via the module global - # once monkeypatch reverts at teardown. + # Reset any cached singletons so the test starts with a clean store. + store_module.reset_store() geometry_route._service.cache_clear() + # Swap the lru_cached service for our preconfigured instance. + monkeypatch.setattr(geometry_route, "_service", lambda: svc) + + # Celery task instantiates its own GeometryService — patch the class + # so any instance created in the task uses the same tempdir + stub. + monkeypatch.setattr( + GeometryService, + "_load", + lambda self, _req: _build_loaded_ct(), + ) + # Also force new instances into the same base_dir. + original_init = GeometryService.__init__ + + def _init_to_tmp(self, base_dir=None, adapter=None): + original_init(self, base_dir=tmp.name, adapter=adapter) + + monkeypatch.setattr(GeometryService, "__init__", _init_to_tmp) + # No-op init_db so the app lifespan doesn't try to reach Postgres. monkeypatch.setattr(radiarch_app, "init_db", lambda: None) - monkeypatch.setattr(geometry_route, "_service", lambda: svc) + + # Bypass Celery entirely: call the task body synchronously. Celery's + # eager mode still tries to use the Redis result backend, which isn't + # running locally; patching .delay sidesteps the broker completely. + import types + + def _eager_delay(job_id, request_payload): + geometry_tasks_module.build_geometry_job.run(job_id, request_payload) + return types.SimpleNamespace(id=job_id) + + monkeypatch.setattr( + geometry_tasks_module.build_geometry_job, "delay", _eager_delay + ) app = create_app() with TestClient(app) as c: yield c tmp.cleanup() - # NOTE: don't call geometry_route._service.cache_clear() here — at this - # point it's still the monkeypatched lambda (which has no cache_clear). - # monkeypatch teardown restores the real lru_cached function on exit. + store_module.reset_store() def _sample_payload(grid_spec=None) -> dict: @@ -116,50 +147,143 @@ def _sample_payload(grid_spec=None) -> dict: # --------------------------------------------------------------------------- -# Tests +# POST /build — async dispatch path # --------------------------------------------------------------------------- -class TestBuildEndpoint: - def test_happy_path_returns_geometry_result(self, client: TestClient) -> None: +class TestBuildAsyncDispatch: + def test_cache_miss_returns_202_and_job_id(self, client: TestClient) -> None: r = client.post("/api/v1/geometry/build", json=_sample_payload()) - assert r.status_code == 200, r.text + assert r.status_code == 202, r.text body = r.json() + assert "job_id" in body + assert body["cache_key"] + # No geometry fields in the 202 response shape. + assert "geometry_id" not in body + assert "structure_index" not in body + + def test_202_job_is_in_queued_or_succeeded_state(self, client: TestClient) -> None: + """With Celery eager mode the task finishes before the HTTP + response returns. Accept either queued (if the broker were real) + or succeeded (eager) — both valid.""" + r = client.post("/api/v1/geometry/build", json=_sample_payload()) + job_id = r.json()["job_id"] + status_r = client.get(f"/api/v1/geometry/jobs/{job_id}") + assert status_r.status_code == 200, status_r.text + assert status_r.json()["state"] in {"queued", "running", "succeeded"} + +# --------------------------------------------------------------------------- +# POST /build — cache-hit fast path +# --------------------------------------------------------------------------- + +class TestBuildCacheHit: + def test_second_build_returns_200_with_full_result(self, client: TestClient) -> None: + # First call: 202 (cache miss, builds and caches via eager Celery). + first = client.post("/api/v1/geometry/build", json=_sample_payload()) + assert first.status_code == 202 + + # Second call: 200 (cache hit) with the full geometry inline. + second = client.post("/api/v1/geometry/build", json=_sample_payload()) + assert second.status_code == 200, second.text + body = second.json() assert body["geometry_id"] assert body["structure_index"] == {"PTV": 1} - assert body["frame_of_reference_uid"] == "1.2.3.9" assert body["ct_metadata"]["num_slices"] == 8 - assert body["grid_spec"]["size"] == [8, 8, 8] - assert body["cache_key"] - def test_cached_second_build_returns_same_geometry_id(self, client: TestClient) -> None: - first = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + def test_cache_hit_has_no_job_id_field(self, client: TestClient) -> None: + client.post("/api/v1/geometry/build", json=_sample_payload()) second = client.post("/api/v1/geometry/build", json=_sample_payload()).json() - assert first["geometry_id"] == second["geometry_id"] + # 200 response is a GeometryResult — no job_id key. + assert "job_id" not in second -class TestGetEndpoint: - def test_roundtrip_build_then_fetch(self, client: TestClient) -> None: - built = client.post("/api/v1/geometry/build", json=_sample_payload()).json() - fetched = client.get(f"/api/v1/geometry/{built['geometry_id']}").json() - assert fetched["geometry_id"] == built["geometry_id"] - assert fetched["cache_key"] == built["cache_key"] +# --------------------------------------------------------------------------- +# GET /jobs/{job_id} +# --------------------------------------------------------------------------- - def test_unknown_id_is_404(self, client: TestClient) -> None: - r = client.get("/api/v1/geometry/does-not-exist") +class TestJobsEndpoint: + def test_unknown_job_id_is_404(self, client: TestClient) -> None: + r = client.get("/api/v1/geometry/jobs/does-not-exist") assert r.status_code == 404 + def test_succeeded_job_carries_geometry_id(self, client: TestClient) -> None: + """In eager mode the job finishes synchronously → polling + immediately after dispatch should see state=succeeded plus a + populated geometry_id.""" + r = client.post("/api/v1/geometry/build", json=_sample_payload()) + job_id = r.json()["job_id"] + + status = client.get(f"/api/v1/geometry/jobs/{job_id}").json() + assert status["state"] == "succeeded" + assert status["geometry_id"] # non-null, points at a real result + assert status["progress"] == 1.0 + assert status["stage"] == "done" + + def test_job_geometry_id_resolves_to_actual_result( + self, client: TestClient + ) -> None: + r = client.post("/api/v1/geometry/build", json=_sample_payload()).json() + status = client.get(f"/api/v1/geometry/jobs/{r['job_id']}").json() + geom = client.get(f"/api/v1/geometry/{status['geometry_id']}").json() + assert geom["geometry_id"] == status["geometry_id"] + assert geom["cache_key"] == r["cache_key"] + + +# --------------------------------------------------------------------------- +# GET /{geometry_id} + streaming endpoints (regression, unchanged behavior) +# --------------------------------------------------------------------------- + +class TestGeometryRetrieval: + def _submit_and_get_id(self, client: TestClient) -> str: + r = client.post("/api/v1/geometry/build", json=_sample_payload()) + job_id = r.json()["job_id"] + return client.get(f"/api/v1/geometry/jobs/{job_id}").json()["geometry_id"] -class TestVolumeStreaming: def test_density_stream_returns_nifti_bytes(self, client: TestClient) -> None: - built = client.post("/api/v1/geometry/build", json=_sample_payload()).json() - r = client.get(f"/api/v1/geometry/{built['geometry_id']}/density") + gid = self._submit_and_get_id(client) + r = client.get(f"/api/v1/geometry/{gid}/density") assert r.status_code == 200 # gzipped NIfTI starts with the gzip magic 0x1f 0x8b. assert r.content[:2] == b"\x1f\x8b" def test_masks_stream_returns_nifti_bytes(self, client: TestClient) -> None: - built = client.post("/api/v1/geometry/build", json=_sample_payload()).json() - r = client.get(f"/api/v1/geometry/{built['geometry_id']}/masks") + gid = self._submit_and_get_id(client) + r = client.get(f"/api/v1/geometry/{gid}/masks") assert r.status_code == 200 assert r.content[:2] == b"\x1f\x8b" + + def test_unknown_geometry_id_is_404(self, client: TestClient) -> None: + assert client.get("/api/v1/geometry/nope").status_code == 404 + assert client.get("/api/v1/geometry/nope/density").status_code == 404 + assert client.get("/api/v1/geometry/nope/masks").status_code == 404 + + +# --------------------------------------------------------------------------- +# DELETE /{geometry_id} +# --------------------------------------------------------------------------- + +class TestDelete: + def _submit_and_get_id(self, client: TestClient) -> str: + r = client.post("/api/v1/geometry/build", json=_sample_payload()) + job_id = r.json()["job_id"] + return client.get(f"/api/v1/geometry/jobs/{job_id}").json()["geometry_id"] + + def test_delete_returns_204_and_removes_geometry(self, client: TestClient) -> None: + gid = self._submit_and_get_id(client) + r = client.delete(f"/api/v1/geometry/{gid}") + assert r.status_code == 204 + assert client.get(f"/api/v1/geometry/{gid}").status_code == 404 + + def test_delete_scrubs_cache_so_next_build_goes_through_pipeline( + self, client: TestClient + ) -> None: + """After deleting, the same request should 202 (rebuild) rather + than 200 (cache hit).""" + gid = self._submit_and_get_id(client) + client.delete(f"/api/v1/geometry/{gid}") + + resubmit = client.post("/api/v1/geometry/build", json=_sample_payload()) + assert resubmit.status_code == 202, resubmit.text + + def test_delete_unknown_id_is_404(self, client: TestClient) -> None: + assert client.delete("/api/v1/geometry/nope").status_code == 404 diff --git a/tests/test_api_uploads.py b/tests/test_api_uploads.py new file mode 100644 index 0000000..9922f4a --- /dev/null +++ b/tests/test_api_uploads.py @@ -0,0 +1,347 @@ +"""End-to-end tests for the /uploads/* routes. + +Generates a tiny valid-enough DICOM file in-memory (one CT instance + +one RTSTRUCT instance), packs them into a ZIP, POSTs the ZIP, and +checks: + +* The upload endpoint returns the right counts. +* GET /uploads/{id} returns the same shape. +* DELETE removes the directory and a second GET 404s. +* The geometry build endpoint can resolve an upload_id and dispatch + without trying to reach Orthanc. +""" + +from __future__ import annotations + +import io +import tempfile +import types +import zipfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Tuple + +import numpy as np +import pydicom +import pytest +from fastapi.testclient import TestClient +from pydicom.dataset import Dataset, FileDataset +from pydicom.uid import ExplicitVRLittleEndian, generate_uid + +from radiarch import app as radiarch_app +from radiarch.api.routes import geometry as geometry_route +from radiarch.api.routes import uploads as uploads_route +from radiarch.app import create_app +from radiarch.core import store as store_module +from radiarch.services.geometry import GeometryService, _LoadedCT +from radiarch.tasks import geometry_tasks as geometry_tasks_module + + +# --------------------------------------------------------------------------- +# DICOM file synthesis +# --------------------------------------------------------------------------- + +def _minimal_dicom(modality: str, sop_class_uid: str) -> bytes: + """Build a minimal valid DICOM Part-10 file with the given modality.""" + file_meta = Dataset() + file_meta.MediaStorageSOPClassUID = sop_class_uid + file_meta.MediaStorageSOPInstanceUID = generate_uid() + file_meta.TransferSyntaxUID = ExplicitVRLittleEndian + + ds = FileDataset( + "in-memory", + {}, + file_meta=file_meta, + preamble=b"\x00" * 128, + ) + ds.PatientName = "TEST^PATIENT" + ds.PatientID = "TEST_ID" + ds.StudyInstanceUID = "1.2.3" + ds.SeriesInstanceUID = generate_uid() + ds.SOPInstanceUID = file_meta.MediaStorageSOPInstanceUID + ds.SOPClassUID = sop_class_uid + ds.Modality = modality + if modality == "CT": + # Add minimal CT pixel info so pydicom's writer is happy. + ds.Rows = 2 + ds.Columns = 2 + ds.BitsAllocated = 16 + ds.BitsStored = 16 + ds.HighBit = 15 + ds.PixelRepresentation = 1 + ds.SamplesPerPixel = 1 + ds.PhotometricInterpretation = "MONOCHROME2" + ds.PixelData = (np.zeros((2, 2), dtype=np.int16)).tobytes() + + buf = io.BytesIO() + ds.save_as(buf, write_like_original=False) + return buf.getvalue() + + +def _make_zip(file_specs: list[tuple[str, bytes]]) -> bytes: + buf = io.BytesIO() + with zipfile.ZipFile(buf, mode="w", compression=zipfile.ZIP_DEFLATED) as zf: + for name, payload in file_specs: + zf.writestr(name, payload) + return buf.getvalue() + + +# CT SOP Class: "CT Image Storage" +_CT_SOP = "1.2.840.10008.5.1.4.1.1.2" +# RTSTRUCT SOP Class: "RT Structure Set Storage" +_RT_SOP = "1.2.840.10008.5.1.4.1.1.481.3" + + +@pytest.fixture +def study_zip_bytes() -> bytes: + """Two CT slices + one RTSTRUCT, all packed into a single ZIP.""" + return _make_zip([ + ("study/ct_0001.dcm", _minimal_dicom("CT", _CT_SOP)), + ("study/ct_0002.dcm", _minimal_dicom("CT", _CT_SOP)), + ("study/rtstruct.dcm", _minimal_dicom("RTSTRUCT", _RT_SOP)), + ]) + + +# --------------------------------------------------------------------------- +# Test fixtures — reuse the same sandbox pattern as test_api_geometry +# --------------------------------------------------------------------------- + +@dataclass +class _FakePatient: + name: str = "UPLOAD_TEST" + rtStructs: list = field(default_factory=list) + + +@dataclass +class _FakeCT: + imageArray: np.ndarray + origin: Tuple[float, float, float] + spacing: Tuple[float, float, float] + patient: _FakePatient + seriesInstanceUID: str = "1.2.3.4" + studyInstanceUID: str = "1.2.3" + frameOfReferenceUID: str = "1.2.3.9" + + +@dataclass +class _FakeMask: + imageArray: np.ndarray + + +@dataclass +class _FakeContour: + name: str + mask: np.ndarray + + def getBinaryMask(self, origin, gridSize, spacing): + return _FakeMask(imageArray=self.mask.astype(bool)) + + +def _build_loaded_ct() -> _LoadedCT: + ct_array = np.zeros((8, 8, 8), dtype=np.int16) + ct_array[2:6, 2:6, 2:6] = 50 + ptv = np.zeros(ct_array.shape, dtype=bool) + ptv[2:6, 2:6, 2:6] = True + ct = _FakeCT(imageArray=ct_array, origin=(0, 0, 0), + spacing=(1, 1, 1), patient=_FakePatient()) + return _LoadedCT(ct=ct, patient=ct.patient, + contours=[_FakeContour("PTV", ptv)]) + + +@pytest.fixture +def client(monkeypatch, tmp_path: Path): + """Same harness as test_api_geometry, plus a tempdir for uploads.""" + upload_root = tmp_path / "uploads" + upload_root.mkdir() + + # Point the upload helper at our tempdir. The real _upload_root is + # lru_cached, but monkeypatch.setattr replaces the symbol entirely so + # subsequent _upload_root() calls inside the route hit our lambda. + monkeypatch.setattr(uploads_route, "_upload_root", + lambda: upload_root, raising=True) + + # Point GeometryService at the same tempdir for upload_id resolution. + import radiarch.services.geometry as geometry_service_module + + def _resolve_upload_path_stub(upload_id: str) -> Path: + path = upload_root / upload_id + if not path.is_dir(): + raise ValueError(f"Upload id not found: {upload_id!r}.") + return path + + monkeypatch.setattr( + geometry_service_module.GeometryService, + "_resolve_upload_path", + staticmethod(_resolve_upload_path_stub), + ) + + # Stub the geometry build itself — we're testing the upload plumbing, + # not OpenTPS. The stub still honors the upload-id contract: if the + # request carries an upload_id we run the (patched) resolver so a + # stale id surfaces as a ValueError just like in production. + def _load_honoring_upload(self, req): + if req.patient_ref.upload_id: + self._resolve_upload_path(req.patient_ref.upload_id) + return _build_loaded_ct() + + artifacts = tempfile.TemporaryDirectory() + svc = GeometryService(base_dir=artifacts.name) + monkeypatch.setattr(svc, "_load", lambda req: _load_honoring_upload(svc, req)) + store_module.reset_store() + geometry_route._service.cache_clear() + monkeypatch.setattr(geometry_route, "_service", lambda: svc) + monkeypatch.setattr(GeometryService, "_load", _load_honoring_upload) + + original_init = GeometryService.__init__ + + def _init_to_tmp(self, base_dir=None, adapter=None): + original_init(self, base_dir=artifacts.name, adapter=adapter) + + monkeypatch.setattr(GeometryService, "__init__", _init_to_tmp) + monkeypatch.setattr(radiarch_app, "init_db", lambda: None) + + def _eager_delay(job_id, request_payload): + geometry_tasks_module.build_geometry_job.run(job_id, request_payload) + return types.SimpleNamespace(id=job_id) + + monkeypatch.setattr( + geometry_tasks_module.build_geometry_job, "delay", _eager_delay + ) + + app = create_app() + with TestClient(app) as c: + yield c + artifacts.cleanup() + store_module.reset_store() + + +# --------------------------------------------------------------------------- +# Upload happy path +# --------------------------------------------------------------------------- + +class TestUploadHappyPath: + def test_post_returns_201_with_counts( + self, client: TestClient, study_zip_bytes: bytes + ) -> None: + r = client.post( + "/api/v1/uploads/dicom", + files={"file": ("study.zip", study_zip_bytes, "application/zip")}, + ) + assert r.status_code == 201, r.text + body = r.json() + assert body["upload_id"] + assert body["file_count"] == 3 + assert body["dicom_count"] == 3 + assert body["ct_slice_count"] == 2 + assert body["rtstruct_count"] == 1 + assert body["total_bytes"] > 0 + + def test_get_returns_same_shape( + self, client: TestClient, study_zip_bytes: bytes + ) -> None: + upload_id = client.post( + "/api/v1/uploads/dicom", + files={"file": ("study.zip", study_zip_bytes, "application/zip")}, + ).json()["upload_id"] + + r = client.get(f"/api/v1/uploads/{upload_id}") + assert r.status_code == 200 + assert r.json()["upload_id"] == upload_id + assert r.json()["ct_slice_count"] == 2 + + def test_delete_returns_204_and_get_then_404s( + self, client: TestClient, study_zip_bytes: bytes + ) -> None: + upload_id = client.post( + "/api/v1/uploads/dicom", + files={"file": ("study.zip", study_zip_bytes, "application/zip")}, + ).json()["upload_id"] + + assert client.delete(f"/api/v1/uploads/{upload_id}").status_code == 204 + assert client.get(f"/api/v1/uploads/{upload_id}").status_code == 404 + + +# --------------------------------------------------------------------------- +# Upload error paths +# --------------------------------------------------------------------------- + +class TestUploadErrors: + def test_non_zip_is_rejected(self, client: TestClient) -> None: + r = client.post( + "/api/v1/uploads/dicom", + files={"file": ("not_a_zip.txt", b"hello", "text/plain")}, + ) + assert r.status_code == 400 + + def test_empty_zip_rejected(self, client: TestClient) -> None: + empty = _make_zip([]) + r = client.post( + "/api/v1/uploads/dicom", + files={"file": ("empty.zip", empty, "application/zip")}, + ) + assert r.status_code == 400 + assert "no DICOM" in r.json()["detail"] + + def test_zip_with_only_non_dicom_rejected(self, client: TestClient) -> None: + zip_bytes = _make_zip([("readme.txt", b"just text")]) + r = client.post( + "/api/v1/uploads/dicom", + files={"file": ("readme.zip", zip_bytes, "application/zip")}, + ) + assert r.status_code == 400 + + def test_unknown_upload_id_is_404(self, client: TestClient) -> None: + assert client.get("/api/v1/uploads/does-not-exist").status_code == 404 + assert client.delete("/api/v1/uploads/does-not-exist").status_code == 404 + + +# --------------------------------------------------------------------------- +# Geometry build can resolve an upload_id +# --------------------------------------------------------------------------- + +class TestGeometryWithUpload: + def test_build_with_upload_id_dispatches( + self, client: TestClient, study_zip_bytes: bytes + ) -> None: + upload_id = client.post( + "/api/v1/uploads/dicom", + files={"file": ("study.zip", study_zip_bytes, "application/zip")}, + ).json()["upload_id"] + + payload = { + "patient_ref": {"upload_id": upload_id}, + "grid_spec": None, + "hu_to_density_model": "LINEAR", + } + r = client.post("/api/v1/geometry/build", json=payload) + # Cache miss → 202 (Celery eager path runs the stubbed _load). + assert r.status_code == 202, r.text + job_id = r.json()["job_id"] + + status = client.get(f"/api/v1/geometry/jobs/{job_id}").json() + assert status["state"] == "succeeded" + assert status["geometry_id"] + + def test_build_with_stale_upload_id_is_422(self, client: TestClient) -> None: + payload = { + "patient_ref": {"upload_id": "bogus-id"}, + "grid_spec": None, + "hu_to_density_model": "LINEAR", + } + r = client.post("/api/v1/geometry/build", json=payload) + # The job dispatches but fails inside the task — surface check is + # via the jobs endpoint. + assert r.status_code == 202 + job_id = r.json()["job_id"] + status = client.get(f"/api/v1/geometry/jobs/{job_id}").json() + assert status["state"] == "failed" + assert "Upload id not found" in (status.get("message") or "") + + def test_patient_ref_requires_one_source(self, client: TestClient) -> None: + payload = { + "patient_ref": {}, # neither dicom_study_uid nor upload_id + "grid_spec": None, + "hu_to_density_model": "LINEAR", + } + r = client.post("/api/v1/geometry/build", json=payload) + assert r.status_code == 422 diff --git a/tests/test_opentps_integration.py b/tests/test_opentps_integration.py index f24fd36..d672d32 100644 --- a/tests/test_opentps_integration.py +++ b/tests/test_opentps_integration.py @@ -40,6 +40,10 @@ from radiarch.config import get_settings +@pytest.mark.skipif( + sys.platform == "darwin", + reason="MCsquare binary not shipped for Darwin; vendored OpenTPS is Linux-only.", +) def test_opentps_integration(): # Skip if test data is not available if not os.path.isdir(_test_data_root):