diff --git a/changes/11632.enhance.md b/changes/11632.enhance.md new file mode 100644 index 00000000000..c33f2a9890b --- /dev/null +++ b/changes/11632.enhance.md @@ -0,0 +1 @@ +Split route health Valkey record into ReplicaProbeTarget (probe config) and ReplicaHealthStatus (TTL-based result), removing initial_delay from Valkey and switching to DB-based timeout in check_warming_up_health. diff --git a/src/ai/backend/client/cli/v2/deployment/revision.py b/src/ai/backend/client/cli/v2/deployment/revision.py index 99d386c96f1..499168253ce 100644 --- a/src/ai/backend/client/cli/v2/deployment/revision.py +++ b/src/ai/backend/client/cli/v2/deployment/revision.py @@ -49,7 +49,8 @@ def add(deployment_id: str, config: str, preset_id: str | None, auto_activate: b data["deployment_id"] = deployment_id if preset_id is not None: data["revision_preset_id"] = preset_id - data["auto_activate"] = auto_activate + if auto_activate: + data.setdefault("options", {})["auto_activate"] = True body = AddRevisionInput.model_validate(data) async def _run() -> None: diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py index 64fae3d692e..90acbe862c9 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py @@ -2,14 +2,16 @@ HealthCheckStatus, HealthStatus, KernelStatus, - RouteHealthRecord, ValkeyScheduleClient, ) +from .types import ReplicaHealthResult, ReplicaHealthStatus, ReplicaProbeTarget __all__ = [ "HealthCheckStatus", "HealthStatus", "KernelStatus", - "RouteHealthRecord", + "ReplicaHealthResult", + "ReplicaHealthStatus", + "ReplicaProbeTarget", "ValkeyScheduleClient", ] diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py index 4d8ae23d5a1..3e7b5e32bf1 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py @@ -12,7 +12,13 @@ AbstractValkeyClient, create_valkey_client, ) +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + ReplicaHealthResult, + ReplicaHealthStatus, + ReplicaProbeTarget, +) from ai.backend.common.exception import BackendAIError +from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.json import dump_json_str, load_json from ai.backend.common.metrics.metric import DomainType, LayerType from ai.backend.common.resilience import ( @@ -34,6 +40,8 @@ AGENT_LAST_CHECK_TTL_SEC = 1200 # 20 minutes - TTL for agent last check timestamp ORPHAN_KERNEL_THRESHOLD_SEC = 600 # 10 minutes - threshold for orphan kernel detection FORCE_TERMINATED_CLEANUP_TTL_SEC = 1200 # 20 minutes - TTL for force-terminated cleanup queue +ROUTE_PROBE_TTL_SEC = 3600 # 1 hour - TTL for route probe targets +ROUTE_HEALTH_STATUS_TTL_SEC = 120 # 2 minutes - TTL for route health status (expiry = DEGRADED) class HealthCheckStatus(enum.StrEnum): @@ -84,87 +92,6 @@ def get_status(self) -> HealthCheckStatus | None: return self.readiness -@dataclass -class RouteHealthRecord: - """Health record for a route stored in Valkey. - - Stored as a hash at key `route_health:{route_id}`. - Created when replica_host/port become available (session RUNNING). - Agent and Manager each update their own fields independently. - """ - - route_id: str - created_at: int # Unix timestamp when route was created - initial_delay_until: int # Unix timestamp = running_at + initial_delay - health_path: str # extracted from model_definition - inference_port: int # extracted from kernel - replica_host: str # extracted from kernel - - # Timestamp when route entered RUNNING state (set by coordinator) - running_at: int | None = None - - # Agent check results - agent_healthy: bool = False - agent_last_check: int = 0 # Unix timestamp - - # Manager check results - manager_healthy: bool = False - manager_last_check: int = 0 # Unix timestamp - - @property - def last_check(self) -> int: - """Most recent check timestamp from either agent or manager.""" - return max(self.agent_last_check, self.manager_last_check) - - @property - def healthy(self) -> bool: - """Health based on the most recent checker (agent or manager).""" - if self.agent_last_check >= self.manager_last_check: - return self.agent_healthy - return self.manager_healthy - - def is_stale(self, current_time: int, staleness_sec: int = MAX_HEALTH_STALENESS_SEC) -> bool: - """Check if health data is stale (no recent check from either side).""" - if self.last_check == 0: - return False # Never checked yet — not stale, just NOT_CHECKED - return (current_time - self.last_check) > staleness_sec - - def to_valkey_hash(self) -> Mapping[str, str]: - """Serialize to Valkey hash fields.""" - data: dict[str, str] = { - "route_id": self.route_id, - "created_at": str(self.created_at), - "initial_delay_until": str(self.initial_delay_until), - "health_path": self.health_path, - "inference_port": str(self.inference_port), - "replica_host": self.replica_host, - "agent_healthy": "1" if self.agent_healthy else "0", - "agent_last_check": str(self.agent_last_check), - "manager_healthy": "1" if self.manager_healthy else "0", - "manager_last_check": str(self.manager_last_check), - } - if self.running_at is not None: - data["running_at"] = str(self.running_at) - return data - - @classmethod - def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthRecord: - """Deserialize from Valkey hash fields.""" - return cls( - route_id=data["route_id"], - created_at=int(data["created_at"]), - initial_delay_until=int(data["initial_delay_until"]), - health_path=data["health_path"], - inference_port=int(data["inference_port"]), - replica_host=data["replica_host"], - running_at=int(raw) if (raw := data.get("running_at")) and raw != "0" else None, - agent_healthy=data.get("agent_healthy", "0") == "1", - agent_last_check=int(data.get("agent_last_check", "0")), - manager_healthy=data.get("manager_healthy", "0") == "1", - manager_last_check=int(data.get("manager_last_check", "0")), - ) - - @dataclass class KernelStatus: """Presence status for a kernel.""" @@ -260,6 +187,12 @@ def _get_kernel_presence_key(self, kernel_id: KernelId) -> str: """ return f"kernel:presence:{kernel_id}" + def _get_route_probe_key(self, replica_id: ReplicaID) -> str: + return f"route_probe:{replica_id}" + + def _get_route_health_status_key(self, replica_id: ReplicaID) -> str: + return f"route_health:{replica_id}" + def _get_agent_last_check_key(self, agent_id: AgentId) -> str: """ Generate the Redis key for agent last check timestamp. @@ -675,86 +608,6 @@ async def update_route_liveness(self, route_id: str, liveness: bool) -> None: async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) - @valkey_schedule_resilience.apply() - async def mark_route_running_at(self, route_id: str) -> None: - """ - Record the RUNNING transition timestamp for a route. - Called when a route transitions to RUNNING status. - Uses Redis time for consistency with health check comparisons. - - :param route_id: The route ID that entered RUNNING state - """ - key = self._get_route_health_key(route_id) - current_time = str(await self._get_redis_time()) - async with self._client.client() as conn: - await conn.hset(key, {"running_at": current_time}) - await conn.expire(key, ROUTE_HEALTH_TTL_SEC) - - @valkey_schedule_resilience.apply() - async def get_route_running_at_batch(self, route_ids: Sequence[str]) -> dict[str, int | None]: - """ - Batch read running_at field from route health hashes. - Works even on partial hashes (before full RouteHealthRecord is initialized). - - :param route_ids: Route IDs to look up - :return: Mapping of route_id to running_at timestamp (None if not set) - """ - if not route_ids: - return {} - - batch = Batch(is_atomic=False) - for route_id in route_ids: - key = self._get_route_health_key(route_id) - batch.hget(key, "running_at") - - async with self._client.client() as conn: - results = await conn.exec(batch, raise_on_error=False) - if results is None: - return dict.fromkeys(route_ids) - - running_at_map: dict[str, int | None] = {} - for i, route_id in enumerate(route_ids): - raw = results[i] if len(results) > i else None - if raw and raw != b"0": - running_at_map[route_id] = int(raw) - else: - running_at_map[route_id] = None - return running_at_map - - @valkey_schedule_resilience.apply() - async def refresh_route_health_ttl(self, route_id: str) -> None: - """ - Refresh TTL for a route health record without changing any data. - Called by observer on every check to prevent key expiry. - - :param route_id: The route ID to refresh - """ - key = self._get_route_health_key(route_id) - async with self._client.client() as conn: - await conn.expire(key, ROUTE_HEALTH_TTL_SEC) - - @valkey_schedule_resilience.apply() - async def update_route_manager_health(self, route_id: str, healthy: bool) -> None: - """ - Update manager health check result for a route. - Called by RouteHealthObserver after HTTP health check. - - :param route_id: The route ID to update - :param healthy: Whether the route passed the health check - """ - key = self._get_route_health_key(route_id) - current_time = str(await self._get_redis_time()) - data: Mapping[str | bytes, str | bytes] = { - "manager_healthy": "1" if healthy else "0", - "manager_last_check": current_time, - } - - batch = Batch(is_atomic=False) - batch.hset(key, data) - batch.expire(key, ROUTE_HEALTH_TTL_SEC) - async with self._client.client() as conn: - await conn.exec(batch, raise_on_error=True) - @valkey_schedule_resilience.apply() async def check_route_health_status( self, route_ids: list[str] @@ -842,92 +695,133 @@ async def update_routes_readiness_batch(self, route_readiness: Mapping[str, bool async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) - # ==================== RouteHealthRecord Methods ==================== + # ==================== ReplicaProbeTarget / ReplicaHealthStatus Methods ==================== @valkey_schedule_resilience.apply() - async def initialize_route_health_records_batch( - self, records: Sequence[RouteHealthRecord] + async def register_route_probe_targets_batch( + self, targets: Sequence[ReplicaProbeTarget] ) -> None: """ - Batch initialize RouteHealthRecord entries in Valkey. - Called when replica info becomes available (session RUNNING + kernel host/port known). + Batch register ReplicaProbeTarget entries in Valkey. + Called by coordinator when route enters WARMING_UP and replica host/port are known. - :param records: RouteHealthRecord instances to store + :param targets: ReplicaProbeTarget instances to store """ - if not records: + if not targets: return batch = Batch(is_atomic=False) - for record in records: - key = self._get_route_health_key(record.route_id) - batch.hset(key, record.to_valkey_hash()) - batch.expire(key, ROUTE_HEALTH_TTL_SEC) + for target in targets: + key = self._get_route_probe_key(target.replica_id) + batch.hset(key, target.to_valkey_hash()) + batch.expire(key, ROUTE_PROBE_TTL_SEC) async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) @valkey_schedule_resilience.apply() - async def get_route_health_record(self, route_id: str) -> RouteHealthRecord | None: + async def get_route_probe_targets_batch( + self, replica_ids: Sequence[ReplicaID] + ) -> Mapping[ReplicaID, ReplicaProbeTarget | None]: """ - Get a RouteHealthRecord from Valkey. + Batch get ReplicaProbeTargets from Valkey. - :param route_id: The route ID to look up - :return: RouteHealthRecord or None if not found + :param replica_ids: Replica IDs to look up + :return: Mapping of replica_id to ReplicaProbeTarget (None if missing or expired) """ - key = self._get_route_health_key(route_id) + if not replica_ids: + return {} + + batch = Batch(is_atomic=False) + for replica_id in replica_ids: + batch.hgetall(self._get_route_probe_key(replica_id)) + async with self._client.client() as conn: - result = await conn.hgetall(key) - if not result: - return None + results = await conn.exec(batch, raise_on_error=False) + if results is None: + return dict.fromkeys(replica_ids) - data = {k.decode(): v.decode() for k, v in result.items()} - if "route_id" not in data: - return None - return RouteHealthRecord.from_valkey_hash(data) + targets: dict[ReplicaID, ReplicaProbeTarget | None] = {} + for i, replica_id in enumerate(replica_ids): + hgetall_result = results[i] if len(results) > i else None + if not hgetall_result: + targets[replica_id] = None + continue + raw = cast(dict[bytes, bytes], hgetall_result) + if not raw or b"replica_id" not in raw: + targets[replica_id] = None + continue + data = {k.decode(): v.decode() for k, v in raw.items()} + targets[replica_id] = ReplicaProbeTarget.from_valkey_hash(data) + + return targets @valkey_schedule_resilience.apply() - async def get_route_health_records_batch( - self, route_ids: Sequence[str] - ) -> Mapping[str, RouteHealthRecord | None]: + async def record_route_health_statuses_batch( + self, results: Sequence[ReplicaHealthResult] + ) -> None: """ - Batch get RouteHealthRecords from Valkey. + Batch record health check results for multiple routes. + Fetches Redis time once and writes all statuses in a single pipeline. + Refreshes TTL on every call; key expiry signals DEGRADED. - :param route_ids: Route IDs to look up - :return: Mapping of route_id to RouteHealthRecord (None if missing) + :param results: Sequence of ReplicaHealthResult instances """ - if not route_ids: + if not results: + return + + current_time = str(await self._get_redis_time()) + batch = Batch(is_atomic=False) + for result in results: + key = self._get_route_health_status_key(result.replica_id) + data: Mapping[str | bytes, str | bytes] = { + "replica_id": str(result.replica_id), + "healthy": "1" if result.healthy else "0", + "last_check": current_time, + } + batch.hset(key, data) + batch.expire(key, ROUTE_HEALTH_STATUS_TTL_SEC) + + async with self._client.client() as conn: + await conn.exec(batch, raise_on_error=True) + + @valkey_schedule_resilience.apply() + async def get_route_health_statuses_batch( + self, replica_ids: Sequence[ReplicaID] + ) -> Mapping[ReplicaID, ReplicaHealthStatus | None]: + """ + Batch get ReplicaHealthStatus from Valkey. + None means no recent health check (key missing or TTL expired) → DEGRADED. + + :param replica_ids: Replica IDs to look up + :return: Mapping of replica_id to ReplicaHealthStatus (None if missing or expired) + """ + if not replica_ids: return {} batch = Batch(is_atomic=False) - for route_id in route_ids: - key = self._get_route_health_key(route_id) - batch.hgetall(key) + for replica_id in replica_ids: + batch.hgetall(self._get_route_health_status_key(replica_id)) async with self._client.client() as conn: results = await conn.exec(batch, raise_on_error=False) if results is None: - return dict.fromkeys(route_ids) + return dict.fromkeys(replica_ids) - records: dict[str, RouteHealthRecord | None] = {} - for i, route_id in enumerate(route_ids): + statuses: dict[ReplicaID, ReplicaHealthStatus | None] = {} + for i, replica_id in enumerate(replica_ids): hgetall_result = results[i] if len(results) > i else None if not hgetall_result: - records[route_id] = None + statuses[replica_id] = None continue - raw = cast(dict[bytes, bytes], hgetall_result) - if not raw: - records[route_id] = None + if not raw or b"replica_id" not in raw: + statuses[replica_id] = None continue - data = {k.decode(): v.decode() for k, v in raw.items()} - if "route_id" not in data: - # Partial hash (e.g., only running_at set by mark_route_running_at) - records[route_id] = None - continue - records[route_id] = RouteHealthRecord.from_valkey_hash(data) + statuses[replica_id] = ReplicaHealthStatus.from_valkey_hash(data) - return records + return statuses @valkey_schedule_resilience.apply() async def close(self) -> None: diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py new file mode 100644 index 00000000000..5aae6e27867 --- /dev/null +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py @@ -0,0 +1,82 @@ +"""Valkey data types for route health management.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from uuid import UUID + +from ai.backend.common.identifier.replica import ReplicaID + + +@dataclass +class ReplicaProbeTarget: + """Probe configuration for a route stored in Valkey. + + Stored as a hash at key `route_probe:{replica_id}`. + Written once by the coordinator when the route enters WARMING_UP (host/port available). + Read by RouteHealthObserver to know what endpoint to probe. + """ + + replica_id: ReplicaID + health_path: str + inference_port: int + replica_host: str + + def to_valkey_hash(self) -> Mapping[str, str]: + return { + "replica_id": str(self.replica_id), + "health_path": self.health_path, + "inference_port": str(self.inference_port), + "replica_host": self.replica_host, + } + + @classmethod + def from_valkey_hash(cls, data: Mapping[str, str]) -> ReplicaProbeTarget: + return cls( + replica_id=ReplicaID(UUID(data["replica_id"])), + health_path=data["health_path"], + inference_port=int(data["inference_port"]), + replica_host=data["replica_host"], + ) + + +@dataclass +class ReplicaHealthResult: + """Input type for recording a health check outcome. + + Passed to ``record_route_health_statuses_batch``; ``last_check`` is + assigned by the client using the current Redis time. + """ + + replica_id: ReplicaID + healthy: bool + + +@dataclass +class ReplicaHealthStatus: + """Health check result for a route stored in Valkey. + + Stored as a hash at key `route_health:{replica_id}`. + Written by RouteHealthObserver after each HTTP probe. + Short TTL — key expiry signals DEGRADED (no recent check). + """ + + replica_id: ReplicaID + healthy: bool + last_check: int # Unix timestamp (Redis time) + + def to_valkey_hash(self) -> Mapping[str, str]: + return { + "replica_id": str(self.replica_id), + "healthy": "1" if self.healthy else "0", + "last_check": str(self.last_check), + } + + @classmethod + def from_valkey_hash(cls, data: Mapping[str, str]) -> ReplicaHealthStatus: + return cls( + replica_id=ReplicaID(UUID(data["replica_id"])), + healthy=data.get("healthy", "0") == "1", + last_check=int(data.get("last_check", "0")), + ) diff --git a/src/ai/backend/manager/api/rest/service/handler.py b/src/ai/backend/manager/api/rest/service/handler.py index c5529faaa1c..6868a977316 100644 --- a/src/ai/backend/manager/api/rest/service/handler.py +++ b/src/ai/backend/manager/api/rest/service/handler.py @@ -167,11 +167,8 @@ def _serve_info_from_dto(dto: ServiceInfo, runtime_variant_name: RuntimeVariant) def _resolve_target_revision_data(info: DeploymentInfo) -> ModelRevisionData | None: - """Resolve the target revision data by id (current first, then deploying).""" - target_id = info.current_revision_id or info.deploying_revision_id - if target_id is None: - return None - return next((r for r in info.model_revisions if r.id == target_id), None) + """Resolve the target revision data (current first, then deploying).""" + return info.current_revision or info.deploying_revision def _serve_info_from_deployment_info( diff --git a/src/ai/backend/manager/data/deployment/types.py b/src/ai/backend/manager/data/deployment/types.py index f164cb412f7..bc85ea5dd36 100644 --- a/src/ai/backend/manager/data/deployment/types.py +++ b/src/ai/backend/manager/data/deployment/types.py @@ -12,7 +12,7 @@ import yarl from pydantic import ConfigDict, Field -from ai.backend.common.config import ModelDefinition, ModelDefinitionDraft +from ai.backend.common.config import ModelDefinition, ModelDefinitionDraft, ModelHealthCheck from ai.backend.common.data.endpoint.types import EndpointLifecycle, ScalingState from ai.backend.common.data.model_deployment.types import ( ActivenessStatus, @@ -29,7 +29,6 @@ from ai.backend.common.identifier.runtime_variant import RuntimeVariantID from ai.backend.common.identifier.vfolder import VFolderUUID from ai.backend.manager.data.session.options import HandlerOptions -from ai.backend.manager.errors.deployment import DeploymentRevisionNotFound if TYPE_CHECKING: from ai.backend.manager.data.session.types import SchedulingResult, SubStepResult @@ -713,26 +712,12 @@ class DeploymentInfo: state: DeploymentState replica: ReplicaData network: DeploymentNetworkData - model_revisions: list[ModelRevisionData] options: DeploymentOptions - current_revision_id: DeploymentRevisionID | None = None + current_revision: ModelRevisionData | None = None policy: DeploymentPolicyData | None = None - deploying_revision_id: DeploymentRevisionID | None = None + deploying_revision: ModelRevisionData | None = None sub_step: DeploymentLifecycleSubStep | None = None - def resolve_revision_data(self, revision_id: DeploymentRevisionID) -> ModelRevisionData: - """Find a ``ModelRevisionData`` by id from ``model_revisions``. - - Raises: - DeploymentRevisionNotFound: If the revision is not found. - """ - for revision in self.model_revisions: - if revision.id == revision_id: - return revision - raise DeploymentRevisionNotFound( - f"Revision {revision_id} not found in model_revisions of deployment {self.id}" - ) - @dataclass(frozen=True) class DeploymentLastHistory: @@ -817,6 +802,7 @@ class RouteInfo: created_at: datetime revision_id: UUID traffic_status: RouteTrafficStatus + health_check: ModelHealthCheck | None error_data: dict[str, Any] = field(default_factory=dict) @property diff --git a/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py new file mode 100644 index 00000000000..29a7b6d37c4 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py @@ -0,0 +1,47 @@ +"""Add health_check column to routings table. + +Revision ID: c3d4e5f6a7b8 +Revises: b2c3d4e5f6a7 +Create Date: 2026-05-15 + +""" + +# Part of: 26.5.0 + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c3d4e5f6a7b8" +down_revision = "b2d4f6e8c1a3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + conn.exec_driver_sql("ALTER TABLE routings ADD COLUMN IF NOT EXISTS health_check JSONB") + # Backfill health_check from revision's model_definition for existing routes. + # Applies ModelHealthCheck defaults for optional fields (interval, max_retries, etc.). + # Only populates when health_check.path exists (required field in ModelHealthCheck). + conn.exec_driver_sql(""" + UPDATE routings r + SET health_check = jsonb_build_object( + 'path', dr.model_definition->'health_check'->>'path', + 'interval', COALESCE((dr.model_definition->'health_check'->>'interval')::float, 10.0), + 'max_retries', COALESCE((dr.model_definition->'health_check'->>'max_retries')::int, 10), + 'max_wait_time', COALESCE((dr.model_definition->'health_check'->>'max_wait_time')::float, 15.0), + 'expected_status_code', COALESCE((dr.model_definition->'health_check'->>'expected_status_code')::int, 200), + 'initial_delay', COALESCE((dr.model_definition->'health_check'->>'initial_delay')::float, 60.0) + ) + FROM deployment_revisions dr + WHERE dr.id = r.revision + AND r.health_check IS NULL + AND dr.model_definition IS NOT NULL + AND dr.model_definition->'health_check' IS NOT NULL + AND dr.model_definition->'health_check'->>'path' IS NOT NULL + """) + + +def downgrade() -> None: + conn = op.get_bind() + conn.exec_driver_sql("ALTER TABLE routings DROP COLUMN IF EXISTS health_check") diff --git a/src/ai/backend/manager/models/endpoint/row.py b/src/ai/backend/manager/models/endpoint/row.py index a653c5b2cfe..2b7e5ed8af9 100644 --- a/src/ai/backend/manager/models/endpoint/row.py +++ b/src/ai/backend/manager/models/endpoint/row.py @@ -69,6 +69,7 @@ DeploymentMetadata, DeploymentNetworkData, DeploymentOptions, + DeploymentPolicyData, DeploymentState, DeploymentSummaryData, ModelDeploymentAutoScalingRuleData, @@ -132,6 +133,18 @@ def _get_endpoint_revisions_join_condition() -> Any: return EndpointRow.id == foreign(DeploymentRevisionRow.endpoint) +def _get_current_revision_row_join_condition() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.deployment_revision import DeploymentRevisionRow + + return EndpointRow.current_revision == DeploymentRevisionRow.id + + +def _get_deploying_revision_row_join_condition() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.deployment_revision import DeploymentRevisionRow + + return EndpointRow.deploying_revision == DeploymentRevisionRow.id + + def _get_endpoint_auto_scaling_policy_join_condition() -> Any: from ai.backend.manager.models.deployment_auto_scaling_policy import ( DeploymentAutoScalingPolicyRow, @@ -293,6 +306,20 @@ class EndpointRow(Base): # type: ignore[misc] primaryjoin=_get_endpoint_revisions_join_condition, order_by="DeploymentRevisionRow.revision_number.desc()", ) + current_revision_row: Mapped[DeploymentRevisionRow | None] = relationship( + "DeploymentRevisionRow", + primaryjoin=_get_current_revision_row_join_condition, + foreign_keys="EndpointRow.current_revision", + viewonly=True, + uselist=False, + ) + deploying_revision_row: Mapped[DeploymentRevisionRow | None] = relationship( + "DeploymentRevisionRow", + primaryjoin=_get_deploying_revision_row_join_condition, + foreign_keys="EndpointRow.deploying_revision", + viewonly=True, + uselist=False, + ) auto_scaling_policy: Mapped[DeploymentAutoScalingPolicyRow | None] = relationship( "DeploymentAutoScalingPolicyRow", @@ -729,24 +756,23 @@ def from_deployment_creator( def to_deployment_info(self) -> DeploymentInfo: """Convert EndpointRow to DeploymentInfo dataclass using revision data.""" - policy_data = None - if self.deployment_policy is not None: - policy_data = self.deployment_policy.to_data() - - model_revisions: list[ModelRevisionData] = [] - for rev_row in self.revisions: - if rev_row.id == self.current_revision or rev_row.id == self.deploying_revision: - model_revisions.append(rev_row.to_data()) - - info = self._to_deployment_info_with_revisions(model_revisions) - info.policy = policy_data - return info + return self._build_deployment_info( + current_revision=( + self.current_revision_row.to_data() if self.current_revision_row else None + ), + deploying_revision=( + self.deploying_revision_row.to_data() if self.deploying_revision_row else None + ), + policy=self.deployment_policy.to_data() if self.deployment_policy is not None else None, + ) - def _to_deployment_info_with_revisions( + def _build_deployment_info( self, - model_revisions: list[ModelRevisionData], + current_revision: ModelRevisionData | None, + deploying_revision: ModelRevisionData | None, + policy: DeploymentPolicyData | None = None, ) -> DeploymentInfo: - """Build DeploymentInfo with pre-built model_revisions dict.""" + """Build DeploymentInfo with current and deploying revision data.""" return DeploymentInfo( id=self.id, metadata=DeploymentMetadata( @@ -775,12 +801,11 @@ def _to_deployment_info_with_revisions( url=self.url, preferred_domain_name=None, ), - model_revisions=list(model_revisions), options=self.options, - current_revision_id=self.current_revision, - deploying_revision_id=self.deploying_revision, + current_revision=current_revision, + deploying_revision=deploying_revision, sub_step=self.sub_step, - policy=self.deployment_policy.to_data() if self.deployment_policy is not None else None, + policy=policy, ) diff --git a/src/ai/backend/manager/models/routing/row.py b/src/ai/backend/manager/models/routing/row.py index 537d4bf48fd..3b8e6795a89 100644 --- a/src/ai/backend/manager/models/routing/row.py +++ b/src/ai/backend/manager/models/routing/row.py @@ -12,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.types import SessionId @@ -27,6 +28,7 @@ from ai.backend.manager.models.base import ( GUID, Base, + PydanticColumn, StrEnumType, ) @@ -110,6 +112,11 @@ class RoutingRow(Base): # type: ignore[misc] ) replica_port: Mapped[int | None] = mapped_column("replica_port", sa.Integer, nullable=True) + # Health check config (copied from revision at creation; None = no health check) + health_check: Mapped[ModelHealthCheck | None] = mapped_column( + "health_check", PydanticColumn(ModelHealthCheck), nullable=True, default=None + ) + # Revision reference without FK (relationship only) revision: Mapped[uuid.UUID] = mapped_column("revision", GUID, nullable=False) traffic_status: Mapped[RouteTrafficStatus] = mapped_column( @@ -264,5 +271,6 @@ def to_route_info(self) -> RouteInfo: created_at=self.created_at, revision_id=self.revision, traffic_status=self.traffic_status, + health_check=self.health_check, error_data=self.error_data or {}, ) diff --git a/src/ai/backend/manager/repositories/deployment/creators/route.py b/src/ai/backend/manager/repositories/deployment/creators/route.py index d6a4641f766..7b2f1b72ce8 100644 --- a/src/ai/backend/manager/repositories/deployment/creators/route.py +++ b/src/ai/backend/manager/repositories/deployment/creators/route.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import Any, override +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID from ai.backend.manager.data.deployment.types import ( @@ -33,6 +34,7 @@ class RouteCreatorSpec(CreatorSpec[RoutingRow]): domain: str project_id: uuid.UUID revision_id: DeploymentRevisionID + health_check: ModelHealthCheck | None traffic_ratio: float = 1.0 traffic_status: RouteTrafficStatus = RouteTrafficStatus.INACTIVE @@ -49,6 +51,7 @@ def build_row(self) -> RoutingRow: traffic_ratio=self.traffic_ratio, revision=self.revision_id, traffic_status=self.traffic_status, + health_check=self.health_check, ) diff --git a/src/ai/backend/manager/repositories/deployment/db_source/db_source.py b/src/ai/backend/manager/repositories/deployment/db_source/db_source.py index c5f364cf920..df3701dfe52 100644 --- a/src/ai/backend/manager/repositories/deployment/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/deployment/db_source/db_source.py @@ -332,13 +332,9 @@ async def create_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == endpoint.id) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), ) ) result = await db_sess.execute(stmt) @@ -393,12 +389,8 @@ async def get_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == endpoint_id) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) @@ -428,12 +420,8 @@ async def get_deployments_by_ids( ) ) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) @@ -518,7 +506,8 @@ async def search_deployments_with_last_history( """ async with self._begin_readonly_session_read_committed() as db_sess: query = sa.select(EndpointRow).options( - selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) query_result = await execute_batch_querier(db_sess, query, querier) @@ -567,12 +556,8 @@ async def list_endpoints_by_name( EndpointRow.lifecycle_stage == EndpointLifecycle.CREATED, ) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), ) ) @@ -635,9 +620,8 @@ async def get_modified_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == endpoint_id) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) @@ -674,13 +658,9 @@ async def update_endpoint_with_spec( sa.select(EndpointRow) .where(EndpointRow.id == updater.pk_value) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), ) ) query_result = await db_sess.execute(stmt) @@ -999,6 +979,7 @@ async def get_routes_by_endpoint( created_at=row.created_at, revision_id=DeploymentRevisionID(row.revision), traffic_status=row.traffic_status, + health_check=row.health_check, replica_host=row.replica_host, replica_port=row.replica_port, updated_at=row.updated_at, @@ -1141,8 +1122,8 @@ async def search_endpoints( """ async with self._begin_readonly_session_read_committed() as db_sess: query = sa.select(EndpointRow).options( - selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row), - selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) @@ -1356,9 +1337,8 @@ async def get_endpoints_with_autoscaling_rules( EndpointRow.id == EndpointAutoScalingRuleRow.endpoint_id, ) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), ) .distinct() ) @@ -1619,6 +1599,7 @@ async def search_route_datas( created_at=row.created_at, revision_id=DeploymentRevisionID(row.revision), traffic_status=row.traffic_status, + health_check=row.health_check, replica_host=row.replica_host, replica_port=row.replica_port, updated_at=row.updated_at, @@ -1664,6 +1645,7 @@ async def search_route_datas_with_last_history( created_at=row.created_at, revision_id=DeploymentRevisionID(row.revision), traffic_status=row.traffic_status, + health_check=row.health_check, replica_host=row.replica_host, replica_port=row.replica_port, updated_at=row.updated_at, @@ -2729,12 +2711,8 @@ async def update_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == updater.pk_value) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) diff --git a/src/ai/backend/manager/repositories/deployment/types/endpoint.py b/src/ai/backend/manager/repositories/deployment/types/endpoint.py index d8779fde3a1..15d4ed56a0c 100644 --- a/src/ai/backend/manager/repositories/deployment/types/endpoint.py +++ b/src/ai/backend/manager/repositories/deployment/types/endpoint.py @@ -11,6 +11,7 @@ import sqlalchemy as sa +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.data.endpoint.types import EndpointLifecycle from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID @@ -78,6 +79,7 @@ class RouteData: created_at: datetime revision_id: DeploymentRevisionID traffic_status: RouteTrafficStatus + health_check: ModelHealthCheck | None replica_host: str | None = None replica_port: int | None = None updated_at: datetime | None = None diff --git a/src/ai/backend/manager/services/deployment/service.py b/src/ai/backend/manager/services/deployment/service.py index 930d588e57a..80ceffbd1b4 100644 --- a/src/ai/backend/manager/services/deployment/service.py +++ b/src/ai/backend/manager/services/deployment/service.py @@ -224,21 +224,7 @@ def _convert_deployment_info_to_data(info: DeploymentInfo) -> ModelDeploymentDat Note: Some fields are set to defaults as DeploymentInfo doesn't have all the data. """ - revision: ModelRevisionData | None = None - if info.current_revision_id is not None: - revision = next( - (r for r in info.model_revisions if r.id == info.current_revision_id), - None, - ) - if revision is None: - log.error( - "Deployment {} has current_revision_id {} but no matching " - "ModelRevisionData was found in DeploymentInfo.model_revisions; " - "current_revision will be reported as null. This usually means " - "EndpointRow.revisions was not eagerly loaded by the caller.", - info.id, - info.current_revision_id, - ) + revision: ModelRevisionData | None = info.current_revision desired_count = info.replica.desired_replica_count if desired_count is None: @@ -257,10 +243,14 @@ def _convert_deployment_info_to_data(info: DeploymentInfo) -> ModelDeploymentDat updated_at=info.metadata.created_at or datetime.now(UTC), ), network_access=info.network, - revision_history_ids=[info.current_revision_id] if info.current_revision_id else [], + revision_history_ids=[info.current_revision.id] + if info.current_revision is not None + else [], revision=revision, - current_revision_id=info.current_revision_id, - deploying_revision_id=info.deploying_revision_id, + current_revision_id=info.current_revision.id if info.current_revision is not None else None, + deploying_revision_id=info.deploying_revision.id + if info.deploying_revision is not None + else None, scaling_rule_ids=[], # Not available in DeploymentInfo replica_state=ReplicaStateData( desired_replica_count=desired_count, diff --git a/src/ai/backend/manager/sokovan/deployment/deployment_controller.py b/src/ai/backend/manager/sokovan/deployment/deployment_controller.py index 90b6ac9d067..c2d67feac22 100644 --- a/src/ai/backend/manager/sokovan/deployment/deployment_controller.py +++ b/src/ai/backend/manager/sokovan/deployment/deployment_controller.py @@ -320,12 +320,11 @@ async def update_deployment( modified_endpoint = await self._deployment_repository.get_modified_endpoint( endpoint_id=endpoint_id, updater=updater ) - if modified_endpoint.current_revision_id is not None: - current_revision = await self._deployment_repository.get_revision( - modified_endpoint.current_revision_id - ) + if modified_endpoint.current_revision is not None: await self._scheduling_controller.validate_session_spec( - SessionValidationSpec.from_revision(model_revision=current_revision) + SessionValidationSpec.from_revision( + model_revision=modified_endpoint.current_revision + ) ) res = await self._deployment_repository.update_endpoint_with_spec(updater) try: @@ -557,7 +556,10 @@ async def activate_revision( # 2. Validate deployment state deployment_info = await self._deployment_repository.get_endpoint_info(deployment_id) - if deployment_info.current_revision_id == revision_id: + if ( + deployment_info.current_revision is not None + and deployment_info.current_revision.id == revision_id + ): raise InvalidEndpointState( f"Revision {revision_id} is already the current revision " f"of deployment {deployment_id}." diff --git a/src/ai/backend/manager/sokovan/deployment/executor.py b/src/ai/backend/manager/sokovan/deployment/executor.py index f52a37eb6cc..6ea2dc8e4e4 100644 --- a/src/ai/backend/manager/sokovan/deployment/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/executor.py @@ -40,6 +40,7 @@ from ai.backend.manager.data.deployment.scale import AutoScalingRule from ai.backend.manager.data.deployment.types import ( DeploymentInfo, + ModelRevisionData, RouteInfo, RouteStatus, RouteTrafficStatus, @@ -47,7 +48,7 @@ from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.prometheus_query_preset import PrometheusQueryPresetData from ai.backend.manager.data.resource.types import ScalingGroupProxyTarget -from ai.backend.manager.errors.deployment import ReplicaCountMismatch +from ai.backend.manager.errors.deployment import DeploymentRevisionNotFound, ReplicaCountMismatch from ai.backend.manager.models.routing import RoutingRow from ai.backend.manager.models.routing.conditions import RouteConditions from ai.backend.manager.repositories.base.rbac.entity_creator import RBACEntityCreator @@ -304,7 +305,7 @@ async def check_ready_deployments_that_need_scaling( # the handler to attempt scaling would permanently wedge the # deployment in SCALING because scale_deployment() would then # refuse to act on a None revision id. - if deployment.deployment_info.current_revision_id is None: + if deployment.deployment_info.current_revision is None: skipped.append(deployment) continue try: @@ -351,12 +352,12 @@ async def scale_deployment( # Phase 2: Evaluate scaling (per-deployment) for deployment in deployments: info = deployment.deployment_info - if info.current_revision_id is None: + if info.current_revision is None: skipped.append(deployment) continue try: out_creators, in_route_ids = self._evaluate_deployment_scaling( - info, route_map, info.current_revision_id + info, route_map, info.current_revision.id ) if out_creators or in_route_ids: scale_out_creators.extend(out_creators) @@ -442,7 +443,7 @@ async def calculate_desired_replicas( # the two stay in lock-step. deployments_to_calculate: list[DeploymentWithHistory] = [] for deployment in deployments: - if deployment.deployment_info.current_revision_id is None: + if deployment.deployment_info.current_revision is None: skipped.append(deployment) continue deployments_to_calculate.append(deployment) @@ -615,7 +616,20 @@ async def _build_endpoint_item( Resolves the runtime variant id to a name at this boundary since the AppProxy wire API still keys on the variant name string. """ - target_revision = deployment.resolve_revision_data(revision_id) + if ( + deployment.current_revision is not None + and deployment.current_revision.id == revision_id + ): + target_revision = deployment.current_revision + elif ( + deployment.deploying_revision is not None + and deployment.deploying_revision.id == revision_id + ): + target_revision = deployment.deploying_revision + else: + raise DeploymentRevisionNotFound( + f"Revision {revision_id} not found in deployment {deployment.id}" + ) health_check_config = None if target_revision.model_definition: @@ -687,6 +701,19 @@ def _evaluate_deployment_scaling( if len(routes) < target_count: # Build creators for scale out new_replica_count = target_count - len(routes) + revision_data: ModelRevisionData | None + if ( + deployment.current_revision is not None + and deployment.current_revision.id == revision_id + ): + revision_data = deployment.current_revision + else: + revision_data = deployment.deploying_revision + health_check = ( + revision_data.model_definition.health_check_config() + if revision_data is not None and revision_data.model_definition + else None + ) for _ in range(new_replica_count): creator_spec = RouteCreatorSpec( deployment_id=deployment.id, @@ -694,6 +721,7 @@ def _evaluate_deployment_scaling( domain=deployment.metadata.domain, project_id=deployment.metadata.project, revision_id=revision_id, + health_check=health_check, ) scale_out_creators.append( RBACEntityCreator( diff --git a/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py b/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py index 158588dfcbe..65a7de192e1 100644 --- a/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py +++ b/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py @@ -160,9 +160,9 @@ async def _ensure_endpoints_registered( info = deployment.deployment_info if info.network.url: continue - if info.deploying_revision_id is None: + if info.deploying_revision is None: continue - entries.append((deployment, info.deploying_revision_id)) + entries.append((deployment, info.deploying_revision.id)) if not entries: return set() @@ -186,7 +186,7 @@ async def execute( d.deployment_info.id for d in deployments if not d.deployment_info.network.url - and d.deployment_info.deploying_revision_id is not None + and d.deployment_info.deploying_revision is not None } if failed_registration_ids: deployments = [ diff --git a/src/ai/backend/manager/sokovan/deployment/handlers/replica.py b/src/ai/backend/manager/sokovan/deployment/handlers/replica.py index 4536437cf2b..7f40d91b026 100644 --- a/src/ai/backend/manager/sokovan/deployment/handlers/replica.py +++ b/src/ai/backend/manager/sokovan/deployment/handlers/replica.py @@ -96,7 +96,7 @@ async def execute( """ log.debug("Checking deployment replicas") - scalable = [d for d in deployments if d.deployment_info.current_revision_id is not None] + scalable = [d for d in deployments if d.deployment_info.current_revision is not None] if len(scalable) != len(deployments): skipped = len(deployments) - len(scalable) log.debug( diff --git a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py index 7c4cba36905..874d7057e28 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py +++ b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py @@ -38,6 +38,7 @@ AppProxySyncRouteHandler, HealthCheckRouteHandler, ProvisioningRouteHandler, + ReplicaProbeTargetSyncHandler, RouteEvictionHandler, RouteHandler, RunningRouteHandler, @@ -184,6 +185,9 @@ def _init_handlers(self, executor: RouteExecutor) -> Mapping[RouteLifecycleType, route_executor=executor, event_producer=self._event_producer, ), + RouteLifecycleType.PROBE_TARGET_SYNC: ReplicaProbeTargetSyncHandler( + route_executor=executor, + ), } async def process_route_lifecycle( @@ -505,6 +509,13 @@ def _create_task_specs() -> list[RouteTaskSpec]: long_interval=30.0, initial_delay=15.0, ), + # Probe target sync - refresh ReplicaProbeTarget TTL and recover lost Valkey entries + RouteTaskSpec( + RouteLifecycleType.PROBE_TARGET_SYNC, + short_interval=None, + long_interval=60.0, + initial_delay=30.0, + ), ] def create_task_specs(self) -> list[EventTaskSpec]: diff --git a/src/ai/backend/manager/sokovan/deployment/route/executor.py b/src/ai/backend/manager/sokovan/deployment/route/executor.py index 6e3c54528cf..16a1c2e357f 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/route/executor.py @@ -7,7 +7,7 @@ from ai.backend.common.clients.http_client.client_pool import ClientPool from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteHealthRecord, + ReplicaProbeTarget, ValkeyScheduleClient, ) from ai.backend.common.dto.appproxy_coordinator.v2.endpoint.request import ( @@ -242,7 +242,7 @@ async def check_starting_routes(self, routes: Sequence[RouteData]) -> RouteExecu if not routes: return RouteExecutionResult(successes=[], errors=[]) - route_ids = {ReplicaID(route.route_id) for route in routes} + route_ids = {route.route_id for route in routes} session_infos: dict[ReplicaID, RouteSessionInfo | None] = dict( await self._deployment_repo.fetch_route_session_kernel_infos(route_ids) ) @@ -252,7 +252,7 @@ async def check_starting_routes(self, routes: Sequence[RouteData]) -> RouteExecu updates: dict[ReplicaID, RouteSessionKernelInfo] = {} for route in routes: - replica_id = ReplicaID(route.route_id) + replica_id = route.route_id info = session_infos.get(replica_id) if info is None: @@ -293,7 +293,7 @@ async def check_starting_routes(self, routes: Sequence[RouteData]) -> RouteExecu if updates: await self._deployment_repo.update_route_replica_info(updates) - await self._initialize_health_records(successes, updates) + await self._register_route_probe_targets(successes, updates) return RouteExecutionResult( successes=successes, @@ -341,90 +341,99 @@ async def check_running_routes(self, routes: Sequence[RouteData]) -> RouteExecut errors=errors, ) - async def _initialize_health_records( + @staticmethod + def _build_probe_target(route: RouteData) -> ReplicaProbeTarget | None: + """Build a ReplicaProbeTarget from a route, or None if any required field is absent.""" + if route.health_check is None or route.replica_host is None or route.replica_port is None: + return None + return ReplicaProbeTarget( + replica_id=route.route_id, + health_path=route.health_check.path, + inference_port=route.replica_port, + replica_host=route.replica_host, + ) + + async def _register_route_probe_targets( self, routes: Sequence[RouteData], replica_info: Mapping[ReplicaID, RouteSessionKernelInfo], ) -> None: - """Create RouteHealthRecords in Valkey for routes that just got replica info.""" - revision_ids = {r.revision_id for r in routes} - health_configs = await self._deployment_repo.fetch_health_check_configs_by_revision_ids( - revision_ids - ) - route_id_strs = [str(r.route_id) for r in routes] - existing_running_at = await self._valkey_schedule.get_route_running_at_batch(route_id_strs) - current_time = await self._valkey_schedule.get_redis_time() - - records: list[RouteHealthRecord] = [] - for route in routes: - kernel = replica_info[route.route_id] - health_config = health_configs.get(route.revision_id) - - health_path = health_config.path if health_config else "/" - initial_delay = health_config.initial_delay if health_config else 60.0 - created_at = int(route.created_at.timestamp()) - - route_id_str = str(route.route_id) - running_at = existing_running_at.get(route_id_str) or current_time - initial_delay_until = running_at + int(initial_delay) - - records.append( - RouteHealthRecord( - route_id=route_id_str, - created_at=created_at, - initial_delay_until=initial_delay_until, - health_path=health_path, - inference_port=kernel.replica_port, - replica_host=kernel.replica_host, - running_at=running_at, - ) + """Register ReplicaProbeTargets in Valkey for routes that just got replica info.""" + targets: list[ReplicaProbeTarget] = [ + ReplicaProbeTarget( + replica_id=route.route_id, + health_path=route.health_check.path, + inference_port=replica_info[route.route_id].replica_port, + replica_host=replica_info[route.route_id].replica_host, ) + for route in routes + if route.health_check is not None + ] + + if targets: + await self._valkey_schedule.register_route_probe_targets_batch(targets) + log.debug("Registered {} ReplicaProbeTargets in Valkey", len(targets)) + + async def sync_route_probe_targets(self, routes: Sequence[RouteData]) -> RouteExecutionResult: + """Sync ReplicaProbeTargets to Valkey for routes with known replica info. + + Handles two cases: + - Valkey data lost (restart, eviction) → re-registers probe targets + - TTL refresh for long-running routes - if records: - await self._valkey_schedule.initialize_route_health_records_batch(records) - log.debug("Initialized {} RouteHealthRecords in Valkey", len(records)) + Routes without health_check/replica_host/replica_port are skipped silently. + """ + targets = [t for route in routes if (t := self._build_probe_target(route)) is not None] + + if targets: + with RouteRecorderContext.shared_phase( + "register_probe_targets", + entity_ids={t.replica_id for t in targets}, + ): + with RouteRecorderContext.shared_step("write_probe_targets"): + await self._valkey_schedule.register_route_probe_targets_batch(targets) + log.debug("Synced {} ReplicaProbeTargets to Valkey", len(targets)) + + return RouteExecutionResult(successes=[], errors=[]) async def check_warming_up_health(self, routes: Sequence[RouteData]) -> RouteExecutionResult: """Check health of PROVISIONING+WARMING_UP routes for initial activation. - success: health probe passed, or no health check configured → RUNNING+ACTIVE - - failure: initial_delay exceeded without a passing probe → TERMINATING - - (no transition): still within initial_delay → route stays WARMING_UP + - failure: last_transition_at + initial_delay exceeded without a passing probe → TERMINATING + - (no transition): still within initial_delay, or last_transition_at unknown → stay WARMING_UP """ - route_id_strs = [str(r.route_id) for r in routes] - revision_ids = {r.revision_id for r in routes} - - records = await self._valkey_schedule.get_route_health_records_batch(route_id_strs) - health_configs = await self._deployment_repo.fetch_health_check_configs_by_revision_ids( - revision_ids - ) - current_time = await self._valkey_schedule.get_redis_time() + statuses = await self._valkey_schedule.get_route_health_statuses_batch([ + route.route_id for route in routes + ]) + now = await self._deployment_repo.get_db_now() successes: list[RouteData] = [] errors: list[RouteExecutionError] = [] for route in routes: - health_config = health_configs.get(route.revision_id) - if health_config is None: + if route.health_check is None: successes.append(route) continue - route_id_str = str(route.route_id) - record = records.get(route_id_str) - - if record is None: + status = statuses.get(route.route_id) + if status is not None and status.healthy: + successes.append(route) continue - if record.last_check > 0 and not record.is_stale(current_time) and record.healthy: - successes.append(route) + if route.last_transition_at is None: continue - if current_time > record.initial_delay_until: + elapsed = (now - route.last_transition_at).total_seconds() + if elapsed > route.health_check.initial_delay: errors.append( RouteExecutionError( route_info=route, reason="Route warming-up timed out waiting for healthy probe", - error_detail=f"Elapsed {current_time - record.initial_delay_until}s after initial_delay", + error_detail=( + f"Elapsed {elapsed:.0f}s exceeds " + f"initial_delay {route.health_check.initial_delay}s" + ), ) ) @@ -432,25 +441,16 @@ async def check_warming_up_health(self, routes: Sequence[RouteData]) -> RouteExe async def check_route_health(self, routes: Sequence[RouteData]) -> RouteExecutionResult: """ - Check health status of routes and push newly-healthy ones to AppProxy. + Check health status of RUNNING routes and push newly-healthy ones to AppProxy. - Reads RouteHealthRecord and classifies based on computed healthy/stale: - - HEALTHY: record.healthy is True - - UNHEALTHY: record.healthy is False and not stale (only past initial_delay) - - DEGRADED: record is stale or missing (only past initial_delay) - - (no transition): within initial_delay without a successful probe yet — - the route keeps whatever health_status it already has so a transient - warmup failure does not downgrade it prematurely. + Reads RouteHealthStatus from Valkey and classifies: + - HEALTHY: status.healthy is True + - UNHEALTHY: status.healthy is False + - DEGRADED: status absent (key missing or TTL expired — no recent check) Routes whose pre-execute health_status was not HEALTHY but whose - probe just passed are pushed to AppProxy synchronously via - :meth:`register_routes_now` so traffic can flow without waiting - for the long-cycle ``AppProxySyncRouteHandler`` fallback. Push - failures are logged and swallowed — the fallback converges - state, and a stuck health-check tick would block all later - observations. The handler keeps ``post_process`` to a thin - logging shim because all the work that changes external state - belongs here in the executor. + probe just passed are pushed to AppProxy synchronously so traffic + can flow without waiting for the long-cycle fallback. Args: routes: Routes to check health for @@ -458,12 +458,12 @@ async def check_route_health(self, routes: Sequence[RouteData]) -> RouteExecutio Returns: Result with successes (healthy), errors (unhealthy), stale (degraded) """ - # Phase 1: Load RouteHealthRecords + # Phase 1: Load RouteHealthStatuses with RouteRecorderContext.shared_phase("load_health_status"): with RouteRecorderContext.shared_step("query_health_check_results"): - route_ids = [str(route.route_id) for route in routes] - records = await self._valkey_schedule.get_route_health_records_batch(route_ids) - current_time = await self._valkey_schedule.get_redis_time() + statuses = await self._valkey_schedule.get_route_health_statuses_batch([ + route.route_id for route in routes + ]) successes: list[RouteData] = [] errors: list[RouteExecutionError] = [] @@ -471,45 +471,24 @@ async def check_route_health(self, routes: Sequence[RouteData]) -> RouteExecutio # Phase 2: Classify health state (per-route) for route in routes: - route_id_str = str(route.route_id) - record = records.get(route_id_str) + status = statuses.get(route.route_id) - if record is None: - # No RouteHealthRecord — not yet initialized + if status is None: + # Key absent or TTL expired → DEGRADED stale.append(route) continue - within_initial_delay = current_time < record.initial_delay_until - - # Success path always wins — a healthy probe transitions the route - # to HEALTHY even if initial_delay has not elapsed, so the kernel - # can start receiving traffic as soon as it is ready. - if record.last_check > 0 and not record.is_stale(current_time) and record.healthy: + if status.healthy: successes.append(route) - continue - - # Within initial_delay and no successful probe yet — hold the - # existing health_status (NOT_CHECKED/HEALTHY/UNHEALTHY/DEGRADED) - # by skipping classification entirely. - if within_initial_delay: - continue - - if record.last_check == 0: - stale.append(route) - continue - - if record.is_stale(current_time): - stale.append(route) - continue - - errors.append( - RouteExecutionError( - route_info=route, - reason="Route health check failed", - error_detail="RouteHealthRecord reports unhealthy", - error_code=None, + else: + errors.append( + RouteExecutionError( + route_info=route, + reason="Route health check failed", + error_detail="RouteHealthStatus reports unhealthy", + error_code=None, + ) ) - ) # Phase 3: Push newly-healthy routes to AppProxy. # ``successes`` carries the pre-execute RouteData snapshot; routes @@ -1056,10 +1035,10 @@ async def cleanup_routes_by_config(self, routes: Sequence[RouteData]) -> RouteEx else: deployment_cleanup_config[deployment.id] = set() valid_revisions: set[DeploymentRevisionID] = set() - if deployment.current_revision_id is not None: - valid_revisions.add(deployment.current_revision_id) - if deployment.deploying_revision_id is not None: - valid_revisions.add(deployment.deploying_revision_id) + if deployment.current_revision is not None: + valid_revisions.add(deployment.current_revision.id) + if deployment.deploying_revision is not None: + valid_revisions.add(deployment.deploying_revision.id) deployment_valid_revisions[deployment.id] = valid_revisions successes: list[RouteData] = [] diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py index 237517fb9d0..9eb8233ebe9 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py @@ -3,6 +3,7 @@ from .appproxy_sync import AppProxySyncRouteHandler from .base import RouteHandler from .health_check import HealthCheckRouteHandler +from .probe_target_sync import ReplicaProbeTargetSyncHandler from .provisioning import ProvisioningRouteHandler from .route_eviction import RouteEvictionHandler from .running import RunningRouteHandler @@ -14,6 +15,7 @@ __all__ = [ "AppProxySyncRouteHandler", "HealthCheckRouteHandler", + "ReplicaProbeTargetSyncHandler", "ProvisioningRouteHandler", "RouteEvictionHandler", "RouteHandler", diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py index bda30a65756..0ebbd140c7a 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py @@ -7,10 +7,10 @@ from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.deployment.types import ( RouteHandlerCategory, - RouteHealthStatus, RouteStatus, RouteStatusTransitions, RouteTargetStatuses, + RouteTrafficStatus, ) from ai.backend.manager.defs import LockID from ai.backend.manager.repositories.deployment.types import RouteData @@ -58,7 +58,7 @@ def category(cls) -> RouteHandlerCategory: def target_statuses(cls) -> RouteTargetStatuses: return RouteTargetStatuses( lifecycle=[RouteStatus.RUNNING], - health=[RouteHealthStatus.HEALTHY], + traffic=[RouteTrafficStatus.ACTIVE], ) @classmethod diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py index 6a8af7221d0..0cef5584561 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py @@ -1,9 +1,8 @@ """Observer for performing HTTP health checks on routes. -Reads RouteHealthRecord from Valkey, performs HTTP health checks in parallel, -and writes manager_healthy/manager_last_check back to Valkey. -During initial_delay, failures are ignored (not written). -The HealthCheckRouteHandler reads from Valkey and performs DB transitions. +Reads ReplicaProbeTarget from Valkey to get probe config (health_path, host, port), +performs HTTP health checks in parallel, and writes RouteHealthStatus back to Valkey. +The short TTL on RouteHealthStatus automatically signals DEGRADED on expiry. """ from __future__ import annotations @@ -15,7 +14,7 @@ import aiohttp from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteHealthRecord, + ReplicaHealthResult, ValkeyScheduleClient, ) from ai.backend.logging import BraceStyleAdapter @@ -30,12 +29,11 @@ class RouteHealthObserver(RouteObserver): - """Performs HTTP health checks on routes using RouteHealthRecord. + """Performs HTTP health checks on routes using ReplicaProbeTarget from Valkey. - Reads RouteHealthRecord from Valkey to get health_path, replica_host, inference_port. + Reads ReplicaProbeTarget (health_path, replica_host, inference_port) from Valkey. HTTP checks run in parallel via asyncio.gather. - During initial_delay period, health check failures are ignored (not written to Valkey). - After initial_delay, both success and failure results are written. + Writes RouteHealthStatus back to Valkey; TTL expiry signals DEGRADED automatically. """ def __init__( @@ -51,59 +49,46 @@ def name(cls) -> str: return "route-health-observer" async def observe(self, routes: Sequence[RouteData]) -> RouteObservationResult: - """Check health for routes using RouteHealthRecord from Valkey.""" + """Check health for routes using ReplicaProbeTarget from Valkey.""" if not routes: return RouteObservationResult(observed_count=0) - # Filter routes that have replica connection info - checkable = [r for r in routes if r.replica_host and r.replica_port] - if not checkable: - return RouteObservationResult(observed_count=0) + # Load ReplicaProbeTargets from Valkey (keyed by ReplicaID) + replica_ids = [r.route_id for r in routes] + probe_targets = await self._valkey_schedule.get_route_probe_targets_batch(replica_ids) - # Load RouteHealthRecords from Valkey - route_ids = [str(r.route_id) for r in checkable] - records = await self._valkey_schedule.get_route_health_records_batch(route_ids) - - # Collect routes that have records - targets: list[tuple[str, RouteHealthRecord]] = [] - for route in checkable: - route_id_str = str(route.route_id) - record = records.get(route_id_str) - if record is not None: - targets.append((route_id_str, record)) + # Collect routes that have probe targets + checkable = [ + (route, target) + for route in routes + if (target := probe_targets.get(route.route_id)) is not None + ] - if not targets: - if checkable: + if not checkable: + if routes: log.warning( - "Health observer: {} checkable routes but 0 have records in Valkey", - len(checkable), + "Health observer: {} routes but 0 have probe targets in Valkey", + len(routes), ) return RouteObservationResult(observed_count=0) # Perform HTTP health checks in parallel check_tasks = [ - self._http_health_check(record.replica_host, record.inference_port, record.health_path) - for _, record in targets + self._http_health_check(target.replica_host, target.inference_port, target.health_path) + for _, target in checkable ] results = await asyncio.gather(*check_tasks) - # Write results to Valkey - current_time = await self._valkey_schedule.get_redis_time() - for (route_id_str, record), is_healthy in zip(targets, results, strict=False): - within_initial_delay = current_time < record.initial_delay_until - - # Always refresh TTL to prevent key expiry - await self._valkey_schedule.refresh_route_health_ttl(route_id_str) - - if is_healthy: - await self._valkey_schedule.update_route_manager_health(route_id_str, True) - elif not within_initial_delay: - await self._valkey_schedule.update_route_manager_health(route_id_str, False) - # else: failure within initial_delay → ignore (don't write) + # Write results to Valkey in a single batch (TTL refreshed on every call) + health_results = [ + ReplicaHealthResult(replica_id=route.route_id, healthy=is_healthy) + for (route, _target), is_healthy in zip(checkable, results, strict=False) + ] + await self._valkey_schedule.record_route_health_statuses_batch(health_results) - if targets: - log.debug("Health observer: checked {} routes", len(targets)) - return RouteObservationResult(observed_count=len(targets)) + if checkable: + log.debug("Health observer: checked {} routes", len(checkable)) + return RouteObservationResult(observed_count=len(checkable)) @staticmethod async def _http_health_check(host: str, port: int, path: str) -> bool: diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py new file mode 100644 index 00000000000..c79ff5963df --- /dev/null +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py @@ -0,0 +1,80 @@ +"""Handler for syncing ReplicaProbeTargets to Valkey for active routes.""" + +import logging +from collections.abc import Sequence + +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.data.deployment.types import ( + RouteHandlerCategory, + RouteHealthStatus, + RouteStatus, + RouteStatusTransitions, + RouteTargetStatuses, +) +from ai.backend.manager.defs import LockID +from ai.backend.manager.repositories.deployment.types import RouteData +from ai.backend.manager.sokovan.deployment.route.executor import RouteExecutor +from ai.backend.manager.sokovan.deployment.route.types import RouteExecutionResult + +from .base import RouteHandler + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class ReplicaProbeTargetSyncHandler(RouteHandler): + """Periodically syncs ReplicaProbeTarget entries to Valkey for active routes. + + Targets RUNNING routes that have already been health-checked (HEALTHY, + UNHEALTHY, DEGRADED). NOT_CHECKED routes are excluded intentionally — + their probe targets are registered when they first transition to WARMING_UP. + + Covers two cases: + - Valkey data lost (restart, eviction) → re-registers probe targets + - TTL refresh for long-running routes that would otherwise expire + """ + + def __init__(self, route_executor: RouteExecutor) -> None: + self._route_executor = route_executor + + @classmethod + def name(cls) -> str: + return "replica-probe-target-sync" + + @property + def lock_id(self) -> LockID | None: + return None + + @classmethod + def category(cls) -> RouteHandlerCategory: + return RouteHandlerCategory.SYNC + + @classmethod + def target_statuses(cls) -> RouteTargetStatuses: + return RouteTargetStatuses( + lifecycle=[RouteStatus.RUNNING], + health=[ + RouteHealthStatus.HEALTHY, + RouteHealthStatus.UNHEALTHY, + RouteHealthStatus.DEGRADED, + ], + ) + + @classmethod + def status_transitions(cls) -> RouteStatusTransitions: + return RouteStatusTransitions( + success=None, + failure=None, + stale=None, + ) + + async def execute(self, routes: Sequence[RouteData]) -> RouteExecutionResult: + """Sync probe targets for routes that have health_check and replica info.""" + return await self._route_executor.sync_route_probe_targets(routes) + + async def post_process(self, result: RouteExecutionResult) -> None: + if result.errors: + log.warning( + "Probe target sync: {} succeeded, {} failed", + len(result.successes), + len(result.errors), + ) diff --git a/src/ai/backend/manager/sokovan/deployment/route/types.py b/src/ai/backend/manager/sokovan/deployment/route/types.py index 3cbd0dbd981..b661ab4cb7a 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/types.py +++ b/src/ai/backend/manager/sokovan/deployment/route/types.py @@ -19,6 +19,7 @@ class RouteLifecycleType(StrEnum): SERVICE_DISCOVERY_SYNC = "service_discovery_sync" APPPROXY_SYNC = "appproxy_sync" OBSERVE_HEALTH = "observe_health" + PROBE_TARGET_SYNC = "probe_target_sync" @dataclass diff --git a/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py b/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py index f3a669ceb0a..6606dd486f6 100644 --- a/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py +++ b/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py @@ -91,9 +91,9 @@ async def evaluate( # to count accumulated failures for rollback detection, but old # terminated routes are irrelevant and would bloat the result set. deploying_revision_ids = { - deployment.deploying_revision_id + deployment.deploying_revision.id for deployment in deployments - if deployment.deploying_revision_id is not None + if deployment.deploying_revision is not None } route_conditions: list[QueryCondition] = [ RouteConditions.by_endpoint_ids(endpoint_ids), diff --git a/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py b/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py index bc1311d558b..9cf8cea86f9 100644 --- a/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py +++ b/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py @@ -87,13 +87,13 @@ def evaluate_cycle( if not isinstance(spec, RollingUpdateSpec): raise TypeError(f"Expected RollingUpdateSpec, got {type(spec).__name__}") desired = deployment.replica.target_replica_count - deploying_revision_id = deployment.deploying_revision_id - if deploying_revision_id is None: + deploying_revision = deployment.deploying_revision + if deploying_revision is None: raise InvalidEndpointState( f"Deployment {deployment.id} has DEPLOYING lifecycle but deploying_revision_id is None. " "This indicates an inconsistent state — the deployment will be skipped." ) - classified = self._classify_routes(routes, deploying_revision_id) + classified = self._classify_routes(routes, deploying_revision.id) log.info( "deployment {}: sub_step={}, routes total={}, " "old_active={}, new_prov={}, new_healthy={}, new_unhealthy={}, new_failed={}", @@ -129,7 +129,11 @@ def _classify_routes( elif route.status.is_inactive(): classified.new_failed_count += 1 elif route.status == RouteStatus.RUNNING: - if route.health_status == RouteHealthStatus.HEALTHY: + if route.health_status == RouteHealthStatus.HEALTHY or route.health_check is None: + # Routes without a health check have no probe data, so we treat them + # as ready once their process is RUNNING (health_status stays DEGRADED + # in DB because no Valkey entry exists — that is correct behaviour, but + # it must not block the DEPLOYING → READY transition). classified.new_healthy_count += 1 else: # UNHEALTHY / DEGRADED / NOT_CHECKED all count here: @@ -267,10 +271,16 @@ def _build_route_creators( count: int, ) -> list[RBACEntityCreator[RoutingRow]]: """Build route creator specs for new revision routes.""" - if deployment.deploying_revision_id is None: + if deployment.deploying_revision is None: raise DeploymentHasNoTargetRevision( f"Cannot create routes: deployment {deployment.id} has no deploying revision" ) + revision_data = deployment.deploying_revision + health_check = ( + revision_data.model_definition.health_check_config() + if revision_data.model_definition + else None + ) creators: list[RBACEntityCreator[RoutingRow]] = [] for _ in range(count): spec = RouteCreatorSpec( @@ -278,7 +288,8 @@ def _build_route_creators( session_owner_id=deployment.metadata.session_owner, domain=deployment.metadata.domain, project_id=deployment.metadata.project, - revision_id=deployment.deploying_revision_id, + revision_id=deployment.deploying_revision.id, + health_check=health_check, ) creators.append( RBACEntityCreator( diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py index 65b50dc1ec5..f61d397ca8c 100644 --- a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py @@ -22,7 +22,12 @@ HealthCheckStatus, ValkeyScheduleClient, ) +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + ReplicaHealthResult, + ReplicaProbeTarget, +) from ai.backend.common.defs import REDIS_LIVE_DB +from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.typed_validators import HostPortPair as HostPortPairModel from ai.backend.common.types import AgentId, KernelId, SessionId, ValkeyTarget @@ -898,3 +903,208 @@ async def test_remove_deletes_only_specified_sessions( result = await valkey_schedule_client.get_force_terminated_sessions() assert result == [sid_keep] + + +class TestReplicaProbeTargetClient: + """Test ValkeyScheduleClient methods for ReplicaProbeTarget.""" + + @pytest.fixture + async def valkey_schedule_client( + self, + redis_container: tuple[str, HostPortPairModel], + ) -> AsyncGenerator[ValkeyScheduleClient, None]: + _, hostport_pair = redis_container + client = await ValkeyScheduleClient.create( + valkey_target=ValkeyTarget(addr=hostport_pair.address), + db_id=REDIS_LIVE_DB, + human_readable_name="test-route-probe-target", + ) + try: + yield client + finally: + await client.close() + + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + def _make_target(self, replica_id: ReplicaID) -> ReplicaProbeTarget: + return ReplicaProbeTarget( + replica_id=replica_id, + health_path="/health", + inference_port=8080, + replica_host="10.0.0.1", + ) + + async def test_register_and_get( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + target = self._make_target(replica_id) + await valkey_schedule_client.register_route_probe_targets_batch([target]) + results = await valkey_schedule_client.get_route_probe_targets_batch([replica_id]) + assert results[replica_id] == target + + async def test_register_batch_multiple( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + targets = [self._make_target(ReplicaID(uuid4())) for _ in range(3)] + await valkey_schedule_client.register_route_probe_targets_batch(targets) + replica_ids = [t.replica_id for t in targets] + results = await valkey_schedule_client.get_route_probe_targets_batch(replica_ids) + for target in targets: + assert results[target.replica_id] == target + + async def test_get_missing_returns_none( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + results = await valkey_schedule_client.get_route_probe_targets_batch([replica_id]) + assert results[replica_id] is None + + async def test_register_empty_does_nothing( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + await valkey_schedule_client.register_route_probe_targets_batch([]) + + async def test_get_empty_returns_empty_dict( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + results = await valkey_schedule_client.get_route_probe_targets_batch([]) + assert results == {} + + async def test_register_overwrites_existing( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.register_route_probe_targets_batch([ + self._make_target(replica_id) + ]) + updated = ReplicaProbeTarget( + replica_id=replica_id, + health_path="/healthz", + inference_port=9000, + replica_host="10.0.0.2", + ) + await valkey_schedule_client.register_route_probe_targets_batch([updated]) + results = await valkey_schedule_client.get_route_probe_targets_batch([replica_id]) + assert results[replica_id] == updated + + +class TestReplicaHealthStatusClient: + """Test ValkeyScheduleClient methods for ReplicaHealthStatus.""" + + @pytest.fixture + async def valkey_schedule_client( + self, + redis_container: tuple[str, HostPortPairModel], + ) -> AsyncGenerator[ValkeyScheduleClient, None]: + _, hostport_pair = redis_container + client = await ValkeyScheduleClient.create( + valkey_target=ValkeyTarget(addr=hostport_pair.address), + db_id=REDIS_LIVE_DB, + human_readable_name="test-route-health-status", + ) + try: + yield client + finally: + await client.close() + + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + async def test_record_healthy_and_get( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=True) + ]) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + status = results[replica_id] + assert status is not None + assert status.healthy is True + assert status.last_check > 0 + + async def test_record_unhealthy_and_get( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=False) + ]) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + status = results[replica_id] + assert status is not None + assert status.healthy is False + + async def test_get_missing_returns_none( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + assert results[replica_id] is None + + async def test_get_empty_returns_empty_dict( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + results = await valkey_schedule_client.get_route_health_statuses_batch([]) + assert results == {} + + async def test_key_deletion_simulates_ttl_expiry( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + """Deleting the key (simulating TTL expiry) results in None → DEGRADED.""" + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=True) + ]) + key = valkey_schedule_client._get_route_health_status_key(replica_id) + async with valkey_schedule_client._client.client() as conn: + await conn.delete([key]) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + assert results[replica_id] is None + + async def test_record_batch_multiple( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + replica_ids = [ReplicaID(uuid4()) for _ in range(3)] + for rid in replica_ids: + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=rid, healthy=True) + ]) + results = await valkey_schedule_client.get_route_health_statuses_batch(replica_ids) + assert len(results) == 3 + for rid in replica_ids: + status = results[rid] + assert status is not None + assert status.healthy is True + + async def test_record_overwrites_previous( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=True) + ]) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=False) + ]) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + status = results[replica_id] + assert status is not None + assert status.healthy is False diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py new file mode 100644 index 00000000000..c55105c5831 --- /dev/null +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py @@ -0,0 +1,101 @@ +"""Unit tests for ReplicaProbeTarget and ReplicaHealthStatus Valkey type serialization.""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + ReplicaHealthStatus, + ReplicaProbeTarget, +) +from ai.backend.common.identifier.replica import ReplicaID + + +class TestReplicaProbeTargetSerialization: + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + @pytest.fixture + def target(self, replica_id: ReplicaID) -> ReplicaProbeTarget: + return ReplicaProbeTarget( + replica_id=replica_id, + health_path="/health", + inference_port=8080, + replica_host="10.0.0.1", + ) + + def test_to_valkey_hash(self, target: ReplicaProbeTarget) -> None: + h = target.to_valkey_hash() + assert h["replica_id"] == str(target.replica_id) + assert h["health_path"] == "/health" + assert h["inference_port"] == "8080" + assert h["replica_host"] == "10.0.0.1" + + def test_from_valkey_hash(self, replica_id: ReplicaID) -> None: + data = { + "replica_id": str(replica_id), + "health_path": "/healthz", + "inference_port": "9000", + "replica_host": "192.168.1.100", + } + target = ReplicaProbeTarget.from_valkey_hash(data) + assert target.replica_id == replica_id + assert target.health_path == "/healthz" + assert target.inference_port == 9000 + assert target.replica_host == "192.168.1.100" + + def test_round_trip(self, target: ReplicaProbeTarget) -> None: + restored = ReplicaProbeTarget.from_valkey_hash(target.to_valkey_hash()) + assert restored == target + + +class TestReplicaHealthStatusSerialization: + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + @pytest.fixture + def healthy_status(self, replica_id: ReplicaID) -> ReplicaHealthStatus: + return ReplicaHealthStatus( + replica_id=replica_id, + healthy=True, + last_check=1700000000, + ) + + def test_to_valkey_hash_healthy(self, healthy_status: ReplicaHealthStatus) -> None: + h = healthy_status.to_valkey_hash() + assert h["replica_id"] == str(healthy_status.replica_id) + assert h["healthy"] == "1" + assert h["last_check"] == "1700000000" + + def test_to_valkey_hash_unhealthy(self, replica_id: ReplicaID) -> None: + status = ReplicaHealthStatus( + replica_id=replica_id, + healthy=False, + last_check=1700000000, + ) + assert status.to_valkey_hash()["healthy"] == "0" + + def test_from_valkey_hash(self, replica_id: ReplicaID) -> None: + data = { + "replica_id": str(replica_id), + "healthy": "1", + "last_check": "1700000000", + } + status = ReplicaHealthStatus.from_valkey_hash(data) + assert status.replica_id == replica_id + assert status.healthy is True + assert status.last_check == 1700000000 + + def test_from_valkey_hash_missing_optional_fields(self, replica_id: ReplicaID) -> None: + """Missing healthy/last_check fields default to safe values.""" + status = ReplicaHealthStatus.from_valkey_hash({"replica_id": str(replica_id)}) + assert status.healthy is False + assert status.last_check == 0 + + def test_round_trip(self, healthy_status: ReplicaHealthStatus) -> None: + restored = ReplicaHealthStatus.from_valkey_hash(healthy_status.to_valkey_hash()) + assert restored == healthy_status diff --git a/tests/unit/manager/repositories/deployment/test_deployment_repository.py b/tests/unit/manager/repositories/deployment/test_deployment_repository.py index 875f206a39f..17d43027744 100644 --- a/tests/unit/manager/repositories/deployment/test_deployment_repository.py +++ b/tests/unit/manager/repositories/deployment/test_deployment_repository.py @@ -3359,6 +3359,7 @@ async def test_create_route( domain=test_domain_name, project_id=test_group_id, revision_id=DeploymentRevisionID(uuid.uuid4()), + health_check=None, traffic_ratio=1.0, traffic_status=RouteTrafficStatus.ACTIVE, ) @@ -3393,6 +3394,7 @@ async def test_update_route_status( domain=test_domain_name, project_id=test_group_id, revision_id=DeploymentRevisionID(uuid.uuid4()), + health_check=None, ) creator = RBACEntityCreator( spec=spec, @@ -3441,6 +3443,7 @@ async def test_update_route_with_unified_spec( domain=test_domain_name, project_id=test_group_id, revision_id=DeploymentRevisionID(uuid.uuid4()), + health_check=None, ) creator = RBACEntityCreator( spec=spec, diff --git a/tests/unit/manager/services/deployment/test_deployment_crud_actions.py b/tests/unit/manager/services/deployment/test_deployment_crud_actions.py index f71c2a80b46..fec1f73b085 100644 --- a/tests/unit/manager/services/deployment/test_deployment_crud_actions.py +++ b/tests/unit/manager/services/deployment/test_deployment_crud_actions.py @@ -155,7 +155,6 @@ def endpoint_info(self, endpoint_id: uuid.UUID) -> DeploymentInfo: network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), ) @@ -315,6 +314,7 @@ def route_info(self, endpoint_id: uuid.UUID) -> RouteInfo: created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) async def test_existing_replica_returns_data( @@ -365,6 +365,7 @@ async def test_zero_weight_traffic_inactive( created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) mock_deployment_repository.get_route = AsyncMock(return_value=inactive_route) @@ -392,6 +393,7 @@ async def test_unassigned_session_id_is_none( created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) mock_deployment_repository.get_route = AsyncMock(return_value=route) @@ -417,6 +419,7 @@ def route_info(self, endpoint_id: uuid.UUID) -> RouteInfo: created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) async def test_default_pagination( diff --git a/tests/unit/manager/services/deployment/test_deployment_service.py b/tests/unit/manager/services/deployment/test_deployment_service.py index 6c32576718f..0b31a0cf890 100644 --- a/tests/unit/manager/services/deployment/test_deployment_service.py +++ b/tests/unit/manager/services/deployment/test_deployment_service.py @@ -384,7 +384,6 @@ def deployment_info(self, deployment_id: uuid.UUID) -> DeploymentInfo: network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), ) @@ -560,7 +559,6 @@ def deployment_info( network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), ) @@ -726,10 +724,9 @@ def test_current_revision_resolved_by_id_match_not_list_order( network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[deploying_data, current_data], options=DeploymentOptions(), - current_revision_id=DeploymentRevisionID(current_data.id), - deploying_revision_id=DeploymentRevisionID(deploying_data.id), + current_revision=current_data, + deploying_revision=deploying_data, ) deployment_data = _convert_deployment_info_to_data(deployment_info) diff --git a/tests/unit/manager/sokovan/deployment/executor/conftest.py b/tests/unit/manager/sokovan/deployment/executor/conftest.py index 1e142fd480d..6dc6abb243e 100644 --- a/tests/unit/manager/sokovan/deployment/executor/conftest.py +++ b/tests/unit/manager/sokovan/deployment/executor/conftest.py @@ -165,9 +165,8 @@ def _create_deployment_info( url=None, preferred_domain_name=None, ), - model_revisions=[cast(ModelRevisionData, revision)] if has_revision else [], - current_revision_id=DeploymentRevisionID(rev_id) if has_revision else None, options=DeploymentOptions(), + current_revision=cast(ModelRevisionData, revision) if has_revision else None, ) @@ -188,6 +187,7 @@ def _create_route_data( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) diff --git a/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py b/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py index 97afe20929a..c18cbfd0825 100644 --- a/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py +++ b/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py @@ -10,6 +10,7 @@ import dataclasses from datetime import datetime +from typing import cast from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 @@ -26,6 +27,7 @@ DeploymentNetworkData, DeploymentOptions, DeploymentState, + ModelRevisionData, ReplicaData, ) from ai.backend.manager.data.resource.types import ScalingGroupProxyTarget @@ -133,9 +135,8 @@ def deployment_created_without_revision(self) -> DeploymentWithHistory: url=None, preferred_domain_name=None, ), - model_revisions=[revision], - current_revision_id=None, - deploying_revision_id=deploying_rev_id, + current_revision=None, + deploying_revision=cast(ModelRevisionData, revision), sub_step=DeploymentLifecycleSubStep.DEPLOYING_PROVISIONING, options=DeploymentOptions(), ), @@ -158,7 +159,7 @@ async def test_registers_endpoint_for_deployment_created_without_revision( dep, revision_id = call_args[0] info = deployment_created_without_revision.deployment_info assert dep is deployment_created_without_revision - assert revision_id == info.deploying_revision_id + assert revision_id == info.deploying_revision.id # type: ignore[union-attr] async def test_deployment_already_with_url_is_not_reregistered( self, @@ -186,7 +187,7 @@ async def test_deployment_without_deploying_revision_is_filtered( ) -> None: """Deployments with no deploying_revision_id must be filtered out.""" info = deployment_created_without_revision.deployment_info - info_no_rev = dataclasses.replace(info, deploying_revision_id=None) + info_no_rev = dataclasses.replace(info, deploying_revision=None) deployment = DeploymentWithHistory(deployment_info=info_no_rev, last_history=None) await handler.execute([deployment]) diff --git a/tests/unit/manager/sokovan/deployment/route/executor/conftest.py b/tests/unit/manager/sokovan/deployment/route/executor/conftest.py index f7dd03c678d..10644d000dc 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/conftest.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/conftest.py @@ -72,7 +72,8 @@ def mock_valkey_schedule() -> AsyncMock: """Mock ValkeyScheduleClient.""" client = AsyncMock() client.get_route_health_records_batch = AsyncMock(return_value={}) - client.get_redis_time = AsyncMock(return_value=1000) + client.get_route_health_statuses_batch = AsyncMock(return_value={}) + client.get_route_probe_targets_batch = AsyncMock(return_value={}) return client @@ -173,9 +174,7 @@ def _create_deployment_info( url="http://test.endpoint", preferred_domain_name=None, ), - model_revisions=[], options=DeploymentOptions(), - current_revision_id=DeploymentRevisionID(uuid4()), ) @@ -203,6 +202,7 @@ def _create_route_data( created_at=datetime.now(tzutc()), revision_id=revision_id or DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py index c05810bf3ae..c1c37265ffe 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py @@ -11,11 +11,14 @@ from __future__ import annotations from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch from uuid import uuid4 from dateutil.tz import tzutc +from ai.backend.common.clients.valkey_client.valkey_schedule import ( + ReplicaHealthStatus as ValkeyReplicaHealthStatus, +) from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID from ai.backend.common.identifier.replica import ReplicaID @@ -41,20 +44,27 @@ def _route(health_status: RouteHealthStatus) -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), ) -def _healthy_record(redis_now: int) -> MagicMock: - """RouteHealthRecord stub that signals 'just probed and healthy'.""" - record = MagicMock() - record.healthy = True - record.last_check = redis_now - 1 - record.initial_delay_until = redis_now - 10 - record.is_stale = MagicMock(return_value=False) - return record +def _healthy_status(route: RouteData) -> ValkeyReplicaHealthStatus: + return ValkeyReplicaHealthStatus( + replica_id=route.route_id, + healthy=True, + last_check=999, + ) + + +def _unhealthy_status(route: RouteData) -> ValkeyReplicaHealthStatus: + return ValkeyReplicaHealthStatus( + replica_id=route.route_id, + healthy=False, + last_check=999, + ) class TestCheckRouteHealthRegister: @@ -67,10 +77,8 @@ async def test_first_time_healthy_triggers_register( ) -> None: """RR-EXEC-HC-001: NOT_CHECKED → HEALTHY route is pushed to AppProxy.""" not_yet_healthy = _route(RouteHealthStatus.NOT_CHECKED) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(not_yet_healthy.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={not_yet_healthy.route_id: _healthy_status(not_yet_healthy)} ) with patch.object( @@ -95,10 +103,8 @@ async def test_already_healthy_route_is_skipped( ) -> None: """RR-EXEC-HC-002: already-HEALTHY routes do not trigger fresh register.""" already_healthy = _route(RouteHealthStatus.HEALTHY) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(already_healthy.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={already_healthy.route_id: _healthy_status(already_healthy)} ) with patch.object(route_executor, "register_routes_now") as mock_register: @@ -115,10 +121,8 @@ async def test_unhealthy_to_healthy_triggers_register( ) -> None: """RR-EXEC-HC-003: UNHEALTHY → HEALTHY recovery is treated as fresh transition.""" recovered = _route(RouteHealthStatus.UNHEALTHY) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(recovered.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={recovered.route_id: _healthy_status(recovered)} ) with patch.object( @@ -138,15 +142,8 @@ async def test_no_successes_skips_register( ) -> None: """RR-EXEC-HC-004: routes whose probe failed do not trigger register.""" unhealthy_route = _route(RouteHealthStatus.UNHEALTHY) - redis_now = 1_000_000 - record = MagicMock() - record.healthy = False - record.last_check = redis_now - 1 - record.initial_delay_until = redis_now - 10 - record.is_stale = MagicMock(return_value=False) - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(unhealthy_route.route_id): record} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={unhealthy_route.route_id: _unhealthy_status(unhealthy_route)} ) with patch.object(route_executor, "register_routes_now") as mock_register: @@ -169,10 +166,8 @@ async def test_register_failure_is_swallowed( cycle would block all later observations. """ not_yet_healthy = _route(RouteHealthStatus.NOT_CHECKED) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(not_yet_healthy.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={not_yet_healthy.route_id: _healthy_status(not_yet_healthy)} ) with patch.object( diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py b/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py index 27dee701f0e..3820122cbc4 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py @@ -1,23 +1,17 @@ -"""Tests for initial_delay calculation based on running_at. - -Test scenarios: -- ID-001: running_at is set → initial_delay_until based on running_at -- ID-002: running_at is None → fallback to redis_time -- ID-003: created_at has expired but running_at has not → still within initial_delay -- ID-004: running_at has also expired → initial_delay over -- ID-005: observer ignores failure within initial_delay (running_at based) -- ID-006: observer writes failure after initial_delay expires (running_at based) -""" +"""Tests for ReplicaProbeTarget registration and initial_delay behavior.""" from __future__ import annotations from datetime import datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock from uuid import uuid4 from dateutil.tz import tzutc -from ai.backend.common.clients.valkey_client.valkey_schedule import RouteHealthRecord +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + ReplicaHealthResult, + ReplicaProbeTarget, +) from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID @@ -33,11 +27,13 @@ from ai.backend.manager.sokovan.deployment.route.handlers.observer.health_check import ( RouteHealthObserver, ) +from ai.backend.manager.sokovan.deployment.route.recorder.context import RouteRecorderContext def _make_route( created_at_ts: int = 1000, session_id: SessionId | None = None, + health_check: ModelHealthCheck | None = None, ) -> RouteData: return RouteData( route_id=ReplicaID(uuid4()), @@ -49,203 +45,221 @@ def _make_route( created_at=datetime.fromtimestamp(created_at_ts, tz=tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=health_check, replica_host="10.0.0.1", replica_port=8000, ) # ============================================================================= -# _initialize_health_records: initial_delay_until calculation +# _register_route_probe_targets # ============================================================================= -class TestInitializeHealthRecordsInitialDelay: - """Tests for initial_delay_until calculation in _initialize_health_records.""" +class TestRegisterReplicaProbeTargets: + """Tests for _register_route_probe_targets.""" - async def test_running_at_present_uses_running_at( + async def test_registers_probe_target_with_health_config( self, route_executor: RouteExecutor, - mock_deployment_repo: AsyncMock, mock_valkey_schedule: AsyncMock, ) -> None: - """ID-001: When running_at exists in Valkey, initial_delay_until is based on running_at. + """health_path comes from route.health_check when present.""" + route = _make_route(health_check=ModelHealthCheck(path="/healthz", initial_delay=60.0)) + replica_id = ReplicaID(route.route_id) - Given: Route with running_at=5000 in Valkey, initial_delay=720 - When: _initialize_health_records - Then: initial_delay_until = 5000 + 720 = 5720 - """ - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) + await route_executor._register_route_probe_targets( + [route], + {replica_id: RouteSessionKernelInfo(replica_host="10.0.0.2", replica_port=9000)}, + ) - mock_valkey_schedule.get_route_running_at_batch.return_value = { - route_id_str: 5000, - } - mock_valkey_schedule.get_redis_time.return_value = 5100 - mock_deployment_repo.fetch_health_check_configs_by_revision_ids.return_value = { - route.revision_id: ModelHealthCheck(path="/health", initial_delay=720.0), - } + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[ReplicaProbeTarget] = call_args[0][0] + assert len(targets) == 1 + assert targets[0].replica_id == replica_id + assert targets[0].health_path == "/healthz" + assert targets[0].inference_port == 9000 + assert targets[0].replica_host == "10.0.0.2" + + async def test_skips_route_without_health_check( + self, + route_executor: RouteExecutor, + mock_valkey_schedule: AsyncMock, + ) -> None: + """Route with health_check=None is not registered.""" + route = _make_route(health_check=None) + replica_id = ReplicaID(route.route_id) - await route_executor._initialize_health_records( + await route_executor._register_route_probe_targets( [route], - {route.route_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, + {replica_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, ) - call_args = mock_valkey_schedule.initialize_route_health_records_batch.call_args - records: list[RouteHealthRecord] = call_args[0][0] - assert len(records) == 1 - assert records[0].running_at == 5000 - assert records[0].initial_delay_until == 5000 + 720 + mock_valkey_schedule.register_route_probe_targets_batch.assert_not_awaited() - async def test_running_at_none_falls_back_to_redis_time( + async def test_registers_multiple_routes( self, route_executor: RouteExecutor, - mock_deployment_repo: AsyncMock, mock_valkey_schedule: AsyncMock, ) -> None: - """ID-002: When running_at is None, fallback to current redis_time. + """Multiple routes with health_check produce multiple ReplicaProbeTarget entries.""" + health_check = ModelHealthCheck(path="/health", initial_delay=60.0) + routes = [_make_route(health_check=health_check) for _ in range(3)] + replica_infos = { + ReplicaID(r.route_id): RouteSessionKernelInfo( + replica_host="10.0.0.1", replica_port=8000 + i + ) + for i, r in enumerate(routes) + } - Given: No existing record in Valkey (running_at not set), redis_time=6000 - When: _initialize_health_records - Then: initial_delay_until = 6000 + 720 = 6720, running_at = 6000 - """ - route = _make_route(created_at_ts=1000) + await route_executor._register_route_probe_targets(routes, replica_infos) - mock_valkey_schedule.get_route_running_at_batch.return_value = {} - mock_valkey_schedule.get_redis_time.return_value = 6000 - mock_deployment_repo.fetch_health_check_configs_by_revision_ids.return_value = { - route.revision_id: ModelHealthCheck(path="/health", initial_delay=720.0), - } + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[ReplicaProbeTarget] = call_args[0][0] + assert len(targets) == 3 - await route_executor._initialize_health_records( - [route], - {route.route_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, - ) - call_args = mock_valkey_schedule.initialize_route_health_records_batch.call_args - records: list[RouteHealthRecord] = call_args[0][0] - assert len(records) == 1 - assert records[0].running_at == 6000 - assert records[0].initial_delay_until == 6000 + 720 +class TestSyncReplicaProbeTargets: + """Tests for sync_route_probe_targets.""" - async def test_created_at_expired_but_running_at_not_expired( + async def test_syncs_routes_with_replica_info( self, route_executor: RouteExecutor, - mock_deployment_repo: AsyncMock, mock_valkey_schedule: AsyncMock, ) -> None: - """ID-003: created_at-based delay would have expired, but running_at-based has not. + """Routes with health_check, replica_host and replica_port are synced.""" + route = _make_route(health_check=ModelHealthCheck(path="/health", initial_delay=60.0)) - Given: created_at=1000, running_at=5000, initial_delay=720, current_time=1800 - created_at + 720 = 1720 < 1800 (expired if created_at based) - running_at + 720 = 5720 > 1800 (NOT expired with running_at based) - When: _initialize_health_records - Then: initial_delay_until = 5720 (based on running_at, not created_at) - """ - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) + with RouteRecorderContext.scope("test", entity_ids=[route.route_id]): + result = await route_executor.sync_route_probe_targets([route]) - mock_valkey_schedule.get_route_running_at_batch.return_value = { - route_id_str: 5000, - } - mock_valkey_schedule.get_redis_time.return_value = 1800 - mock_deployment_repo.fetch_health_check_configs_by_revision_ids.return_value = { - route.revision_id: ModelHealthCheck(path="/health", initial_delay=720.0), - } + mock_valkey_schedule.register_route_probe_targets_batch.assert_awaited_once() + assert result.successes == [] + assert result.errors == [] - await route_executor._initialize_health_records( - [route], - {route.route_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, + async def test_skips_routes_without_health_check_or_replica_info( + self, + route_executor: RouteExecutor, + mock_valkey_schedule: AsyncMock, + ) -> None: + """Routes missing health_check or replica info are silently skipped.""" + route = _make_route(health_check=ModelHealthCheck(path="/health", initial_delay=60.0)) + route_no_health_check = _make_route(health_check=None) + route_no_replica = RouteData( + route_id=ReplicaID(uuid4()), + deployment_id=DeploymentID(uuid4()), + session_id=SessionId(uuid4()), + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.NOT_CHECKED, + traffic_ratio=1.0, + created_at=datetime.fromtimestamp(1000, tz=tzutc()), + revision_id=DeploymentRevisionID(uuid4()), + traffic_status=RouteTrafficStatus.INACTIVE, + health_check=ModelHealthCheck(path="/health", initial_delay=60.0), + replica_host=None, + replica_port=None, + ) + + with RouteRecorderContext.scope( + "test", + entity_ids=[route.route_id, route_no_health_check.route_id, route_no_replica.route_id], + ): + await route_executor.sync_route_probe_targets([ + route, + route_no_health_check, + route_no_replica, + ]) + + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[ReplicaProbeTarget] = call_args[0][0] + assert len(targets) == 1 + assert targets[0].replica_id == ReplicaID(route.route_id) + + async def test_all_routes_missing_replica_info_returns_empty( + self, + route_executor: RouteExecutor, + mock_valkey_schedule: AsyncMock, + ) -> None: + """No routes with replica info → no Valkey call, empty result.""" + route_no_info = RouteData( + route_id=ReplicaID(uuid4()), + deployment_id=DeploymentID(uuid4()), + session_id=SessionId(uuid4()), + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.NOT_CHECKED, + traffic_ratio=1.0, + created_at=datetime.fromtimestamp(1000, tz=tzutc()), + revision_id=DeploymentRevisionID(uuid4()), + traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, + replica_host=None, + replica_port=None, ) - call_args = mock_valkey_schedule.initialize_route_health_records_batch.call_args - records: list[RouteHealthRecord] = call_args[0][0] - assert records[0].initial_delay_until == 5720 - assert records[0].created_at == 1000 + result = await route_executor.sync_route_probe_targets([route_no_info]) + + mock_valkey_schedule.register_route_probe_targets_batch.assert_not_awaited() + assert result.successes == [] + assert result.errors == [] # ============================================================================= -# RouteHealthObserver: within_initial_delay based on running_at +# RouteHealthObserver: probe target based observation # ============================================================================= -class TestObserverInitialDelay: - """Tests for observer's initial_delay behavior with running_at-based records.""" +class TestObserverSetsHealthStatus: + """Tests for RouteHealthObserver writing RouteHealthStatus to Valkey.""" - async def test_observer_ignores_failure_within_initial_delay(self) -> None: - """ID-005: Observer does not write failure during initial_delay period. + async def test_observer_writes_success_result(self) -> None: + """ID-005: Observer writes healthy=True on successful probe. - Given: running_at=5000, initial_delay=720 → initial_delay_until=5720 - current redis_time=5500 (within initial_delay) - health check fails + Given: ReplicaProbeTarget exists in Valkey, HTTP check succeeds When: Observer runs - Then: update_route_manager_health NOT called for failure + Then: record_route_health_status called with (replica_id, True) """ mock_deployment_repo = AsyncMock() mock_valkey = AsyncMock() - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) - route_data = MagicMock() - route_data.route_id = route.route_id - route_data.replica_host = "10.0.0.1" - route_data.replica_port = 8000 - - record = RouteHealthRecord( - route_id=route_id_str, - created_at=1000, - initial_delay_until=5720, + route = _make_route() + probe_target = ReplicaProbeTarget( + replica_id=route.route_id, health_path="/health", inference_port=8000, replica_host="10.0.0.1", - running_at=5000, ) - mock_valkey.get_route_health_records_batch.return_value = {route_id_str: record} - mock_valkey.get_redis_time.return_value = 5500 # Within initial_delay + mock_valkey.get_route_probe_targets_batch.return_value = {route.route_id: probe_target} observer = RouteHealthObserver( deployment_repository=mock_deployment_repo, valkey_schedule=mock_valkey, ) - # Patch HTTP check to always fail - observer._http_health_check = AsyncMock(return_value=False) # type: ignore[method-assign] + observer._http_health_check = AsyncMock(return_value=True) # type: ignore[method-assign] - await observer.observe([route_data]) + await observer.observe([route]) - # refresh_route_health_ttl should be called (always) - mock_valkey.refresh_route_health_ttl.assert_awaited_once_with(route_id_str) - # update_route_manager_health should NOT be called (failure ignored during initial_delay) - mock_valkey.update_route_manager_health.assert_not_awaited() + mock_valkey.record_route_health_statuses_batch.assert_awaited_once_with([ + ReplicaHealthResult(replica_id=route.route_id, healthy=True) + ]) - async def test_observer_writes_failure_after_initial_delay(self) -> None: - """ID-006: Observer writes failure after initial_delay expires. + async def test_observer_writes_failure_result(self) -> None: + """ID-006: Observer writes healthy=False on failed probe. - Given: running_at=5000, initial_delay=720 → initial_delay_until=5720 - current redis_time=5800 (past initial_delay) - health check fails + Given: ReplicaProbeTarget exists in Valkey, HTTP check fails When: Observer runs - Then: update_route_manager_health called with False + Then: record_route_health_status called with (replica_id, False) """ mock_deployment_repo = AsyncMock() mock_valkey = AsyncMock() - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) - route_data = MagicMock() - route_data.route_id = route.route_id - route_data.replica_host = "10.0.0.1" - route_data.replica_port = 8000 - - record = RouteHealthRecord( - route_id=route_id_str, - created_at=1000, - initial_delay_until=5720, + route = _make_route() + probe_target = ReplicaProbeTarget( + replica_id=route.route_id, health_path="/health", inference_port=8000, replica_host="10.0.0.1", - running_at=5000, ) - mock_valkey.get_route_health_records_batch.return_value = {route_id_str: record} - mock_valkey.get_redis_time.return_value = 5800 # Past initial_delay + mock_valkey.get_route_probe_targets_batch.return_value = {route.route_id: probe_target} observer = RouteHealthObserver( deployment_repository=mock_deployment_repo, @@ -253,38 +267,24 @@ async def test_observer_writes_failure_after_initial_delay(self) -> None: ) observer._http_health_check = AsyncMock(return_value=False) # type: ignore[method-assign] - await observer.observe([route_data]) + await observer.observe([route]) - mock_valkey.update_route_manager_health.assert_awaited_once_with(route_id_str, False) + mock_valkey.record_route_health_statuses_batch.assert_awaited_once_with([ + ReplicaHealthResult(replica_id=route.route_id, healthy=False) + ]) - async def test_observer_writes_success_within_initial_delay(self) -> None: - """ID-007: Observer writes success even during initial_delay. + async def test_observer_skips_route_without_probe_target(self) -> None: + """ID-007: Observer skips route when no probe target in Valkey. - Given: Within initial_delay, health check succeeds + Given: No ReplicaProbeTarget in Valkey When: Observer runs - Then: update_route_manager_health called with True + Then: record_route_health_status NOT called, observed_count=0 """ mock_deployment_repo = AsyncMock() mock_valkey = AsyncMock() - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) - route_data = MagicMock() - route_data.route_id = route.route_id - route_data.replica_host = "10.0.0.1" - route_data.replica_port = 8000 - - record = RouteHealthRecord( - route_id=route_id_str, - created_at=1000, - initial_delay_until=5720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=5000, - ) - mock_valkey.get_route_health_records_batch.return_value = {route_id_str: record} - mock_valkey.get_redis_time.return_value = 5500 # Within initial_delay + route = _make_route() + mock_valkey.get_route_probe_targets_batch.return_value = {route.route_id: None} observer = RouteHealthObserver( deployment_repository=mock_deployment_repo, @@ -292,113 +292,7 @@ async def test_observer_writes_success_within_initial_delay(self) -> None: ) observer._http_health_check = AsyncMock(return_value=True) # type: ignore[method-assign] - await observer.observe([route_data]) - - mock_valkey.update_route_manager_health.assert_awaited_once_with(route_id_str, True) - + result = await observer.observe([route]) -# ============================================================================= -# RouteHealthRecord serialization: running_at -# ============================================================================= - - -class TestRouteHealthRecordRunningAt: - """Tests for RouteHealthRecord running_at serialization.""" - - def test_running_at_none_not_in_hash(self) -> None: - """running_at=None should not appear in serialized hash.""" - record = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=1720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=None, - ) - h = record.to_valkey_hash() - assert "running_at" not in h - - def test_running_at_present_in_hash(self) -> None: - """running_at with value should appear in serialized hash.""" - record = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=5720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=5000, - ) - h = record.to_valkey_hash() - assert h["running_at"] == "5000" - - def test_from_hash_missing_running_at_is_none(self) -> None: - """Deserializing hash without running_at field yields None.""" - data = { - "route_id": "r1", - "created_at": "1000", - "initial_delay_until": "1720", - "health_path": "/health", - "inference_port": "8000", - "replica_host": "10.0.0.1", - } - record = RouteHealthRecord.from_valkey_hash(data) - assert record.running_at is None - - def test_from_hash_zero_running_at_is_none(self) -> None: - """Deserializing hash with running_at=0 yields None (backward compat).""" - data = { - "route_id": "r1", - "created_at": "1000", - "initial_delay_until": "1720", - "health_path": "/health", - "inference_port": "8000", - "replica_host": "10.0.0.1", - "running_at": "0", - } - record = RouteHealthRecord.from_valkey_hash(data) - assert record.running_at is None - - def test_from_hash_valid_running_at(self) -> None: - """Deserializing hash with valid running_at yields int.""" - data = { - "route_id": "r1", - "created_at": "1000", - "initial_delay_until": "5720", - "health_path": "/health", - "inference_port": "8000", - "replica_host": "10.0.0.1", - "running_at": "5000", - } - record = RouteHealthRecord.from_valkey_hash(data) - assert record.running_at == 5000 - - def test_roundtrip_with_running_at(self) -> None: - """Serialize → deserialize preserves running_at.""" - original = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=5720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=5000, - ) - restored = RouteHealthRecord.from_valkey_hash(original.to_valkey_hash()) - assert restored.running_at == 5000 - assert restored.initial_delay_until == 5720 - - def test_roundtrip_without_running_at(self) -> None: - """Serialize → deserialize preserves running_at=None.""" - original = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=1720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=None, - ) - restored = RouteHealthRecord.from_valkey_hash(original.to_valkey_hash()) - assert restored.running_at is None + mock_valkey.record_route_health_statuses_batch.assert_not_awaited() + assert result.observed_count == 0 diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py b/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py index 123cb2c432e..a4df7545614 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py @@ -62,6 +62,7 @@ def _route_with_replica( traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, replica_host=replica_host, replica_port=replica_port, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py index 6b0f31bc3e1..b47c56a181b 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py @@ -22,7 +22,9 @@ import pytest from dateutil.tz import tzutc -from ai.backend.common.clients.valkey_client.valkey_schedule import RouteHealthRecord +from ai.backend.common.clients.valkey_client.valkey_schedule import ( + ReplicaHealthStatus as ValkeyReplicaHealthStatus, +) from ai.backend.common.dto.appproxy_coordinator.v2.endpoint.response import ( BulkUpdateRoutesResponse, ) @@ -196,7 +198,7 @@ class TestCheckRouteHealth: """Tests for check_route_health functionality. Verifies the executor correctly checks route health via Valkey - using RouteHealthRecord-based classification. + using RouteHealthStatus TTL-based classification. """ async def test_healthy_route_in_successes( @@ -207,34 +209,23 @@ async def test_healthy_route_in_successes( ) -> None: """RH-001: Healthy route is in successes. - Given: Route with healthy RouteHealthRecord in Valkey + Given: Route with RouteHealthStatus(healthy=True) in Valkey When: Check route health Then: Route in successes list """ - # Arrange - route_id_str = str(healthy_route.route_id) - current_time = 1000 - record = RouteHealthRecord( - route_id=route_id_str, - created_at=900, - initial_delay_until=930, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - agent_healthy=True, - agent_last_check=current_time - 5, + status = ValkeyReplicaHealthStatus( + replica_id=healthy_route.route_id, + healthy=True, + last_check=995, ) - mock_valkey_schedule.get_route_health_records_batch.return_value = { - route_id_str: record, + mock_valkey_schedule.get_route_health_statuses_batch.return_value = { + healthy_route.route_id: status, } - mock_valkey_schedule.get_redis_time.return_value = current_time entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 1 assert len(result.errors) == 0 assert len(result.stale) == 0 @@ -247,34 +238,23 @@ async def test_unhealthy_route_in_errors( ) -> None: """RH-002: Unhealthy route is in errors. - Given: Route with unhealthy RouteHealthRecord in Valkey + Given: Route with RouteHealthStatus(healthy=False) in Valkey When: Check route health Then: Route in errors list """ - # Arrange - route_id_str = str(healthy_route.route_id) - current_time = 1000 - record = RouteHealthRecord( - route_id=route_id_str, - created_at=900, - initial_delay_until=930, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - agent_healthy=False, - agent_last_check=current_time - 5, + status = ValkeyReplicaHealthStatus( + replica_id=healthy_route.route_id, + healthy=False, + last_check=995, ) - mock_valkey_schedule.get_route_health_records_batch.return_value = { - route_id_str: record, + mock_valkey_schedule.get_route_health_statuses_batch.return_value = { + healthy_route.route_id: status, } - mock_valkey_schedule.get_redis_time.return_value = current_time entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 0 assert len(result.errors) == 1 assert len(result.stale) == 0 @@ -285,36 +265,18 @@ async def test_stale_route_in_stale_list( mock_valkey_schedule: AsyncMock, healthy_route: RouteData, ) -> None: - """RH-003: Stale route is in stale list. + """RH-003: Route with expired TTL (no key) is in stale list. - Given: Route with stale RouteHealthRecord in Valkey (old last_check) + Given: Route whose RouteHealthStatus TTL has expired (key absent) When: Check route health - Then: Route in stale list + Then: Route in stale list (DEGRADED) """ - # Arrange - route_id_str = str(healthy_route.route_id) - current_time = 1000 - record = RouteHealthRecord( - route_id=route_id_str, - created_at=100, - initial_delay_until=130, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - agent_healthy=True, - agent_last_check=100, # Very old check, stale - ) - mock_valkey_schedule.get_route_health_records_batch.return_value = { - route_id_str: record, - } - mock_valkey_schedule.get_redis_time.return_value = current_time + mock_valkey_schedule.get_route_health_statuses_batch.return_value = {} entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 0 assert len(result.errors) == 0 assert len(result.stale) == 1 @@ -327,20 +289,16 @@ async def test_missing_health_data_treated_as_stale( ) -> None: """RH-004: Missing health data is treated as stale. - Given: Route with no RouteHealthRecord in Valkey + Given: Route with no RouteHealthStatus in Valkey When: Check route health Then: Route in stale list """ - # Arrange - Empty records response - mock_valkey_schedule.get_route_health_records_batch.return_value = {} - mock_valkey_schedule.get_redis_time.return_value = 1000 + mock_valkey_schedule.get_route_health_statuses_batch.return_value = {} entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 0 assert len(result.errors) == 0 assert len(result.stale) == 1 @@ -491,12 +449,15 @@ async def test_unhealthy_route_marked_for_cleanup( Then: Route in successes (marked for cleanup) """ # Arrange + current_revision_mock = MagicMock() + current_revision_mock.id = unhealthy_route.revision_id + deployment = MagicMock() deployment.id = unhealthy_route.deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = unhealthy_route.revision_id - deployment.deploying_revision_id = None + deployment.current_revision = current_revision_mock + deployment.deploying_revision = None mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -524,12 +485,15 @@ async def test_healthy_route_not_marked_for_cleanup( Then: Route not in successes """ # Arrange + current_revision_mock = MagicMock() + current_revision_mock.id = healthy_route.revision_id + deployment = MagicMock() deployment.id = healthy_route.deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = healthy_route.revision_id - deployment.deploying_revision_id = None + deployment.current_revision = current_revision_mock + deployment.deploying_revision = None mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -588,14 +552,20 @@ async def test_orphan_revision_route_marked_for_cleanup( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), # neither current nor deploying traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) + current_revision_mock = MagicMock() + current_revision_mock.id = current_revision_id + deploying_revision_mock = MagicMock() + deploying_revision_mock.id = deploying_revision_id + deployment = MagicMock() deployment.id = deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = current_revision_id - deployment.deploying_revision_id = deploying_revision_id + deployment.current_revision = current_revision_mock + deployment.deploying_revision = deploying_revision_mock mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -632,14 +602,18 @@ async def test_provisioning_route_for_deploying_revision_kept( created_at=datetime.now(tzutc()), revision_id=deploying_revision_id, traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) + deploying_revision_mock = MagicMock() + deploying_revision_mock.id = deploying_revision_id + deployment = MagicMock() deployment.id = deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = None - deployment.deploying_revision_id = deploying_revision_id + deployment.current_revision = None + deployment.deploying_revision = deploying_revision_mock mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -676,14 +650,15 @@ async def test_orphan_check_skipped_when_no_known_revisions( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) deployment = MagicMock() deployment.id = deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = None - deployment.deploying_revision_id = None + deployment.current_revision = None + deployment.deploying_revision = None mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -930,6 +905,7 @@ def _route_for_endpoint(endpoint_id: DeploymentID) -> RouteData: revision_id=DeploymentRevisionID(uuid4()), created_at=datetime.now(tzutc()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py b/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py index 100ca19f379..fa9ab8638f9 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py @@ -44,6 +44,7 @@ def _terminating_route() -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), @@ -162,6 +163,7 @@ async def test_routes_without_session_still_unregister( traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, replica_host=None, replica_port=None, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py b/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py index 0602c89e260..2f85556fce7 100644 --- a/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py +++ b/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py @@ -41,6 +41,7 @@ def _route(health_status: RouteHealthStatus) -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py b/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py index 9ac3e535b82..2a9255a232b 100644 --- a/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py +++ b/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py @@ -40,6 +40,7 @@ def _terminating_route() -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py b/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py index 59a05cceae1..aad4bff89a9 100644 --- a/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py +++ b/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py @@ -54,6 +54,7 @@ def sample_route_data() -> RouteData: created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) @@ -399,6 +400,7 @@ async def test_records_history_on_stale( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) ] ) diff --git a/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py b/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py index e25cc3ac94a..98f4c7f8f94 100644 --- a/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py +++ b/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py @@ -13,13 +13,17 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass from datetime import UTC, datetime +from typing import cast +from unittest.mock import MagicMock from uuid import UUID, uuid4 import pytest from pydantic import ValidationError +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.data.endpoint.types import EndpointLifecycle, ScalingState from ai.backend.common.dto.manager.v2.deployment.types import IntOrPercent from ai.backend.common.exception import BackendAISchemaValidationFailed @@ -33,6 +37,7 @@ DeploymentNetworkData, DeploymentOptions, DeploymentState, + ModelRevisionData, ReplicaData, RouteHealthStatus, RouteInfo, @@ -57,6 +62,7 @@ def make_int_or_percent(value: int | float) -> IntOrPercent: OLD_REV = UUID("11111111-1111-1111-1111-111111111111") +_STUB_HEALTH_CHECK = ModelHealthCheck(path="/health", interval=10.0, initial_delay=30.0) NEW_REV = UUID("22222222-2222-2222-2222-222222222222") PROJECT_ID = UUID("cccccccc-cccc-cccc-cccc-cccccccccccc") USER_ID = UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") @@ -101,6 +107,10 @@ def make_deployment( current_revision_id: UUID = OLD_REV, endpoint_id: UUID = ENDPOINT_ID, ) -> DeploymentInfo: + deploying_mock = MagicMock() + deploying_mock.id = DeploymentRevisionID(deploying_revision_id) + current_mock = MagicMock() + current_mock.id = DeploymentRevisionID(current_revision_id) return DeploymentInfo( id=DeploymentID(endpoint_id), metadata=DeploymentMetadata( @@ -125,10 +135,9 @@ def make_deployment( network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), - current_revision_id=DeploymentRevisionID(current_revision_id), - deploying_revision_id=DeploymentRevisionID(deploying_revision_id), + current_revision=cast(ModelRevisionData, current_mock), + deploying_revision=cast(ModelRevisionData, deploying_mock), ) @@ -139,6 +148,7 @@ def make_route( health_status: RouteHealthStatus = RouteHealthStatus.HEALTHY, endpoint_id: UUID = ENDPOINT_ID, route_id: UUID | None = None, + health_check: ModelHealthCheck | None = _STUB_HEALTH_CHECK, ) -> RouteInfo: return RouteInfo( route_id=route_id or uuid4(), @@ -152,6 +162,7 @@ def make_route( traffic_status=RouteTrafficStatus.ACTIVE if status.is_active() else RouteTrafficStatus.INACTIVE, + health_check=health_check, ) @@ -745,7 +756,7 @@ def test_all_old_inactive_no_new_creates_desired(self) -> None: def test_deploying_rev_none_rejected(self) -> None: """If deploying_revision_id is None, evaluate_cycle raises.""" - deployment = make_deployment(desired=1, deploying_revision_id=None) # type: ignore[arg-type] + deployment = dataclasses.replace(make_deployment(desired=1), deploying_revision=None) spec = RollingUpdateSpec( max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) ) @@ -1125,3 +1136,70 @@ def test_small_fraction_with_few_replicas(self) -> None: result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) assert len(result.route_changes.rollout_specs) == 1 + + +# =========================================================================== +# No-health-check scenario +# =========================================================================== + + +class TestNoHealthCheck: + """Routes without health_check stay DEGRADED in DB but must allow READY transition.""" + + def test_running_degraded_no_health_check_completes(self) -> None: + """RUNNING + DEGRADED + no health_check → counts as healthy → COMPLETED.""" + deployment = make_deployment(desired=1) + spec = RollingUpdateSpec( + max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) + ) + routes = [ + make_route( + revision_id=NEW_REV, + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.DEGRADED, + health_check=None, + ) + ] + + result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) + + assert result.sub_step == DeploymentLifecycleSubStep.DEPLOYING_COMPLETED + + def test_running_degraded_with_health_check_does_not_complete(self) -> None: + """RUNNING + DEGRADED + has health_check → still unhealthy → PROVISIONING.""" + deployment = make_deployment(desired=1) + spec = RollingUpdateSpec( + max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) + ) + routes = [ + make_route( + revision_id=NEW_REV, + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.DEGRADED, + health_check=_STUB_HEALTH_CHECK, + ) + ] + + result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) + + assert result.sub_step == DeploymentLifecycleSubStep.DEPLOYING_PROVISIONING + + def test_multiple_replicas_no_health_check_completes(self) -> None: + """2 desired, 2 RUNNING DEGRADED no-health-check → COMPLETED.""" + deployment = make_deployment(desired=2) + spec = RollingUpdateSpec( + max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) + ) + routes = [ + make_route( + revision_id=NEW_REV, + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.DEGRADED, + health_check=None, + ) + for _ in range(2) + ] + + result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) + + assert result.sub_step == DeploymentLifecycleSubStep.DEPLOYING_COMPLETED diff --git a/tests/unit/manager/sokovan/deployment/test_coordinator_history.py b/tests/unit/manager/sokovan/deployment/test_coordinator_history.py index 3fdf5862279..261ba81575a 100644 --- a/tests/unit/manager/sokovan/deployment/test_coordinator_history.py +++ b/tests/unit/manager/sokovan/deployment/test_coordinator_history.py @@ -71,7 +71,6 @@ def sample_deployment_info() -> DeploymentInfo: url=None, preferred_domain_name=None, ), - model_revisions=[], options=DeploymentOptions(), )