From 5e1a1abf48a5f634685284464464d64d3f95be68 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Thu, 14 May 2026 21:39:31 +0900 Subject: [PATCH 01/13] feat(BA-6035): define RouteProbeTarget and RouteHealthStatus Valkey types Split route health Valkey data into two separate types: - RouteProbeTarget: probe config (health_path, inference_port, replica_host) stored at route_probe:{replica_id}, TTL 3600s, written by coordinator - RouteHealthStatus: check result (healthy, last_check) stored at route_health:{replica_id}, TTL 120s, written by observer Both types use ReplicaID typed identifier for replica_id field. Co-Authored-By: Claude Sonnet 4.6 --- .../valkey_client/valkey_schedule/__init__.py | 3 + .../valkey_client/valkey_schedule/types.py | 70 +++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py index 64fae3d692e..76ad8931cde 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py @@ -5,11 +5,14 @@ RouteHealthRecord, ValkeyScheduleClient, ) +from .types import RouteHealthStatus, RouteProbeTarget __all__ = [ "HealthCheckStatus", "HealthStatus", "KernelStatus", "RouteHealthRecord", + "RouteHealthStatus", + "RouteProbeTarget", "ValkeyScheduleClient", ] diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py new file mode 100644 index 00000000000..fdaf250493e --- /dev/null +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py @@ -0,0 +1,70 @@ +"""Valkey data types for route health management.""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from uuid import UUID + +from ai.backend.common.identifier.replica import ReplicaID + + +@dataclass +class RouteProbeTarget: + """Probe configuration for a route stored in Valkey. + + Stored as a hash at key `route_probe:{route_id}`. + Written once by the coordinator when the route enters WARMING_UP (host/port available). + Read by RouteHealthObserver to know what endpoint to probe. + """ + + route_id: ReplicaID + health_path: str + inference_port: int + replica_host: str + + def to_valkey_hash(self) -> Mapping[str, str]: + return { + "route_id": str(self.route_id), + "health_path": self.health_path, + "inference_port": str(self.inference_port), + "replica_host": self.replica_host, + } + + @classmethod + def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteProbeTarget: + return cls( + route_id=ReplicaID(UUID(data["route_id"])), + health_path=data["health_path"], + inference_port=int(data["inference_port"]), + replica_host=data["replica_host"], + ) + + +@dataclass +class RouteHealthStatus: + """Health check result for a route stored in Valkey. + + Stored as a hash at key `route_health:{replica_id}`. + Written by RouteHealthObserver after each HTTP probe. + Short TTL — key expiry signals DEGRADED (no recent check). + """ + + replica_id: ReplicaID + healthy: bool + last_check: int # Unix timestamp (Redis time) + + def to_valkey_hash(self) -> Mapping[str, str]: + return { + "replica_id": str(self.replica_id), + "healthy": "1" if self.healthy else "0", + "last_check": str(self.last_check), + } + + @classmethod + def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthStatus: + return cls( + replica_id=ReplicaID(UUID(data["replica_id"])), + healthy=data.get("healthy", "0") == "1", + last_check=int(data.get("last_check", "0")), + ) From 24ad9e23ef9ec982900ecff61606af611f0cdb55 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Thu, 14 May 2026 21:53:32 +0900 Subject: [PATCH 02/13] feat(BA-6035): add RouteProbeTarget/RouteHealthStatus Valkey client methods with tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add four new ValkeyScheduleClient methods for the split route health design: - register_route_probe_targets_batch: writes probe config at route_probe:{id}, TTL 3600s - get_route_probe_targets_batch: batch-reads probe targets - record_route_health_status: writes health result at route_health:{id}, TTL 120s (expiry = DEGRADED) - get_route_health_statuses_batch: batch-reads health statuses Also rename RouteProbeTarget.route_id → replica_id for consistency with RouteHealthStatus. Tests: pure serialization tests in test_valkey_schedule_types.py, Redis integration tests for all four methods in test_valkey_schedule_client.py. Co-Authored-By: Claude Sonnet 4.6 --- .../valkey_client/valkey_schedule/client.py | 134 ++++++++++++ .../valkey_client/valkey_schedule/types.py | 8 +- .../test_valkey_schedule_client.py | 195 ++++++++++++++++++ .../test_valkey_schedule_types.py | 101 +++++++++ 4 files changed, 434 insertions(+), 4 deletions(-) create mode 100644 tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py index 4d8ae23d5a1..a21dc742325 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py @@ -12,7 +12,12 @@ AbstractValkeyClient, create_valkey_client, ) +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + RouteHealthStatus, + RouteProbeTarget, +) from ai.backend.common.exception import BackendAIError +from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.json import dump_json_str, load_json from ai.backend.common.metrics.metric import DomainType, LayerType from ai.backend.common.resilience import ( @@ -34,6 +39,8 @@ AGENT_LAST_CHECK_TTL_SEC = 1200 # 20 minutes - TTL for agent last check timestamp ORPHAN_KERNEL_THRESHOLD_SEC = 600 # 10 minutes - threshold for orphan kernel detection FORCE_TERMINATED_CLEANUP_TTL_SEC = 1200 # 20 minutes - TTL for force-terminated cleanup queue +ROUTE_PROBE_TTL_SEC = 3600 # 1 hour - TTL for route probe targets +ROUTE_HEALTH_STATUS_TTL_SEC = 120 # 2 minutes - TTL for route health status (expiry = DEGRADED) class HealthCheckStatus(enum.StrEnum): @@ -260,6 +267,12 @@ def _get_kernel_presence_key(self, kernel_id: KernelId) -> str: """ return f"kernel:presence:{kernel_id}" + def _get_route_probe_key(self, replica_id: ReplicaID) -> str: + return f"route_probe:{replica_id}" + + def _get_route_health_status_key(self, replica_id: ReplicaID) -> str: + return f"route_health:{replica_id}" + def _get_agent_last_check_key(self, agent_id: AgentId) -> str: """ Generate the Redis key for agent last check timestamp. @@ -929,6 +942,127 @@ async def get_route_health_records_batch( return records + # ==================== RouteProbeTarget / RouteHealthStatus Methods ==================== + + @valkey_schedule_resilience.apply() + async def register_route_probe_targets_batch(self, targets: Sequence[RouteProbeTarget]) -> None: + """ + Batch register RouteProbeTarget entries in Valkey. + Called by coordinator when route enters WARMING_UP and replica host/port are known. + + :param targets: RouteProbeTarget instances to store + """ + if not targets: + return + + batch = Batch(is_atomic=False) + for target in targets: + key = self._get_route_probe_key(target.replica_id) + batch.hset(key, target.to_valkey_hash()) + batch.expire(key, ROUTE_PROBE_TTL_SEC) + + async with self._client.client() as conn: + await conn.exec(batch, raise_on_error=True) + + @valkey_schedule_resilience.apply() + async def get_route_probe_targets_batch( + self, replica_ids: Sequence[ReplicaID] + ) -> Mapping[ReplicaID, RouteProbeTarget | None]: + """ + Batch get RouteProbeTargets from Valkey. + + :param replica_ids: Replica IDs to look up + :return: Mapping of replica_id to RouteProbeTarget (None if missing or expired) + """ + if not replica_ids: + return {} + + batch = Batch(is_atomic=False) + for replica_id in replica_ids: + batch.hgetall(self._get_route_probe_key(replica_id)) + + async with self._client.client() as conn: + results = await conn.exec(batch, raise_on_error=False) + if results is None: + return dict.fromkeys(replica_ids) + + targets: dict[ReplicaID, RouteProbeTarget | None] = {} + for i, replica_id in enumerate(replica_ids): + hgetall_result = results[i] if len(results) > i else None + if not hgetall_result: + targets[replica_id] = None + continue + raw = cast(dict[bytes, bytes], hgetall_result) + if not raw or b"replica_id" not in raw: + targets[replica_id] = None + continue + data = {k.decode(): v.decode() for k, v in raw.items()} + targets[replica_id] = RouteProbeTarget.from_valkey_hash(data) + + return targets + + @valkey_schedule_resilience.apply() + async def record_route_health_status(self, replica_id: ReplicaID, healthy: bool) -> None: + """ + Record health check result for a route. + Called by RouteHealthObserver after each HTTP probe. + Refreshes TTL on every call; key expiry signals DEGRADED. + + :param replica_id: The replica ID to update + :param healthy: Whether the route passed the health check + """ + key = self._get_route_health_status_key(replica_id) + current_time = str(await self._get_redis_time()) + data: Mapping[str | bytes, str | bytes] = { + "replica_id": str(replica_id), + "healthy": "1" if healthy else "0", + "last_check": current_time, + } + batch = Batch(is_atomic=False) + batch.hset(key, data) + batch.expire(key, ROUTE_HEALTH_STATUS_TTL_SEC) + + async with self._client.client() as conn: + await conn.exec(batch, raise_on_error=True) + + @valkey_schedule_resilience.apply() + async def get_route_health_statuses_batch( + self, replica_ids: Sequence[ReplicaID] + ) -> Mapping[ReplicaID, RouteHealthStatus | None]: + """ + Batch get RouteHealthStatus from Valkey. + None means no recent health check (key missing or TTL expired) → DEGRADED. + + :param replica_ids: Replica IDs to look up + :return: Mapping of replica_id to RouteHealthStatus (None if missing or expired) + """ + if not replica_ids: + return {} + + batch = Batch(is_atomic=False) + for replica_id in replica_ids: + batch.hgetall(self._get_route_health_status_key(replica_id)) + + async with self._client.client() as conn: + results = await conn.exec(batch, raise_on_error=False) + if results is None: + return dict.fromkeys(replica_ids) + + statuses: dict[ReplicaID, RouteHealthStatus | None] = {} + for i, replica_id in enumerate(replica_ids): + hgetall_result = results[i] if len(results) > i else None + if not hgetall_result: + statuses[replica_id] = None + continue + raw = cast(dict[bytes, bytes], hgetall_result) + if not raw or b"replica_id" not in raw: + statuses[replica_id] = None + continue + data = {k.decode(): v.decode() for k, v in raw.items()} + statuses[replica_id] = RouteHealthStatus.from_valkey_hash(data) + + return statuses + @valkey_schedule_resilience.apply() async def close(self) -> None: """ diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py index fdaf250493e..268bf3eabe9 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py @@ -13,19 +13,19 @@ class RouteProbeTarget: """Probe configuration for a route stored in Valkey. - Stored as a hash at key `route_probe:{route_id}`. + Stored as a hash at key `route_probe:{replica_id}`. Written once by the coordinator when the route enters WARMING_UP (host/port available). Read by RouteHealthObserver to know what endpoint to probe. """ - route_id: ReplicaID + replica_id: ReplicaID health_path: str inference_port: int replica_host: str def to_valkey_hash(self) -> Mapping[str, str]: return { - "route_id": str(self.route_id), + "replica_id": str(self.replica_id), "health_path": self.health_path, "inference_port": str(self.inference_port), "replica_host": self.replica_host, @@ -34,7 +34,7 @@ def to_valkey_hash(self) -> Mapping[str, str]: @classmethod def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteProbeTarget: return cls( - route_id=ReplicaID(UUID(data["route_id"])), + replica_id=ReplicaID(UUID(data["replica_id"])), health_path=data["health_path"], inference_port=int(data["inference_port"]), replica_host=data["replica_host"], diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py index 65b50dc1ec5..df79a68408e 100644 --- a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py @@ -22,7 +22,9 @@ HealthCheckStatus, ValkeyScheduleClient, ) +from ai.backend.common.clients.valkey_client.valkey_schedule.types import RouteProbeTarget from ai.backend.common.defs import REDIS_LIVE_DB +from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.typed_validators import HostPortPair as HostPortPairModel from ai.backend.common.types import AgentId, KernelId, SessionId, ValkeyTarget @@ -898,3 +900,196 @@ async def test_remove_deletes_only_specified_sessions( result = await valkey_schedule_client.get_force_terminated_sessions() assert result == [sid_keep] + + +class TestRouteProbeTargetClient: + """Test ValkeyScheduleClient methods for RouteProbeTarget.""" + + @pytest.fixture + async def valkey_schedule_client( + self, + redis_container: tuple[str, HostPortPairModel], + ) -> AsyncGenerator[ValkeyScheduleClient, None]: + _, hostport_pair = redis_container + client = await ValkeyScheduleClient.create( + valkey_target=ValkeyTarget(addr=hostport_pair.address), + db_id=REDIS_LIVE_DB, + human_readable_name="test-route-probe-target", + ) + try: + yield client + finally: + await client.close() + + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + def _make_target(self, replica_id: ReplicaID) -> RouteProbeTarget: + return RouteProbeTarget( + replica_id=replica_id, + health_path="/health", + inference_port=8080, + replica_host="10.0.0.1", + ) + + async def test_register_and_get( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + target = self._make_target(replica_id) + await valkey_schedule_client.register_route_probe_targets_batch([target]) + results = await valkey_schedule_client.get_route_probe_targets_batch([replica_id]) + assert results[replica_id] == target + + async def test_register_batch_multiple( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + targets = [self._make_target(ReplicaID(uuid4())) for _ in range(3)] + await valkey_schedule_client.register_route_probe_targets_batch(targets) + replica_ids = [t.replica_id for t in targets] + results = await valkey_schedule_client.get_route_probe_targets_batch(replica_ids) + for target in targets: + assert results[target.replica_id] == target + + async def test_get_missing_returns_none( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + results = await valkey_schedule_client.get_route_probe_targets_batch([replica_id]) + assert results[replica_id] is None + + async def test_register_empty_does_nothing( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + await valkey_schedule_client.register_route_probe_targets_batch([]) + + async def test_get_empty_returns_empty_dict( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + results = await valkey_schedule_client.get_route_probe_targets_batch([]) + assert results == {} + + async def test_register_overwrites_existing( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.register_route_probe_targets_batch([ + self._make_target(replica_id) + ]) + updated = RouteProbeTarget( + replica_id=replica_id, + health_path="/healthz", + inference_port=9000, + replica_host="10.0.0.2", + ) + await valkey_schedule_client.register_route_probe_targets_batch([updated]) + results = await valkey_schedule_client.get_route_probe_targets_batch([replica_id]) + assert results[replica_id] == updated + + +class TestRouteHealthStatusClient: + """Test ValkeyScheduleClient methods for RouteHealthStatus.""" + + @pytest.fixture + async def valkey_schedule_client( + self, + redis_container: tuple[str, HostPortPairModel], + ) -> AsyncGenerator[ValkeyScheduleClient, None]: + _, hostport_pair = redis_container + client = await ValkeyScheduleClient.create( + valkey_target=ValkeyTarget(addr=hostport_pair.address), + db_id=REDIS_LIVE_DB, + human_readable_name="test-route-health-status", + ) + try: + yield client + finally: + await client.close() + + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + async def test_record_healthy_and_get( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.record_route_health_status(replica_id, healthy=True) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + status = results[replica_id] + assert status is not None + assert status.healthy is True + assert status.last_check > 0 + + async def test_record_unhealthy_and_get( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.record_route_health_status(replica_id, healthy=False) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + status = results[replica_id] + assert status is not None + assert status.healthy is False + + async def test_get_missing_returns_none( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + assert results[replica_id] is None + + async def test_get_empty_returns_empty_dict( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + results = await valkey_schedule_client.get_route_health_statuses_batch([]) + assert results == {} + + async def test_key_deletion_simulates_ttl_expiry( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + """Deleting the key (simulating TTL expiry) results in None → DEGRADED.""" + await valkey_schedule_client.record_route_health_status(replica_id, healthy=True) + key = valkey_schedule_client._get_route_health_status_key(replica_id) + async with valkey_schedule_client._client.client() as conn: + await conn.delete([key]) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + assert results[replica_id] is None + + async def test_record_batch_multiple( + self, + valkey_schedule_client: ValkeyScheduleClient, + ) -> None: + replica_ids = [ReplicaID(uuid4()) for _ in range(3)] + for rid in replica_ids: + await valkey_schedule_client.record_route_health_status(rid, healthy=True) + results = await valkey_schedule_client.get_route_health_statuses_batch(replica_ids) + assert len(results) == 3 + for rid in replica_ids: + status = results[rid] + assert status is not None + assert status.healthy is True + + async def test_record_overwrites_previous( + self, + valkey_schedule_client: ValkeyScheduleClient, + replica_id: ReplicaID, + ) -> None: + await valkey_schedule_client.record_route_health_status(replica_id, healthy=True) + await valkey_schedule_client.record_route_health_status(replica_id, healthy=False) + results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) + status = results[replica_id] + assert status is not None + assert status.healthy is False diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py new file mode 100644 index 00000000000..8b274e09caf --- /dev/null +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py @@ -0,0 +1,101 @@ +"""Unit tests for RouteProbeTarget and RouteHealthStatus Valkey type serialization.""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + RouteHealthStatus, + RouteProbeTarget, +) +from ai.backend.common.identifier.replica import ReplicaID + + +class TestRouteProbeTargetSerialization: + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + @pytest.fixture + def target(self, replica_id: ReplicaID) -> RouteProbeTarget: + return RouteProbeTarget( + replica_id=replica_id, + health_path="/health", + inference_port=8080, + replica_host="10.0.0.1", + ) + + def test_to_valkey_hash(self, target: RouteProbeTarget) -> None: + h = target.to_valkey_hash() + assert h["replica_id"] == str(target.replica_id) + assert h["health_path"] == "/health" + assert h["inference_port"] == "8080" + assert h["replica_host"] == "10.0.0.1" + + def test_from_valkey_hash(self, replica_id: ReplicaID) -> None: + data = { + "replica_id": str(replica_id), + "health_path": "/healthz", + "inference_port": "9000", + "replica_host": "192.168.1.100", + } + target = RouteProbeTarget.from_valkey_hash(data) + assert target.replica_id == replica_id + assert target.health_path == "/healthz" + assert target.inference_port == 9000 + assert target.replica_host == "192.168.1.100" + + def test_round_trip(self, target: RouteProbeTarget) -> None: + restored = RouteProbeTarget.from_valkey_hash(target.to_valkey_hash()) + assert restored == target + + +class TestRouteHealthStatusSerialization: + @pytest.fixture + def replica_id(self) -> ReplicaID: + return ReplicaID(uuid4()) + + @pytest.fixture + def healthy_status(self, replica_id: ReplicaID) -> RouteHealthStatus: + return RouteHealthStatus( + replica_id=replica_id, + healthy=True, + last_check=1700000000, + ) + + def test_to_valkey_hash_healthy(self, healthy_status: RouteHealthStatus) -> None: + h = healthy_status.to_valkey_hash() + assert h["replica_id"] == str(healthy_status.replica_id) + assert h["healthy"] == "1" + assert h["last_check"] == "1700000000" + + def test_to_valkey_hash_unhealthy(self, replica_id: ReplicaID) -> None: + status = RouteHealthStatus( + replica_id=replica_id, + healthy=False, + last_check=1700000000, + ) + assert status.to_valkey_hash()["healthy"] == "0" + + def test_from_valkey_hash(self, replica_id: ReplicaID) -> None: + data = { + "replica_id": str(replica_id), + "healthy": "1", + "last_check": "1700000000", + } + status = RouteHealthStatus.from_valkey_hash(data) + assert status.replica_id == replica_id + assert status.healthy is True + assert status.last_check == 1700000000 + + def test_from_valkey_hash_missing_optional_fields(self, replica_id: ReplicaID) -> None: + """Missing healthy/last_check fields default to safe values.""" + status = RouteHealthStatus.from_valkey_hash({"replica_id": str(replica_id)}) + assert status.healthy is False + assert status.last_check == 0 + + def test_round_trip(self, healthy_status: RouteHealthStatus) -> None: + restored = RouteHealthStatus.from_valkey_hash(healthy_status.to_valkey_hash()) + assert restored == healthy_status From 2fa5c0db6432961408cbee84a03e010ba4e0e990 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 12:31:50 +0900 Subject: [PATCH 03/13] feat(BA-6035): add health_check to RoutingRow and refactor DeploymentInfo revisions Key changes: - Add health_check: PydanticColumn(ModelHealthCheck) to RoutingRow (nullable); set from revision at route creation, eliminating per-cycle revision fetches in _register_route_probe_targets and sync_route_probe_targets - Add health_check field to RouteCreatorSpec (required) and RouteData (required, no default); both callers now resolve health config from revision at creation - Refactor DeploymentInfo: replace model_revisions list + current_revision_id + deploying_revision_id with direct current_revision and deploying_revision fields - Add current_revision_row / deploying_revision_row 1:1 relationships on EndpointRow; to_deployment_info() uses these instead of iterating revisions - Replace all selectinload(EndpointRow.revisions) in db_source with the two 1:1 relationships; remove unnecessary image_row eager loads and duplicates - Remove redundant get_revision() call in deployment_controller.update_deployment - Remove unnecessary ReplicaID() casts on route.route_id (already ReplicaID) - Alembic migration: c3d4e5f6a7b8 adds health_check JSONB column to routings Co-Authored-By: Claude Sonnet 4.6 --- .../manager/api/rest/service/handler.py | 7 +- .../backend/manager/data/deployment/types.py | 19 +- ...d4e5f6a7b8_add_health_check_to_routings.py | 27 +++ src/ai/backend/manager/models/endpoint/row.py | 63 ++++-- src/ai/backend/manager/models/routing/row.py | 7 + .../repositories/deployment/creators/route.py | 3 + .../deployment/db_source/db_source.py | 68 ++---- .../repositories/deployment/types/endpoint.py | 2 + .../manager/services/deployment/service.py | 26 +-- .../deployment/deployment_controller.py | 14 +- .../manager/sokovan/deployment/executor.py | 40 +++- .../sokovan/deployment/handlers/deploying.py | 6 +- .../sokovan/deployment/handlers/replica.py | 2 +- .../sokovan/deployment/route/executor.py | 96 +++++---- .../sokovan/deployment/strategy/evaluator.py | 4 +- .../deployment/strategy/rolling_update.py | 17 +- .../sokovan/deployment/executor/conftest.py | 4 +- .../deployment/route/executor/conftest.py | 3 +- .../test_check_route_health_register.py | 1 + .../route/executor/test_initial_delay.py | 199 ++++++++++-------- .../test_register_unregister_routes.py | 1 + .../route/executor/test_route_executor.py | 38 +++- .../executor/test_terminate_routes_drain.py | 2 + .../handlers/test_health_check_handler.py | 1 + .../handlers/test_terminating_handler.py | 1 + .../route/test_coordinator_history.py | 2 + 26 files changed, 390 insertions(+), 263 deletions(-) create mode 100644 src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py diff --git a/src/ai/backend/manager/api/rest/service/handler.py b/src/ai/backend/manager/api/rest/service/handler.py index c5529faaa1c..6868a977316 100644 --- a/src/ai/backend/manager/api/rest/service/handler.py +++ b/src/ai/backend/manager/api/rest/service/handler.py @@ -167,11 +167,8 @@ def _serve_info_from_dto(dto: ServiceInfo, runtime_variant_name: RuntimeVariant) def _resolve_target_revision_data(info: DeploymentInfo) -> ModelRevisionData | None: - """Resolve the target revision data by id (current first, then deploying).""" - target_id = info.current_revision_id or info.deploying_revision_id - if target_id is None: - return None - return next((r for r in info.model_revisions if r.id == target_id), None) + """Resolve the target revision data (current first, then deploying).""" + return info.current_revision or info.deploying_revision def _serve_info_from_deployment_info( diff --git a/src/ai/backend/manager/data/deployment/types.py b/src/ai/backend/manager/data/deployment/types.py index f164cb412f7..1e94c790dde 100644 --- a/src/ai/backend/manager/data/deployment/types.py +++ b/src/ai/backend/manager/data/deployment/types.py @@ -29,7 +29,6 @@ from ai.backend.common.identifier.runtime_variant import RuntimeVariantID from ai.backend.common.identifier.vfolder import VFolderUUID from ai.backend.manager.data.session.options import HandlerOptions -from ai.backend.manager.errors.deployment import DeploymentRevisionNotFound if TYPE_CHECKING: from ai.backend.manager.data.session.types import SchedulingResult, SubStepResult @@ -713,26 +712,12 @@ class DeploymentInfo: state: DeploymentState replica: ReplicaData network: DeploymentNetworkData - model_revisions: list[ModelRevisionData] options: DeploymentOptions - current_revision_id: DeploymentRevisionID | None = None + current_revision: ModelRevisionData | None = None policy: DeploymentPolicyData | None = None - deploying_revision_id: DeploymentRevisionID | None = None + deploying_revision: ModelRevisionData | None = None sub_step: DeploymentLifecycleSubStep | None = None - def resolve_revision_data(self, revision_id: DeploymentRevisionID) -> ModelRevisionData: - """Find a ``ModelRevisionData`` by id from ``model_revisions``. - - Raises: - DeploymentRevisionNotFound: If the revision is not found. - """ - for revision in self.model_revisions: - if revision.id == revision_id: - return revision - raise DeploymentRevisionNotFound( - f"Revision {revision_id} not found in model_revisions of deployment {self.id}" - ) - @dataclass(frozen=True) class DeploymentLastHistory: diff --git a/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py new file mode 100644 index 00000000000..ffe75d4deb4 --- /dev/null +++ b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py @@ -0,0 +1,27 @@ +"""Add health_check column to routings table. + +Revision ID: c3d4e5f6a7b8 +Revises: b2c3d4e5f6a7 +Create Date: 2026-05-15 + +""" + +# Part of: 26.5.0 + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "c3d4e5f6a7b8" +down_revision = "b2c3d4e5f6a7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + conn.exec_driver_sql("ALTER TABLE routings ADD COLUMN IF NOT EXISTS health_check JSONB") + + +def downgrade() -> None: + conn = op.get_bind() + conn.exec_driver_sql("ALTER TABLE routings DROP COLUMN IF EXISTS health_check") diff --git a/src/ai/backend/manager/models/endpoint/row.py b/src/ai/backend/manager/models/endpoint/row.py index a653c5b2cfe..2b7e5ed8af9 100644 --- a/src/ai/backend/manager/models/endpoint/row.py +++ b/src/ai/backend/manager/models/endpoint/row.py @@ -69,6 +69,7 @@ DeploymentMetadata, DeploymentNetworkData, DeploymentOptions, + DeploymentPolicyData, DeploymentState, DeploymentSummaryData, ModelDeploymentAutoScalingRuleData, @@ -132,6 +133,18 @@ def _get_endpoint_revisions_join_condition() -> Any: return EndpointRow.id == foreign(DeploymentRevisionRow.endpoint) +def _get_current_revision_row_join_condition() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.deployment_revision import DeploymentRevisionRow + + return EndpointRow.current_revision == DeploymentRevisionRow.id + + +def _get_deploying_revision_row_join_condition() -> sa.ColumnElement[bool]: + from ai.backend.manager.models.deployment_revision import DeploymentRevisionRow + + return EndpointRow.deploying_revision == DeploymentRevisionRow.id + + def _get_endpoint_auto_scaling_policy_join_condition() -> Any: from ai.backend.manager.models.deployment_auto_scaling_policy import ( DeploymentAutoScalingPolicyRow, @@ -293,6 +306,20 @@ class EndpointRow(Base): # type: ignore[misc] primaryjoin=_get_endpoint_revisions_join_condition, order_by="DeploymentRevisionRow.revision_number.desc()", ) + current_revision_row: Mapped[DeploymentRevisionRow | None] = relationship( + "DeploymentRevisionRow", + primaryjoin=_get_current_revision_row_join_condition, + foreign_keys="EndpointRow.current_revision", + viewonly=True, + uselist=False, + ) + deploying_revision_row: Mapped[DeploymentRevisionRow | None] = relationship( + "DeploymentRevisionRow", + primaryjoin=_get_deploying_revision_row_join_condition, + foreign_keys="EndpointRow.deploying_revision", + viewonly=True, + uselist=False, + ) auto_scaling_policy: Mapped[DeploymentAutoScalingPolicyRow | None] = relationship( "DeploymentAutoScalingPolicyRow", @@ -729,24 +756,23 @@ def from_deployment_creator( def to_deployment_info(self) -> DeploymentInfo: """Convert EndpointRow to DeploymentInfo dataclass using revision data.""" - policy_data = None - if self.deployment_policy is not None: - policy_data = self.deployment_policy.to_data() - - model_revisions: list[ModelRevisionData] = [] - for rev_row in self.revisions: - if rev_row.id == self.current_revision or rev_row.id == self.deploying_revision: - model_revisions.append(rev_row.to_data()) - - info = self._to_deployment_info_with_revisions(model_revisions) - info.policy = policy_data - return info + return self._build_deployment_info( + current_revision=( + self.current_revision_row.to_data() if self.current_revision_row else None + ), + deploying_revision=( + self.deploying_revision_row.to_data() if self.deploying_revision_row else None + ), + policy=self.deployment_policy.to_data() if self.deployment_policy is not None else None, + ) - def _to_deployment_info_with_revisions( + def _build_deployment_info( self, - model_revisions: list[ModelRevisionData], + current_revision: ModelRevisionData | None, + deploying_revision: ModelRevisionData | None, + policy: DeploymentPolicyData | None = None, ) -> DeploymentInfo: - """Build DeploymentInfo with pre-built model_revisions dict.""" + """Build DeploymentInfo with current and deploying revision data.""" return DeploymentInfo( id=self.id, metadata=DeploymentMetadata( @@ -775,12 +801,11 @@ def _to_deployment_info_with_revisions( url=self.url, preferred_domain_name=None, ), - model_revisions=list(model_revisions), options=self.options, - current_revision_id=self.current_revision, - deploying_revision_id=self.deploying_revision, + current_revision=current_revision, + deploying_revision=deploying_revision, sub_step=self.sub_step, - policy=self.deployment_policy.to_data() if self.deployment_policy is not None else None, + policy=policy, ) diff --git a/src/ai/backend/manager/models/routing/row.py b/src/ai/backend/manager/models/routing/row.py index 537d4bf48fd..ba4e3e7a825 100644 --- a/src/ai/backend/manager/models/routing/row.py +++ b/src/ai/backend/manager/models/routing/row.py @@ -12,6 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.types import SessionId @@ -27,6 +28,7 @@ from ai.backend.manager.models.base import ( GUID, Base, + PydanticColumn, StrEnumType, ) @@ -110,6 +112,11 @@ class RoutingRow(Base): # type: ignore[misc] ) replica_port: Mapped[int | None] = mapped_column("replica_port", sa.Integer, nullable=True) + # Health check config (copied from revision at creation; None = no health check) + health_check: Mapped[ModelHealthCheck | None] = mapped_column( + "health_check", PydanticColumn(ModelHealthCheck), nullable=True, default=None + ) + # Revision reference without FK (relationship only) revision: Mapped[uuid.UUID] = mapped_column("revision", GUID, nullable=False) traffic_status: Mapped[RouteTrafficStatus] = mapped_column( diff --git a/src/ai/backend/manager/repositories/deployment/creators/route.py b/src/ai/backend/manager/repositories/deployment/creators/route.py index d6a4641f766..7b2f1b72ce8 100644 --- a/src/ai/backend/manager/repositories/deployment/creators/route.py +++ b/src/ai/backend/manager/repositories/deployment/creators/route.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import Any, override +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID from ai.backend.manager.data.deployment.types import ( @@ -33,6 +34,7 @@ class RouteCreatorSpec(CreatorSpec[RoutingRow]): domain: str project_id: uuid.UUID revision_id: DeploymentRevisionID + health_check: ModelHealthCheck | None traffic_ratio: float = 1.0 traffic_status: RouteTrafficStatus = RouteTrafficStatus.INACTIVE @@ -49,6 +51,7 @@ def build_row(self) -> RoutingRow: traffic_ratio=self.traffic_ratio, revision=self.revision_id, traffic_status=self.traffic_status, + health_check=self.health_check, ) diff --git a/src/ai/backend/manager/repositories/deployment/db_source/db_source.py b/src/ai/backend/manager/repositories/deployment/db_source/db_source.py index c5f364cf920..df3701dfe52 100644 --- a/src/ai/backend/manager/repositories/deployment/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/deployment/db_source/db_source.py @@ -332,13 +332,9 @@ async def create_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == endpoint.id) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), ) ) result = await db_sess.execute(stmt) @@ -393,12 +389,8 @@ async def get_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == endpoint_id) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) @@ -428,12 +420,8 @@ async def get_deployments_by_ids( ) ) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) @@ -518,7 +506,8 @@ async def search_deployments_with_last_history( """ async with self._begin_readonly_session_read_committed() as db_sess: query = sa.select(EndpointRow).options( - selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) query_result = await execute_batch_querier(db_sess, query, querier) @@ -567,12 +556,8 @@ async def list_endpoints_by_name( EndpointRow.lifecycle_stage == EndpointLifecycle.CREATED, ) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), ) ) @@ -635,9 +620,8 @@ async def get_modified_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == endpoint_id) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) @@ -674,13 +658,9 @@ async def update_endpoint_with_spec( sa.select(EndpointRow) .where(EndpointRow.id == updater.pk_value) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), ) ) query_result = await db_sess.execute(stmt) @@ -999,6 +979,7 @@ async def get_routes_by_endpoint( created_at=row.created_at, revision_id=DeploymentRevisionID(row.revision), traffic_status=row.traffic_status, + health_check=row.health_check, replica_host=row.replica_host, replica_port=row.replica_port, updated_at=row.updated_at, @@ -1141,8 +1122,8 @@ async def search_endpoints( """ async with self._begin_readonly_session_read_committed() as db_sess: query = sa.select(EndpointRow).options( - selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row), - selectinload(EndpointRow.revisions).selectinload(DeploymentRevisionRow.image_row), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) @@ -1356,9 +1337,8 @@ async def get_endpoints_with_autoscaling_rules( EndpointRow.id == EndpointAutoScalingRuleRow.endpoint_id, ) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), ) .distinct() ) @@ -1619,6 +1599,7 @@ async def search_route_datas( created_at=row.created_at, revision_id=DeploymentRevisionID(row.revision), traffic_status=row.traffic_status, + health_check=row.health_check, replica_host=row.replica_host, replica_port=row.replica_port, updated_at=row.updated_at, @@ -1664,6 +1645,7 @@ async def search_route_datas_with_last_history( created_at=row.created_at, revision_id=DeploymentRevisionID(row.revision), traffic_status=row.traffic_status, + health_check=row.health_check, replica_host=row.replica_host, replica_port=row.replica_port, updated_at=row.updated_at, @@ -2729,12 +2711,8 @@ async def update_endpoint( sa.select(EndpointRow) .where(EndpointRow.id == updater.pk_value) .options( - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), - selectinload(EndpointRow.revisions).selectinload( - DeploymentRevisionRow.image_row - ), + selectinload(EndpointRow.current_revision_row), + selectinload(EndpointRow.deploying_revision_row), selectinload(EndpointRow.deployment_policy), ) ) diff --git a/src/ai/backend/manager/repositories/deployment/types/endpoint.py b/src/ai/backend/manager/repositories/deployment/types/endpoint.py index d8779fde3a1..15d4ed56a0c 100644 --- a/src/ai/backend/manager/repositories/deployment/types/endpoint.py +++ b/src/ai/backend/manager/repositories/deployment/types/endpoint.py @@ -11,6 +11,7 @@ import sqlalchemy as sa +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.data.endpoint.types import EndpointLifecycle from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID @@ -78,6 +79,7 @@ class RouteData: created_at: datetime revision_id: DeploymentRevisionID traffic_status: RouteTrafficStatus + health_check: ModelHealthCheck | None replica_host: str | None = None replica_port: int | None = None updated_at: datetime | None = None diff --git a/src/ai/backend/manager/services/deployment/service.py b/src/ai/backend/manager/services/deployment/service.py index 930d588e57a..80ceffbd1b4 100644 --- a/src/ai/backend/manager/services/deployment/service.py +++ b/src/ai/backend/manager/services/deployment/service.py @@ -224,21 +224,7 @@ def _convert_deployment_info_to_data(info: DeploymentInfo) -> ModelDeploymentDat Note: Some fields are set to defaults as DeploymentInfo doesn't have all the data. """ - revision: ModelRevisionData | None = None - if info.current_revision_id is not None: - revision = next( - (r for r in info.model_revisions if r.id == info.current_revision_id), - None, - ) - if revision is None: - log.error( - "Deployment {} has current_revision_id {} but no matching " - "ModelRevisionData was found in DeploymentInfo.model_revisions; " - "current_revision will be reported as null. This usually means " - "EndpointRow.revisions was not eagerly loaded by the caller.", - info.id, - info.current_revision_id, - ) + revision: ModelRevisionData | None = info.current_revision desired_count = info.replica.desired_replica_count if desired_count is None: @@ -257,10 +243,14 @@ def _convert_deployment_info_to_data(info: DeploymentInfo) -> ModelDeploymentDat updated_at=info.metadata.created_at or datetime.now(UTC), ), network_access=info.network, - revision_history_ids=[info.current_revision_id] if info.current_revision_id else [], + revision_history_ids=[info.current_revision.id] + if info.current_revision is not None + else [], revision=revision, - current_revision_id=info.current_revision_id, - deploying_revision_id=info.deploying_revision_id, + current_revision_id=info.current_revision.id if info.current_revision is not None else None, + deploying_revision_id=info.deploying_revision.id + if info.deploying_revision is not None + else None, scaling_rule_ids=[], # Not available in DeploymentInfo replica_state=ReplicaStateData( desired_replica_count=desired_count, diff --git a/src/ai/backend/manager/sokovan/deployment/deployment_controller.py b/src/ai/backend/manager/sokovan/deployment/deployment_controller.py index 90b6ac9d067..c2d67feac22 100644 --- a/src/ai/backend/manager/sokovan/deployment/deployment_controller.py +++ b/src/ai/backend/manager/sokovan/deployment/deployment_controller.py @@ -320,12 +320,11 @@ async def update_deployment( modified_endpoint = await self._deployment_repository.get_modified_endpoint( endpoint_id=endpoint_id, updater=updater ) - if modified_endpoint.current_revision_id is not None: - current_revision = await self._deployment_repository.get_revision( - modified_endpoint.current_revision_id - ) + if modified_endpoint.current_revision is not None: await self._scheduling_controller.validate_session_spec( - SessionValidationSpec.from_revision(model_revision=current_revision) + SessionValidationSpec.from_revision( + model_revision=modified_endpoint.current_revision + ) ) res = await self._deployment_repository.update_endpoint_with_spec(updater) try: @@ -557,7 +556,10 @@ async def activate_revision( # 2. Validate deployment state deployment_info = await self._deployment_repository.get_endpoint_info(deployment_id) - if deployment_info.current_revision_id == revision_id: + if ( + deployment_info.current_revision is not None + and deployment_info.current_revision.id == revision_id + ): raise InvalidEndpointState( f"Revision {revision_id} is already the current revision " f"of deployment {deployment_id}." diff --git a/src/ai/backend/manager/sokovan/deployment/executor.py b/src/ai/backend/manager/sokovan/deployment/executor.py index f52a37eb6cc..6ea2dc8e4e4 100644 --- a/src/ai/backend/manager/sokovan/deployment/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/executor.py @@ -40,6 +40,7 @@ from ai.backend.manager.data.deployment.scale import AutoScalingRule from ai.backend.manager.data.deployment.types import ( DeploymentInfo, + ModelRevisionData, RouteInfo, RouteStatus, RouteTrafficStatus, @@ -47,7 +48,7 @@ from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.data.prometheus_query_preset import PrometheusQueryPresetData from ai.backend.manager.data.resource.types import ScalingGroupProxyTarget -from ai.backend.manager.errors.deployment import ReplicaCountMismatch +from ai.backend.manager.errors.deployment import DeploymentRevisionNotFound, ReplicaCountMismatch from ai.backend.manager.models.routing import RoutingRow from ai.backend.manager.models.routing.conditions import RouteConditions from ai.backend.manager.repositories.base.rbac.entity_creator import RBACEntityCreator @@ -304,7 +305,7 @@ async def check_ready_deployments_that_need_scaling( # the handler to attempt scaling would permanently wedge the # deployment in SCALING because scale_deployment() would then # refuse to act on a None revision id. - if deployment.deployment_info.current_revision_id is None: + if deployment.deployment_info.current_revision is None: skipped.append(deployment) continue try: @@ -351,12 +352,12 @@ async def scale_deployment( # Phase 2: Evaluate scaling (per-deployment) for deployment in deployments: info = deployment.deployment_info - if info.current_revision_id is None: + if info.current_revision is None: skipped.append(deployment) continue try: out_creators, in_route_ids = self._evaluate_deployment_scaling( - info, route_map, info.current_revision_id + info, route_map, info.current_revision.id ) if out_creators or in_route_ids: scale_out_creators.extend(out_creators) @@ -442,7 +443,7 @@ async def calculate_desired_replicas( # the two stay in lock-step. deployments_to_calculate: list[DeploymentWithHistory] = [] for deployment in deployments: - if deployment.deployment_info.current_revision_id is None: + if deployment.deployment_info.current_revision is None: skipped.append(deployment) continue deployments_to_calculate.append(deployment) @@ -615,7 +616,20 @@ async def _build_endpoint_item( Resolves the runtime variant id to a name at this boundary since the AppProxy wire API still keys on the variant name string. """ - target_revision = deployment.resolve_revision_data(revision_id) + if ( + deployment.current_revision is not None + and deployment.current_revision.id == revision_id + ): + target_revision = deployment.current_revision + elif ( + deployment.deploying_revision is not None + and deployment.deploying_revision.id == revision_id + ): + target_revision = deployment.deploying_revision + else: + raise DeploymentRevisionNotFound( + f"Revision {revision_id} not found in deployment {deployment.id}" + ) health_check_config = None if target_revision.model_definition: @@ -687,6 +701,19 @@ def _evaluate_deployment_scaling( if len(routes) < target_count: # Build creators for scale out new_replica_count = target_count - len(routes) + revision_data: ModelRevisionData | None + if ( + deployment.current_revision is not None + and deployment.current_revision.id == revision_id + ): + revision_data = deployment.current_revision + else: + revision_data = deployment.deploying_revision + health_check = ( + revision_data.model_definition.health_check_config() + if revision_data is not None and revision_data.model_definition + else None + ) for _ in range(new_replica_count): creator_spec = RouteCreatorSpec( deployment_id=deployment.id, @@ -694,6 +721,7 @@ def _evaluate_deployment_scaling( domain=deployment.metadata.domain, project_id=deployment.metadata.project, revision_id=revision_id, + health_check=health_check, ) scale_out_creators.append( RBACEntityCreator( diff --git a/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py b/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py index 158588dfcbe..65a7de192e1 100644 --- a/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py +++ b/src/ai/backend/manager/sokovan/deployment/handlers/deploying.py @@ -160,9 +160,9 @@ async def _ensure_endpoints_registered( info = deployment.deployment_info if info.network.url: continue - if info.deploying_revision_id is None: + if info.deploying_revision is None: continue - entries.append((deployment, info.deploying_revision_id)) + entries.append((deployment, info.deploying_revision.id)) if not entries: return set() @@ -186,7 +186,7 @@ async def execute( d.deployment_info.id for d in deployments if not d.deployment_info.network.url - and d.deployment_info.deploying_revision_id is not None + and d.deployment_info.deploying_revision is not None } if failed_registration_ids: deployments = [ diff --git a/src/ai/backend/manager/sokovan/deployment/handlers/replica.py b/src/ai/backend/manager/sokovan/deployment/handlers/replica.py index 4536437cf2b..7f40d91b026 100644 --- a/src/ai/backend/manager/sokovan/deployment/handlers/replica.py +++ b/src/ai/backend/manager/sokovan/deployment/handlers/replica.py @@ -96,7 +96,7 @@ async def execute( """ log.debug("Checking deployment replicas") - scalable = [d for d in deployments if d.deployment_info.current_revision_id is not None] + scalable = [d for d in deployments if d.deployment_info.current_revision is not None] if len(scalable) != len(deployments): skipped = len(deployments) - len(scalable) log.debug( diff --git a/src/ai/backend/manager/sokovan/deployment/route/executor.py b/src/ai/backend/manager/sokovan/deployment/route/executor.py index 6e3c54528cf..53e032e144e 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/route/executor.py @@ -7,7 +7,7 @@ from ai.backend.common.clients.http_client.client_pool import ClientPool from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteHealthRecord, + RouteProbeTarget, ValkeyScheduleClient, ) from ai.backend.common.dto.appproxy_coordinator.v2.endpoint.request import ( @@ -242,7 +242,7 @@ async def check_starting_routes(self, routes: Sequence[RouteData]) -> RouteExecu if not routes: return RouteExecutionResult(successes=[], errors=[]) - route_ids = {ReplicaID(route.route_id) for route in routes} + route_ids = {route.route_id for route in routes} session_infos: dict[ReplicaID, RouteSessionInfo | None] = dict( await self._deployment_repo.fetch_route_session_kernel_infos(route_ids) ) @@ -252,7 +252,7 @@ async def check_starting_routes(self, routes: Sequence[RouteData]) -> RouteExecu updates: dict[ReplicaID, RouteSessionKernelInfo] = {} for route in routes: - replica_id = ReplicaID(route.route_id) + replica_id = route.route_id info = session_infos.get(replica_id) if info is None: @@ -293,7 +293,7 @@ async def check_starting_routes(self, routes: Sequence[RouteData]) -> RouteExecu if updates: await self._deployment_repo.update_route_replica_info(updates) - await self._initialize_health_records(successes, updates) + await self._register_route_probe_targets(successes, updates) return RouteExecutionResult( successes=successes, @@ -341,48 +341,66 @@ async def check_running_routes(self, routes: Sequence[RouteData]) -> RouteExecut errors=errors, ) - async def _initialize_health_records( + async def _register_route_probe_targets( self, routes: Sequence[RouteData], replica_info: Mapping[ReplicaID, RouteSessionKernelInfo], ) -> None: - """Create RouteHealthRecords in Valkey for routes that just got replica info.""" - revision_ids = {r.revision_id for r in routes} - health_configs = await self._deployment_repo.fetch_health_check_configs_by_revision_ids( - revision_ids - ) - route_id_strs = [str(r.route_id) for r in routes] - existing_running_at = await self._valkey_schedule.get_route_running_at_batch(route_id_strs) - current_time = await self._valkey_schedule.get_redis_time() - - records: list[RouteHealthRecord] = [] + """Register RouteProbeTargets in Valkey for routes that just got replica info.""" + targets: list[RouteProbeTarget] = [] for route in routes: - kernel = replica_info[route.route_id] - health_config = health_configs.get(route.revision_id) - - health_path = health_config.path if health_config else "/" - initial_delay = health_config.initial_delay if health_config else 60.0 - created_at = int(route.created_at.timestamp()) - - route_id_str = str(route.route_id) - running_at = existing_running_at.get(route_id_str) or current_time - initial_delay_until = running_at + int(initial_delay) - - records.append( - RouteHealthRecord( - route_id=route_id_str, - created_at=created_at, - initial_delay_until=initial_delay_until, + replica_id = route.route_id + kernel = replica_info[replica_id] + health_path = route.health_check.path if route.health_check else "/" + targets.append( + RouteProbeTarget( + replica_id=replica_id, health_path=health_path, inference_port=kernel.replica_port, replica_host=kernel.replica_host, - running_at=running_at, ) ) - if records: - await self._valkey_schedule.initialize_route_health_records_batch(records) - log.debug("Initialized {} RouteHealthRecords in Valkey", len(records)) + if targets: + await self._valkey_schedule.register_route_probe_targets_batch(targets) + log.debug("Registered {} RouteProbeTargets in Valkey", len(targets)) + + async def sync_route_probe_targets(self, routes: Sequence[RouteData]) -> RouteExecutionResult: + """Sync RouteProbeTargets to Valkey for routes with known replica info. + + Handles two cases: + - Valkey data lost (restart, eviction) → re-registers probe targets + - TTL refresh for long-running routes + + Routes without replica_host/replica_port are skipped silently. + """ + routes_with_info = [route for route in routes if route.replica_host and route.replica_port] + if not routes_with_info: + return RouteExecutionResult(successes=[], errors=[]) + + # Build probe targets (no phase — health config comes from RouteData) + targets: list[RouteProbeTarget] = [] + for route in routes_with_info: + health_path = route.health_check.path if route.health_check else "/" + targets.append( + RouteProbeTarget( + replica_id=route.route_id, + health_path=health_path, + inference_port=route.replica_port, # type: ignore[arg-type] + replica_host=route.replica_host, # type: ignore[arg-type] + ) + ) + + if targets: + with RouteRecorderContext.shared_phase( + "register_probe_targets", + entity_ids={t.replica_id for t in targets}, + ): + with RouteRecorderContext.shared_step("write_probe_targets"): + await self._valkey_schedule.register_route_probe_targets_batch(targets) + log.debug("Synced {} RouteProbeTargets to Valkey", len(targets)) + + return RouteExecutionResult(successes=[], errors=[]) async def check_warming_up_health(self, routes: Sequence[RouteData]) -> RouteExecutionResult: """Check health of PROVISIONING+WARMING_UP routes for initial activation. @@ -1056,10 +1074,10 @@ async def cleanup_routes_by_config(self, routes: Sequence[RouteData]) -> RouteEx else: deployment_cleanup_config[deployment.id] = set() valid_revisions: set[DeploymentRevisionID] = set() - if deployment.current_revision_id is not None: - valid_revisions.add(deployment.current_revision_id) - if deployment.deploying_revision_id is not None: - valid_revisions.add(deployment.deploying_revision_id) + if deployment.current_revision is not None: + valid_revisions.add(deployment.current_revision.id) + if deployment.deploying_revision is not None: + valid_revisions.add(deployment.deploying_revision.id) deployment_valid_revisions[deployment.id] = valid_revisions successes: list[RouteData] = [] diff --git a/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py b/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py index f3a669ceb0a..6606dd486f6 100644 --- a/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py +++ b/src/ai/backend/manager/sokovan/deployment/strategy/evaluator.py @@ -91,9 +91,9 @@ async def evaluate( # to count accumulated failures for rollback detection, but old # terminated routes are irrelevant and would bloat the result set. deploying_revision_ids = { - deployment.deploying_revision_id + deployment.deploying_revision.id for deployment in deployments - if deployment.deploying_revision_id is not None + if deployment.deploying_revision is not None } route_conditions: list[QueryCondition] = [ RouteConditions.by_endpoint_ids(endpoint_ids), diff --git a/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py b/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py index bc1311d558b..d22579c7030 100644 --- a/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py +++ b/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py @@ -87,13 +87,13 @@ def evaluate_cycle( if not isinstance(spec, RollingUpdateSpec): raise TypeError(f"Expected RollingUpdateSpec, got {type(spec).__name__}") desired = deployment.replica.target_replica_count - deploying_revision_id = deployment.deploying_revision_id - if deploying_revision_id is None: + deploying_revision = deployment.deploying_revision + if deploying_revision is None: raise InvalidEndpointState( f"Deployment {deployment.id} has DEPLOYING lifecycle but deploying_revision_id is None. " "This indicates an inconsistent state — the deployment will be skipped." ) - classified = self._classify_routes(routes, deploying_revision_id) + classified = self._classify_routes(routes, deploying_revision.id) log.info( "deployment {}: sub_step={}, routes total={}, " "old_active={}, new_prov={}, new_healthy={}, new_unhealthy={}, new_failed={}", @@ -267,10 +267,16 @@ def _build_route_creators( count: int, ) -> list[RBACEntityCreator[RoutingRow]]: """Build route creator specs for new revision routes.""" - if deployment.deploying_revision_id is None: + if deployment.deploying_revision is None: raise DeploymentHasNoTargetRevision( f"Cannot create routes: deployment {deployment.id} has no deploying revision" ) + revision_data = deployment.deploying_revision + health_check = ( + revision_data.model_definition.health_check_config() + if revision_data.model_definition + else None + ) creators: list[RBACEntityCreator[RoutingRow]] = [] for _ in range(count): spec = RouteCreatorSpec( @@ -278,7 +284,8 @@ def _build_route_creators( session_owner_id=deployment.metadata.session_owner, domain=deployment.metadata.domain, project_id=deployment.metadata.project, - revision_id=deployment.deploying_revision_id, + revision_id=deployment.deploying_revision.id, + health_check=health_check, ) creators.append( RBACEntityCreator( diff --git a/tests/unit/manager/sokovan/deployment/executor/conftest.py b/tests/unit/manager/sokovan/deployment/executor/conftest.py index 1e142fd480d..6dc6abb243e 100644 --- a/tests/unit/manager/sokovan/deployment/executor/conftest.py +++ b/tests/unit/manager/sokovan/deployment/executor/conftest.py @@ -165,9 +165,8 @@ def _create_deployment_info( url=None, preferred_domain_name=None, ), - model_revisions=[cast(ModelRevisionData, revision)] if has_revision else [], - current_revision_id=DeploymentRevisionID(rev_id) if has_revision else None, options=DeploymentOptions(), + current_revision=cast(ModelRevisionData, revision) if has_revision else None, ) @@ -188,6 +187,7 @@ def _create_route_data( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) diff --git a/tests/unit/manager/sokovan/deployment/route/executor/conftest.py b/tests/unit/manager/sokovan/deployment/route/executor/conftest.py index f7dd03c678d..105e77c1014 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/conftest.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/conftest.py @@ -173,9 +173,7 @@ def _create_deployment_info( url="http://test.endpoint", preferred_domain_name=None, ), - model_revisions=[], options=DeploymentOptions(), - current_revision_id=DeploymentRevisionID(uuid4()), ) @@ -203,6 +201,7 @@ def _create_route_data( created_at=datetime.now(tzutc()), revision_id=revision_id or DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py index c05810bf3ae..08fe21d21a7 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py @@ -41,6 +41,7 @@ def _route(health_status: RouteHealthStatus) -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py b/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py index 27dee701f0e..46a877582ee 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py @@ -1,13 +1,4 @@ -"""Tests for initial_delay calculation based on running_at. - -Test scenarios: -- ID-001: running_at is set → initial_delay_until based on running_at -- ID-002: running_at is None → fallback to redis_time -- ID-003: created_at has expired but running_at has not → still within initial_delay -- ID-004: running_at has also expired → initial_delay over -- ID-005: observer ignores failure within initial_delay (running_at based) -- ID-006: observer writes failure after initial_delay expires (running_at based) -""" +"""Tests for RouteProbeTarget registration and initial_delay behavior.""" from __future__ import annotations @@ -18,6 +9,7 @@ from dateutil.tz import tzutc from ai.backend.common.clients.valkey_client.valkey_schedule import RouteHealthRecord +from ai.backend.common.clients.valkey_client.valkey_schedule.types import RouteProbeTarget from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID @@ -33,11 +25,13 @@ from ai.backend.manager.sokovan.deployment.route.handlers.observer.health_check import ( RouteHealthObserver, ) +from ai.backend.manager.sokovan.deployment.route.recorder.context import RouteRecorderContext def _make_route( created_at_ts: int = 1000, session_id: SessionId | None = None, + health_check: ModelHealthCheck | None = None, ) -> RouteData: return RouteData( route_id=ReplicaID(uuid4()), @@ -49,118 +43,157 @@ def _make_route( created_at=datetime.fromtimestamp(created_at_ts, tz=tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=health_check, replica_host="10.0.0.1", replica_port=8000, ) # ============================================================================= -# _initialize_health_records: initial_delay_until calculation +# _register_route_probe_targets # ============================================================================= -class TestInitializeHealthRecordsInitialDelay: - """Tests for initial_delay_until calculation in _initialize_health_records.""" +class TestRegisterRouteProbeTargets: + """Tests for _register_route_probe_targets.""" - async def test_running_at_present_uses_running_at( + async def test_registers_probe_target_with_health_config( self, route_executor: RouteExecutor, - mock_deployment_repo: AsyncMock, mock_valkey_schedule: AsyncMock, ) -> None: - """ID-001: When running_at exists in Valkey, initial_delay_until is based on running_at. + """health_path comes from route.health_check when present.""" + route = _make_route(health_check=ModelHealthCheck(path="/healthz", initial_delay=60.0)) + replica_id = ReplicaID(route.route_id) - Given: Route with running_at=5000 in Valkey, initial_delay=720 - When: _initialize_health_records - Then: initial_delay_until = 5000 + 720 = 5720 - """ - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) + await route_executor._register_route_probe_targets( + [route], + {replica_id: RouteSessionKernelInfo(replica_host="10.0.0.2", replica_port=9000)}, + ) - mock_valkey_schedule.get_route_running_at_batch.return_value = { - route_id_str: 5000, - } - mock_valkey_schedule.get_redis_time.return_value = 5100 - mock_deployment_repo.fetch_health_check_configs_by_revision_ids.return_value = { - route.revision_id: ModelHealthCheck(path="/health", initial_delay=720.0), - } + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[RouteProbeTarget] = call_args[0][0] + assert len(targets) == 1 + assert targets[0].replica_id == replica_id + assert targets[0].health_path == "/healthz" + assert targets[0].inference_port == 9000 + assert targets[0].replica_host == "10.0.0.2" - await route_executor._initialize_health_records( + async def test_registers_default_health_path_when_no_config( + self, + route_executor: RouteExecutor, + mock_valkey_schedule: AsyncMock, + ) -> None: + """health_path defaults to '/' when route.health_check is None.""" + route = _make_route(health_check=None) + replica_id = ReplicaID(route.route_id) + + await route_executor._register_route_probe_targets( [route], - {route.route_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, + {replica_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, ) - call_args = mock_valkey_schedule.initialize_route_health_records_batch.call_args - records: list[RouteHealthRecord] = call_args[0][0] - assert len(records) == 1 - assert records[0].running_at == 5000 - assert records[0].initial_delay_until == 5000 + 720 + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[RouteProbeTarget] = call_args[0][0] + assert targets[0].health_path == "/" - async def test_running_at_none_falls_back_to_redis_time( + async def test_registers_multiple_routes( self, route_executor: RouteExecutor, - mock_deployment_repo: AsyncMock, mock_valkey_schedule: AsyncMock, ) -> None: - """ID-002: When running_at is None, fallback to current redis_time. + """Multiple routes produce multiple RouteProbeTarget entries.""" + routes = [_make_route() for _ in range(3)] + replica_infos = { + ReplicaID(r.route_id): RouteSessionKernelInfo( + replica_host="10.0.0.1", replica_port=8000 + i + ) + for i, r in enumerate(routes) + } - Given: No existing record in Valkey (running_at not set), redis_time=6000 - When: _initialize_health_records - Then: initial_delay_until = 6000 + 720 = 6720, running_at = 6000 - """ - route = _make_route(created_at_ts=1000) + await route_executor._register_route_probe_targets(routes, replica_infos) - mock_valkey_schedule.get_route_running_at_batch.return_value = {} - mock_valkey_schedule.get_redis_time.return_value = 6000 - mock_deployment_repo.fetch_health_check_configs_by_revision_ids.return_value = { - route.revision_id: ModelHealthCheck(path="/health", initial_delay=720.0), - } + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[RouteProbeTarget] = call_args[0][0] + assert len(targets) == 3 - await route_executor._initialize_health_records( - [route], - {route.route_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, - ) - call_args = mock_valkey_schedule.initialize_route_health_records_batch.call_args - records: list[RouteHealthRecord] = call_args[0][0] - assert len(records) == 1 - assert records[0].running_at == 6000 - assert records[0].initial_delay_until == 6000 + 720 +class TestSyncRouteProbeTargets: + """Tests for sync_route_probe_targets.""" - async def test_created_at_expired_but_running_at_not_expired( + async def test_syncs_routes_with_replica_info( self, route_executor: RouteExecutor, - mock_deployment_repo: AsyncMock, mock_valkey_schedule: AsyncMock, ) -> None: - """ID-003: created_at-based delay would have expired, but running_at-based has not. + """Routes with replica_host and replica_port are synced.""" + route = _make_route() - Given: created_at=1000, running_at=5000, initial_delay=720, current_time=1800 - created_at + 720 = 1720 < 1800 (expired if created_at based) - running_at + 720 = 5720 > 1800 (NOT expired with running_at based) - When: _initialize_health_records - Then: initial_delay_until = 5720 (based on running_at, not created_at) - """ - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) + with RouteRecorderContext.scope("test", entity_ids=[route.route_id]): + result = await route_executor.sync_route_probe_targets([route]) - mock_valkey_schedule.get_route_running_at_batch.return_value = { - route_id_str: 5000, - } - mock_valkey_schedule.get_redis_time.return_value = 1800 - mock_deployment_repo.fetch_health_check_configs_by_revision_ids.return_value = { - route.revision_id: ModelHealthCheck(path="/health", initial_delay=720.0), - } + mock_valkey_schedule.register_route_probe_targets_batch.assert_awaited_once() + assert result.successes == [] + assert result.errors == [] - await route_executor._initialize_health_records( - [route], - {route.route_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, + async def test_skips_routes_without_replica_info( + self, + route_executor: RouteExecutor, + mock_valkey_schedule: AsyncMock, + ) -> None: + """Routes without replica_host/port are silently skipped.""" + route = _make_route() + route_no_info = RouteData( + route_id=ReplicaID(uuid4()), + deployment_id=DeploymentID(uuid4()), + session_id=SessionId(uuid4()), + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.NOT_CHECKED, + traffic_ratio=1.0, + created_at=datetime.fromtimestamp(1000, tz=tzutc()), + revision_id=DeploymentRevisionID(uuid4()), + traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, + replica_host=None, + replica_port=None, ) - call_args = mock_valkey_schedule.initialize_route_health_records_batch.call_args - records: list[RouteHealthRecord] = call_args[0][0] - assert records[0].initial_delay_until == 5720 - assert records[0].created_at == 1000 + with RouteRecorderContext.scope( + "test", entity_ids=[route.route_id, route_no_info.route_id] + ): + await route_executor.sync_route_probe_targets([route, route_no_info]) + + call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args + targets: list[RouteProbeTarget] = call_args[0][0] + assert len(targets) == 1 + assert targets[0].replica_id == ReplicaID(route.route_id) + + async def test_all_routes_missing_replica_info_returns_empty( + self, + route_executor: RouteExecutor, + mock_valkey_schedule: AsyncMock, + ) -> None: + """No routes with replica info → no Valkey call, empty result.""" + route_no_info = RouteData( + route_id=ReplicaID(uuid4()), + deployment_id=DeploymentID(uuid4()), + session_id=SessionId(uuid4()), + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.NOT_CHECKED, + traffic_ratio=1.0, + created_at=datetime.fromtimestamp(1000, tz=tzutc()), + revision_id=DeploymentRevisionID(uuid4()), + traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, + replica_host=None, + replica_port=None, + ) + + result = await route_executor.sync_route_probe_targets([route_no_info]) + + mock_valkey_schedule.register_route_probe_targets_batch.assert_not_awaited() + assert result.successes == [] + assert result.errors == [] # ============================================================================= diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py b/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py index 123cb2c432e..a4df7545614 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_register_unregister_routes.py @@ -62,6 +62,7 @@ def _route_with_replica( traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, replica_host=replica_host, replica_port=replica_port, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py index 6b0f31bc3e1..dbfe9bb5b49 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py @@ -491,12 +491,15 @@ async def test_unhealthy_route_marked_for_cleanup( Then: Route in successes (marked for cleanup) """ # Arrange + current_revision_mock = MagicMock() + current_revision_mock.id = unhealthy_route.revision_id + deployment = MagicMock() deployment.id = unhealthy_route.deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = unhealthy_route.revision_id - deployment.deploying_revision_id = None + deployment.current_revision = current_revision_mock + deployment.deploying_revision = None mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -524,12 +527,15 @@ async def test_healthy_route_not_marked_for_cleanup( Then: Route not in successes """ # Arrange + current_revision_mock = MagicMock() + current_revision_mock.id = healthy_route.revision_id + deployment = MagicMock() deployment.id = healthy_route.deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = healthy_route.revision_id - deployment.deploying_revision_id = None + deployment.current_revision = current_revision_mock + deployment.deploying_revision = None mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -588,14 +594,20 @@ async def test_orphan_revision_route_marked_for_cleanup( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), # neither current nor deploying traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) + current_revision_mock = MagicMock() + current_revision_mock.id = current_revision_id + deploying_revision_mock = MagicMock() + deploying_revision_mock.id = deploying_revision_id + deployment = MagicMock() deployment.id = deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = current_revision_id - deployment.deploying_revision_id = deploying_revision_id + deployment.current_revision = current_revision_mock + deployment.deploying_revision = deploying_revision_mock mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -632,14 +644,18 @@ async def test_provisioning_route_for_deploying_revision_kept( created_at=datetime.now(tzutc()), revision_id=deploying_revision_id, traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) + deploying_revision_mock = MagicMock() + deploying_revision_mock.id = deploying_revision_id + deployment = MagicMock() deployment.id = deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = None - deployment.deploying_revision_id = deploying_revision_id + deployment.current_revision = None + deployment.deploying_revision = deploying_revision_mock mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -676,14 +692,15 @@ async def test_orphan_check_skipped_when_no_known_revisions( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) deployment = MagicMock() deployment.id = deployment_id deployment.metadata = MagicMock() deployment.metadata.resource_group = "default" - deployment.current_revision_id = None - deployment.deploying_revision_id = None + deployment.current_revision = None + deployment.deploying_revision = None mock_deployment_repo.get_deployments_by_ids.return_value = [deployment] mock_deployment_repo.get_scaling_group_cleanup_configs.return_value = { "default": cleanup_config_unhealthy_only @@ -930,6 +947,7 @@ def _route_for_endpoint(endpoint_id: DeploymentID) -> RouteData: revision_id=DeploymentRevisionID(uuid4()), created_at=datetime.now(tzutc()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py b/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py index 100ca19f379..fa9ab8638f9 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_terminate_routes_drain.py @@ -44,6 +44,7 @@ def _terminating_route() -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), @@ -162,6 +163,7 @@ async def test_routes_without_session_still_unregister( traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, replica_host=None, replica_port=None, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py b/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py index 0602c89e260..2f85556fce7 100644 --- a/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py +++ b/tests/unit/manager/sokovan/deployment/route/handlers/test_health_check_handler.py @@ -41,6 +41,7 @@ def _route(health_status: RouteHealthStatus) -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py b/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py index 9ac3e535b82..2a9255a232b 100644 --- a/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py +++ b/tests/unit/manager/sokovan/deployment/route/handlers/test_terminating_handler.py @@ -40,6 +40,7 @@ def _terminating_route() -> RouteData: traffic_ratio=1.0, revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, replica_host="10.0.0.1", replica_port=8000, created_at=datetime.now(tzutc()), diff --git a/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py b/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py index 59a05cceae1..aad4bff89a9 100644 --- a/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py +++ b/tests/unit/manager/sokovan/deployment/route/test_coordinator_history.py @@ -54,6 +54,7 @@ def sample_route_data() -> RouteData: created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) @@ -399,6 +400,7 @@ async def test_records_history_on_stale( created_at=datetime.now(tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) ] ) From 25ca363442fdf6133f1fef50edd0c0ab39330b2e Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 12:43:21 +0900 Subject: [PATCH 04/13] fix(BA-6035): backfill health_check on routings from revision model_definition Existing routes have health_check=NULL after the column was added. Backfill by reading health_check config from deployment_revisions.model_definition and applying ModelHealthCheck defaults for optional fields. Only populates when path (required field) exists in the draft config. Co-Authored-By: Claude Sonnet 4.6 --- ...d4e5f6a7b8_add_health_check_to_routings.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py index ffe75d4deb4..5a975cb9d6b 100644 --- a/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py +++ b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py @@ -20,6 +20,26 @@ def upgrade() -> None: conn = op.get_bind() conn.exec_driver_sql("ALTER TABLE routings ADD COLUMN IF NOT EXISTS health_check JSONB") + # Backfill health_check from revision's model_definition for existing routes. + # Applies ModelHealthCheck defaults for optional fields (interval, max_retries, etc.). + # Only populates when health_check.path exists (required field in ModelHealthCheck). + conn.exec_driver_sql(""" + UPDATE routings r + SET health_check = jsonb_build_object( + 'path', dr.model_definition->'health_check'->>'path', + 'interval', COALESCE((dr.model_definition->'health_check'->>'interval')::float, 10.0), + 'max_retries', COALESCE((dr.model_definition->'health_check'->>'max_retries')::int, 10), + 'max_wait_time', COALESCE((dr.model_definition->'health_check'->>'max_wait_time')::float, 15.0), + 'expected_status_code', COALESCE((dr.model_definition->'health_check'->>'expected_status_code')::int, 200), + 'initial_delay', COALESCE((dr.model_definition->'health_check'->>'initial_delay')::float, 60.0) + ) + FROM deployment_revisions dr + WHERE dr.id = r.revision + AND r.health_check IS NULL + AND dr.model_definition IS NOT NULL + AND dr.model_definition->'health_check' IS NOT NULL + AND dr.model_definition->'health_check'->>'path' IS NOT NULL + """) def downgrade() -> None: From d533350738a32b0f9b9cd5214db2cdef00c5a6f0 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 12:58:17 +0900 Subject: [PATCH 05/13] fix(BA-6035): replace health records with TTL-based statuses in check_route_health MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - check_warming_up_health: use get_route_health_statuses_batch + DB-based timeout via _deployment_repo.get_db_now() instead of Redis time + revision fetch - check_route_health: TTL 3-way classification — None=DEGRADED, healthy=success, unhealthy=error (removes within_initial_delay, is_stale, get_redis_time calls) - Update TestCheckRouteHealth to use ValkeyRouteHealthStatus fixtures - Rewrite test_check_route_health_register.py to use get_route_health_statuses_batch Co-Authored-By: Claude Sonnet 4.6 --- .../sokovan/deployment/route/executor.py | 113 +++++++----------- .../test_check_route_health_register.py | 62 +++++----- .../route/executor/test_route_executor.py | 90 ++++---------- 3 files changed, 92 insertions(+), 173 deletions(-) diff --git a/src/ai/backend/manager/sokovan/deployment/route/executor.py b/src/ai/backend/manager/sokovan/deployment/route/executor.py index 53e032e144e..b34c4fb61ec 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/route/executor.py @@ -406,43 +406,40 @@ async def check_warming_up_health(self, routes: Sequence[RouteData]) -> RouteExe """Check health of PROVISIONING+WARMING_UP routes for initial activation. - success: health probe passed, or no health check configured → RUNNING+ACTIVE - - failure: initial_delay exceeded without a passing probe → TERMINATING - - (no transition): still within initial_delay → route stays WARMING_UP + - failure: last_transition_at + initial_delay exceeded without a passing probe → TERMINATING + - (no transition): still within initial_delay, or last_transition_at unknown → stay WARMING_UP """ - route_id_strs = [str(r.route_id) for r in routes] - revision_ids = {r.revision_id for r in routes} - - records = await self._valkey_schedule.get_route_health_records_batch(route_id_strs) - health_configs = await self._deployment_repo.fetch_health_check_configs_by_revision_ids( - revision_ids - ) - current_time = await self._valkey_schedule.get_redis_time() + statuses = await self._valkey_schedule.get_route_health_statuses_batch([ + route.route_id for route in routes + ]) + now = await self._deployment_repo.get_db_now() successes: list[RouteData] = [] errors: list[RouteExecutionError] = [] for route in routes: - health_config = health_configs.get(route.revision_id) - if health_config is None: + if route.health_check is None: successes.append(route) continue - route_id_str = str(route.route_id) - record = records.get(route_id_str) - - if record is None: + status = statuses.get(route.route_id) + if status is not None and status.healthy: + successes.append(route) continue - if record.last_check > 0 and not record.is_stale(current_time) and record.healthy: - successes.append(route) + if route.last_transition_at is None: continue - if current_time > record.initial_delay_until: + elapsed = (now - route.last_transition_at).total_seconds() + if elapsed > route.health_check.initial_delay: errors.append( RouteExecutionError( route_info=route, reason="Route warming-up timed out waiting for healthy probe", - error_detail=f"Elapsed {current_time - record.initial_delay_until}s after initial_delay", + error_detail=( + f"Elapsed {elapsed:.0f}s exceeds " + f"initial_delay {route.health_check.initial_delay}s" + ), ) ) @@ -450,25 +447,16 @@ async def check_warming_up_health(self, routes: Sequence[RouteData]) -> RouteExe async def check_route_health(self, routes: Sequence[RouteData]) -> RouteExecutionResult: """ - Check health status of routes and push newly-healthy ones to AppProxy. + Check health status of RUNNING routes and push newly-healthy ones to AppProxy. - Reads RouteHealthRecord and classifies based on computed healthy/stale: - - HEALTHY: record.healthy is True - - UNHEALTHY: record.healthy is False and not stale (only past initial_delay) - - DEGRADED: record is stale or missing (only past initial_delay) - - (no transition): within initial_delay without a successful probe yet — - the route keeps whatever health_status it already has so a transient - warmup failure does not downgrade it prematurely. + Reads RouteHealthStatus from Valkey and classifies: + - HEALTHY: status.healthy is True + - UNHEALTHY: status.healthy is False + - DEGRADED: status absent (key missing or TTL expired — no recent check) Routes whose pre-execute health_status was not HEALTHY but whose - probe just passed are pushed to AppProxy synchronously via - :meth:`register_routes_now` so traffic can flow without waiting - for the long-cycle ``AppProxySyncRouteHandler`` fallback. Push - failures are logged and swallowed — the fallback converges - state, and a stuck health-check tick would block all later - observations. The handler keeps ``post_process`` to a thin - logging shim because all the work that changes external state - belongs here in the executor. + probe just passed are pushed to AppProxy synchronously so traffic + can flow without waiting for the long-cycle fallback. Args: routes: Routes to check health for @@ -476,12 +464,12 @@ async def check_route_health(self, routes: Sequence[RouteData]) -> RouteExecutio Returns: Result with successes (healthy), errors (unhealthy), stale (degraded) """ - # Phase 1: Load RouteHealthRecords + # Phase 1: Load RouteHealthStatuses with RouteRecorderContext.shared_phase("load_health_status"): with RouteRecorderContext.shared_step("query_health_check_results"): - route_ids = [str(route.route_id) for route in routes] - records = await self._valkey_schedule.get_route_health_records_batch(route_ids) - current_time = await self._valkey_schedule.get_redis_time() + statuses = await self._valkey_schedule.get_route_health_statuses_batch([ + route.route_id for route in routes + ]) successes: list[RouteData] = [] errors: list[RouteExecutionError] = [] @@ -489,45 +477,24 @@ async def check_route_health(self, routes: Sequence[RouteData]) -> RouteExecutio # Phase 2: Classify health state (per-route) for route in routes: - route_id_str = str(route.route_id) - record = records.get(route_id_str) + status = statuses.get(route.route_id) - if record is None: - # No RouteHealthRecord — not yet initialized + if status is None: + # Key absent or TTL expired → DEGRADED stale.append(route) continue - within_initial_delay = current_time < record.initial_delay_until - - # Success path always wins — a healthy probe transitions the route - # to HEALTHY even if initial_delay has not elapsed, so the kernel - # can start receiving traffic as soon as it is ready. - if record.last_check > 0 and not record.is_stale(current_time) and record.healthy: + if status.healthy: successes.append(route) - continue - - # Within initial_delay and no successful probe yet — hold the - # existing health_status (NOT_CHECKED/HEALTHY/UNHEALTHY/DEGRADED) - # by skipping classification entirely. - if within_initial_delay: - continue - - if record.last_check == 0: - stale.append(route) - continue - - if record.is_stale(current_time): - stale.append(route) - continue - - errors.append( - RouteExecutionError( - route_info=route, - reason="Route health check failed", - error_detail="RouteHealthRecord reports unhealthy", - error_code=None, + else: + errors.append( + RouteExecutionError( + route_info=route, + reason="Route health check failed", + error_detail="RouteHealthStatus reports unhealthy", + error_code=None, + ) ) - ) # Phase 3: Push newly-healthy routes to AppProxy. # ``successes`` carries the pre-execute RouteData snapshot; routes diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py index 08fe21d21a7..d5f4041c0d0 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py @@ -11,11 +11,14 @@ from __future__ import annotations from datetime import datetime -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch from uuid import uuid4 from dateutil.tz import tzutc +from ai.backend.common.clients.valkey_client.valkey_schedule import ( + RouteHealthStatus as ValkeyRouteHealthStatus, +) from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID from ai.backend.common.identifier.replica import ReplicaID @@ -48,14 +51,20 @@ def _route(health_status: RouteHealthStatus) -> RouteData: ) -def _healthy_record(redis_now: int) -> MagicMock: - """RouteHealthRecord stub that signals 'just probed and healthy'.""" - record = MagicMock() - record.healthy = True - record.last_check = redis_now - 1 - record.initial_delay_until = redis_now - 10 - record.is_stale = MagicMock(return_value=False) - return record +def _healthy_status(route: RouteData) -> ValkeyRouteHealthStatus: + return ValkeyRouteHealthStatus( + replica_id=route.route_id, + healthy=True, + last_check=999, + ) + + +def _unhealthy_status(route: RouteData) -> ValkeyRouteHealthStatus: + return ValkeyRouteHealthStatus( + replica_id=route.route_id, + healthy=False, + last_check=999, + ) class TestCheckRouteHealthRegister: @@ -68,10 +77,8 @@ async def test_first_time_healthy_triggers_register( ) -> None: """RR-EXEC-HC-001: NOT_CHECKED → HEALTHY route is pushed to AppProxy.""" not_yet_healthy = _route(RouteHealthStatus.NOT_CHECKED) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(not_yet_healthy.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={not_yet_healthy.route_id: _healthy_status(not_yet_healthy)} ) with patch.object( @@ -96,10 +103,8 @@ async def test_already_healthy_route_is_skipped( ) -> None: """RR-EXEC-HC-002: already-HEALTHY routes do not trigger fresh register.""" already_healthy = _route(RouteHealthStatus.HEALTHY) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(already_healthy.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={already_healthy.route_id: _healthy_status(already_healthy)} ) with patch.object(route_executor, "register_routes_now") as mock_register: @@ -116,10 +121,8 @@ async def test_unhealthy_to_healthy_triggers_register( ) -> None: """RR-EXEC-HC-003: UNHEALTHY → HEALTHY recovery is treated as fresh transition.""" recovered = _route(RouteHealthStatus.UNHEALTHY) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(recovered.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={recovered.route_id: _healthy_status(recovered)} ) with patch.object( @@ -139,15 +142,8 @@ async def test_no_successes_skips_register( ) -> None: """RR-EXEC-HC-004: routes whose probe failed do not trigger register.""" unhealthy_route = _route(RouteHealthStatus.UNHEALTHY) - redis_now = 1_000_000 - record = MagicMock() - record.healthy = False - record.last_check = redis_now - 1 - record.initial_delay_until = redis_now - 10 - record.is_stale = MagicMock(return_value=False) - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(unhealthy_route.route_id): record} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={unhealthy_route.route_id: _unhealthy_status(unhealthy_route)} ) with patch.object(route_executor, "register_routes_now") as mock_register: @@ -170,10 +166,8 @@ async def test_register_failure_is_swallowed( cycle would block all later observations. """ not_yet_healthy = _route(RouteHealthStatus.NOT_CHECKED) - redis_now = 1_000_000 - mock_valkey_schedule.get_redis_time = AsyncMock(return_value=redis_now) - mock_valkey_schedule.get_route_health_records_batch = AsyncMock( - return_value={str(not_yet_healthy.route_id): _healthy_record(redis_now)} + mock_valkey_schedule.get_route_health_statuses_batch = AsyncMock( + return_value={not_yet_healthy.route_id: _healthy_status(not_yet_healthy)} ) with patch.object( diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py index dbfe9bb5b49..22b8bf219bf 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py @@ -22,7 +22,9 @@ import pytest from dateutil.tz import tzutc -from ai.backend.common.clients.valkey_client.valkey_schedule import RouteHealthRecord +from ai.backend.common.clients.valkey_client.valkey_schedule import ( + RouteHealthStatus as ValkeyRouteHealthStatus, +) from ai.backend.common.dto.appproxy_coordinator.v2.endpoint.response import ( BulkUpdateRoutesResponse, ) @@ -196,7 +198,7 @@ class TestCheckRouteHealth: """Tests for check_route_health functionality. Verifies the executor correctly checks route health via Valkey - using RouteHealthRecord-based classification. + using RouteHealthStatus TTL-based classification. """ async def test_healthy_route_in_successes( @@ -207,34 +209,23 @@ async def test_healthy_route_in_successes( ) -> None: """RH-001: Healthy route is in successes. - Given: Route with healthy RouteHealthRecord in Valkey + Given: Route with RouteHealthStatus(healthy=True) in Valkey When: Check route health Then: Route in successes list """ - # Arrange - route_id_str = str(healthy_route.route_id) - current_time = 1000 - record = RouteHealthRecord( - route_id=route_id_str, - created_at=900, - initial_delay_until=930, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - agent_healthy=True, - agent_last_check=current_time - 5, + status = ValkeyRouteHealthStatus( + replica_id=healthy_route.route_id, + healthy=True, + last_check=995, ) - mock_valkey_schedule.get_route_health_records_batch.return_value = { - route_id_str: record, + mock_valkey_schedule.get_route_health_statuses_batch.return_value = { + healthy_route.route_id: status, } - mock_valkey_schedule.get_redis_time.return_value = current_time entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 1 assert len(result.errors) == 0 assert len(result.stale) == 0 @@ -247,34 +238,23 @@ async def test_unhealthy_route_in_errors( ) -> None: """RH-002: Unhealthy route is in errors. - Given: Route with unhealthy RouteHealthRecord in Valkey + Given: Route with RouteHealthStatus(healthy=False) in Valkey When: Check route health Then: Route in errors list """ - # Arrange - route_id_str = str(healthy_route.route_id) - current_time = 1000 - record = RouteHealthRecord( - route_id=route_id_str, - created_at=900, - initial_delay_until=930, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - agent_healthy=False, - agent_last_check=current_time - 5, + status = ValkeyRouteHealthStatus( + replica_id=healthy_route.route_id, + healthy=False, + last_check=995, ) - mock_valkey_schedule.get_route_health_records_batch.return_value = { - route_id_str: record, + mock_valkey_schedule.get_route_health_statuses_batch.return_value = { + healthy_route.route_id: status, } - mock_valkey_schedule.get_redis_time.return_value = current_time entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 0 assert len(result.errors) == 1 assert len(result.stale) == 0 @@ -285,36 +265,18 @@ async def test_stale_route_in_stale_list( mock_valkey_schedule: AsyncMock, healthy_route: RouteData, ) -> None: - """RH-003: Stale route is in stale list. + """RH-003: Route with expired TTL (no key) is in stale list. - Given: Route with stale RouteHealthRecord in Valkey (old last_check) + Given: Route whose RouteHealthStatus TTL has expired (key absent) When: Check route health - Then: Route in stale list + Then: Route in stale list (DEGRADED) """ - # Arrange - route_id_str = str(healthy_route.route_id) - current_time = 1000 - record = RouteHealthRecord( - route_id=route_id_str, - created_at=100, - initial_delay_until=130, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - agent_healthy=True, - agent_last_check=100, # Very old check, stale - ) - mock_valkey_schedule.get_route_health_records_batch.return_value = { - route_id_str: record, - } - mock_valkey_schedule.get_redis_time.return_value = current_time + mock_valkey_schedule.get_route_health_statuses_batch.return_value = {} entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 0 assert len(result.errors) == 0 assert len(result.stale) == 1 @@ -327,20 +289,16 @@ async def test_missing_health_data_treated_as_stale( ) -> None: """RH-004: Missing health data is treated as stale. - Given: Route with no RouteHealthRecord in Valkey + Given: Route with no RouteHealthStatus in Valkey When: Check route health Then: Route in stale list """ - # Arrange - Empty records response - mock_valkey_schedule.get_route_health_records_batch.return_value = {} - mock_valkey_schedule.get_redis_time.return_value = 1000 + mock_valkey_schedule.get_route_health_statuses_batch.return_value = {} entity_ids = [healthy_route.route_id] with RouteRecorderContext.scope("test", entity_ids=entity_ids): - # Act result = await route_executor.check_route_health([healthy_route]) - # Assert assert len(result.successes) == 0 assert len(result.errors) == 0 assert len(result.stale) == 1 From 79283c9a8dd6bb6fad7e3a10b5284b5fca3926ae Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:06:00 +0900 Subject: [PATCH 06/13] fix(BA-6035): update RouteHealthObserver to use ReplicaProbeTarget + batch health recording MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add ReplicaHealthResult dataclass (write input; last_check stamped by client) - Add record_route_health_statuses_batch: single Redis time fetch + pipeline - Rename Valkey-layer types Route* → Replica* (RouteProbeTarget/RouteHealthStatus/RouteHealthResult → Replica*) - Observer: get_route_probe_targets_batch + record_route_health_statuses_batch; within_initial_delay guard removed, TTL handles DEGRADED - Extract _build_probe_target() static method; skip routes without health_check Co-Authored-By: Claude Sonnet 4.6 --- .../valkey_client/valkey_schedule/__init__.py | 7 +- .../valkey_client/valkey_schedule/client.py | 64 +++- .../valkey_client/valkey_schedule/types.py | 20 +- .../sokovan/deployment/route/executor.py | 62 ++-- .../route/handlers/observer/health_check.py | 79 ++--- .../test_valkey_schedule_client.py | 16 +- .../test_valkey_schedule_types.py | 38 +-- .../deployment/route/executor/conftest.py | 3 +- .../test_check_route_health_register.py | 10 +- .../route/executor/test_initial_delay.py | 277 +++++------------- .../route/executor/test_route_executor.py | 6 +- 11 files changed, 234 insertions(+), 348 deletions(-) diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py index 76ad8931cde..74d33c8d464 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py @@ -5,14 +5,15 @@ RouteHealthRecord, ValkeyScheduleClient, ) -from .types import RouteHealthStatus, RouteProbeTarget +from .types import ReplicaHealthResult, ReplicaHealthStatus, ReplicaProbeTarget __all__ = [ "HealthCheckStatus", "HealthStatus", "KernelStatus", "RouteHealthRecord", - "RouteHealthStatus", - "RouteProbeTarget", + "ReplicaHealthResult", + "ReplicaHealthStatus", + "ReplicaProbeTarget", "ValkeyScheduleClient", ] diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py index a21dc742325..d02000e1094 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py @@ -13,8 +13,9 @@ create_valkey_client, ) from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( - RouteHealthStatus, - RouteProbeTarget, + ReplicaHealthResult, + ReplicaHealthStatus, + ReplicaProbeTarget, ) from ai.backend.common.exception import BackendAIError from ai.backend.common.identifier.replica import ReplicaID @@ -942,15 +943,17 @@ async def get_route_health_records_batch( return records - # ==================== RouteProbeTarget / RouteHealthStatus Methods ==================== + # ==================== ReplicaProbeTarget / ReplicaHealthStatus Methods ==================== @valkey_schedule_resilience.apply() - async def register_route_probe_targets_batch(self, targets: Sequence[RouteProbeTarget]) -> None: + async def register_route_probe_targets_batch( + self, targets: Sequence[ReplicaProbeTarget] + ) -> None: """ - Batch register RouteProbeTarget entries in Valkey. + Batch register ReplicaProbeTarget entries in Valkey. Called by coordinator when route enters WARMING_UP and replica host/port are known. - :param targets: RouteProbeTarget instances to store + :param targets: ReplicaProbeTarget instances to store """ if not targets: return @@ -967,12 +970,12 @@ async def register_route_probe_targets_batch(self, targets: Sequence[RouteProbeT @valkey_schedule_resilience.apply() async def get_route_probe_targets_batch( self, replica_ids: Sequence[ReplicaID] - ) -> Mapping[ReplicaID, RouteProbeTarget | None]: + ) -> Mapping[ReplicaID, ReplicaProbeTarget | None]: """ - Batch get RouteProbeTargets from Valkey. + Batch get ReplicaProbeTargets from Valkey. :param replica_ids: Replica IDs to look up - :return: Mapping of replica_id to RouteProbeTarget (None if missing or expired) + :return: Mapping of replica_id to ReplicaProbeTarget (None if missing or expired) """ if not replica_ids: return {} @@ -986,7 +989,7 @@ async def get_route_probe_targets_batch( if results is None: return dict.fromkeys(replica_ids) - targets: dict[ReplicaID, RouteProbeTarget | None] = {} + targets: dict[ReplicaID, ReplicaProbeTarget | None] = {} for i, replica_id in enumerate(replica_ids): hgetall_result = results[i] if len(results) > i else None if not hgetall_result: @@ -997,7 +1000,7 @@ async def get_route_probe_targets_batch( targets[replica_id] = None continue data = {k.decode(): v.decode() for k, v in raw.items()} - targets[replica_id] = RouteProbeTarget.from_valkey_hash(data) + targets[replica_id] = ReplicaProbeTarget.from_valkey_hash(data) return targets @@ -1025,16 +1028,45 @@ async def record_route_health_status(self, replica_id: ReplicaID, healthy: bool) async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) + @valkey_schedule_resilience.apply() + async def record_route_health_statuses_batch( + self, results: Sequence[ReplicaHealthResult] + ) -> None: + """ + Batch record health check results for multiple routes. + Fetches Redis time once and writes all statuses in a single pipeline. + Refreshes TTL on every call; key expiry signals DEGRADED. + + :param results: Sequence of ReplicaHealthResult instances + """ + if not results: + return + + current_time = str(await self._get_redis_time()) + batch = Batch(is_atomic=False) + for result in results: + key = self._get_route_health_status_key(result.replica_id) + data: Mapping[str | bytes, str | bytes] = { + "replica_id": str(result.replica_id), + "healthy": "1" if result.healthy else "0", + "last_check": current_time, + } + batch.hset(key, data) + batch.expire(key, ROUTE_HEALTH_STATUS_TTL_SEC) + + async with self._client.client() as conn: + await conn.exec(batch, raise_on_error=True) + @valkey_schedule_resilience.apply() async def get_route_health_statuses_batch( self, replica_ids: Sequence[ReplicaID] - ) -> Mapping[ReplicaID, RouteHealthStatus | None]: + ) -> Mapping[ReplicaID, ReplicaHealthStatus | None]: """ - Batch get RouteHealthStatus from Valkey. + Batch get ReplicaHealthStatus from Valkey. None means no recent health check (key missing or TTL expired) → DEGRADED. :param replica_ids: Replica IDs to look up - :return: Mapping of replica_id to RouteHealthStatus (None if missing or expired) + :return: Mapping of replica_id to ReplicaHealthStatus (None if missing or expired) """ if not replica_ids: return {} @@ -1048,7 +1080,7 @@ async def get_route_health_statuses_batch( if results is None: return dict.fromkeys(replica_ids) - statuses: dict[ReplicaID, RouteHealthStatus | None] = {} + statuses: dict[ReplicaID, ReplicaHealthStatus | None] = {} for i, replica_id in enumerate(replica_ids): hgetall_result = results[i] if len(results) > i else None if not hgetall_result: @@ -1059,7 +1091,7 @@ async def get_route_health_statuses_batch( statuses[replica_id] = None continue data = {k.decode(): v.decode() for k, v in raw.items()} - statuses[replica_id] = RouteHealthStatus.from_valkey_hash(data) + statuses[replica_id] = ReplicaHealthStatus.from_valkey_hash(data) return statuses diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py index 268bf3eabe9..5aae6e27867 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/types.py @@ -10,7 +10,7 @@ @dataclass -class RouteProbeTarget: +class ReplicaProbeTarget: """Probe configuration for a route stored in Valkey. Stored as a hash at key `route_probe:{replica_id}`. @@ -32,7 +32,7 @@ def to_valkey_hash(self) -> Mapping[str, str]: } @classmethod - def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteProbeTarget: + def from_valkey_hash(cls, data: Mapping[str, str]) -> ReplicaProbeTarget: return cls( replica_id=ReplicaID(UUID(data["replica_id"])), health_path=data["health_path"], @@ -42,7 +42,19 @@ def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteProbeTarget: @dataclass -class RouteHealthStatus: +class ReplicaHealthResult: + """Input type for recording a health check outcome. + + Passed to ``record_route_health_statuses_batch``; ``last_check`` is + assigned by the client using the current Redis time. + """ + + replica_id: ReplicaID + healthy: bool + + +@dataclass +class ReplicaHealthStatus: """Health check result for a route stored in Valkey. Stored as a hash at key `route_health:{replica_id}`. @@ -62,7 +74,7 @@ def to_valkey_hash(self) -> Mapping[str, str]: } @classmethod - def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthStatus: + def from_valkey_hash(cls, data: Mapping[str, str]) -> ReplicaHealthStatus: return cls( replica_id=ReplicaID(UUID(data["replica_id"])), healthy=data.get("healthy", "0") == "1", diff --git a/src/ai/backend/manager/sokovan/deployment/route/executor.py b/src/ai/backend/manager/sokovan/deployment/route/executor.py index b34c4fb61ec..16a1c2e357f 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/executor.py +++ b/src/ai/backend/manager/sokovan/deployment/route/executor.py @@ -7,7 +7,7 @@ from ai.backend.common.clients.http_client.client_pool import ClientPool from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteProbeTarget, + ReplicaProbeTarget, ValkeyScheduleClient, ) from ai.backend.common.dto.appproxy_coordinator.v2.endpoint.request import ( @@ -341,55 +341,49 @@ async def check_running_routes(self, routes: Sequence[RouteData]) -> RouteExecut errors=errors, ) + @staticmethod + def _build_probe_target(route: RouteData) -> ReplicaProbeTarget | None: + """Build a ReplicaProbeTarget from a route, or None if any required field is absent.""" + if route.health_check is None or route.replica_host is None or route.replica_port is None: + return None + return ReplicaProbeTarget( + replica_id=route.route_id, + health_path=route.health_check.path, + inference_port=route.replica_port, + replica_host=route.replica_host, + ) + async def _register_route_probe_targets( self, routes: Sequence[RouteData], replica_info: Mapping[ReplicaID, RouteSessionKernelInfo], ) -> None: - """Register RouteProbeTargets in Valkey for routes that just got replica info.""" - targets: list[RouteProbeTarget] = [] - for route in routes: - replica_id = route.route_id - kernel = replica_info[replica_id] - health_path = route.health_check.path if route.health_check else "/" - targets.append( - RouteProbeTarget( - replica_id=replica_id, - health_path=health_path, - inference_port=kernel.replica_port, - replica_host=kernel.replica_host, - ) + """Register ReplicaProbeTargets in Valkey for routes that just got replica info.""" + targets: list[ReplicaProbeTarget] = [ + ReplicaProbeTarget( + replica_id=route.route_id, + health_path=route.health_check.path, + inference_port=replica_info[route.route_id].replica_port, + replica_host=replica_info[route.route_id].replica_host, ) + for route in routes + if route.health_check is not None + ] if targets: await self._valkey_schedule.register_route_probe_targets_batch(targets) - log.debug("Registered {} RouteProbeTargets in Valkey", len(targets)) + log.debug("Registered {} ReplicaProbeTargets in Valkey", len(targets)) async def sync_route_probe_targets(self, routes: Sequence[RouteData]) -> RouteExecutionResult: - """Sync RouteProbeTargets to Valkey for routes with known replica info. + """Sync ReplicaProbeTargets to Valkey for routes with known replica info. Handles two cases: - Valkey data lost (restart, eviction) → re-registers probe targets - TTL refresh for long-running routes - Routes without replica_host/replica_port are skipped silently. + Routes without health_check/replica_host/replica_port are skipped silently. """ - routes_with_info = [route for route in routes if route.replica_host and route.replica_port] - if not routes_with_info: - return RouteExecutionResult(successes=[], errors=[]) - - # Build probe targets (no phase — health config comes from RouteData) - targets: list[RouteProbeTarget] = [] - for route in routes_with_info: - health_path = route.health_check.path if route.health_check else "/" - targets.append( - RouteProbeTarget( - replica_id=route.route_id, - health_path=health_path, - inference_port=route.replica_port, # type: ignore[arg-type] - replica_host=route.replica_host, # type: ignore[arg-type] - ) - ) + targets = [t for route in routes if (t := self._build_probe_target(route)) is not None] if targets: with RouteRecorderContext.shared_phase( @@ -398,7 +392,7 @@ async def sync_route_probe_targets(self, routes: Sequence[RouteData]) -> RouteEx ): with RouteRecorderContext.shared_step("write_probe_targets"): await self._valkey_schedule.register_route_probe_targets_batch(targets) - log.debug("Synced {} RouteProbeTargets to Valkey", len(targets)) + log.debug("Synced {} ReplicaProbeTargets to Valkey", len(targets)) return RouteExecutionResult(successes=[], errors=[]) diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py index 6a8af7221d0..0cef5584561 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/observer/health_check.py @@ -1,9 +1,8 @@ """Observer for performing HTTP health checks on routes. -Reads RouteHealthRecord from Valkey, performs HTTP health checks in parallel, -and writes manager_healthy/manager_last_check back to Valkey. -During initial_delay, failures are ignored (not written). -The HealthCheckRouteHandler reads from Valkey and performs DB transitions. +Reads ReplicaProbeTarget from Valkey to get probe config (health_path, host, port), +performs HTTP health checks in parallel, and writes RouteHealthStatus back to Valkey. +The short TTL on RouteHealthStatus automatically signals DEGRADED on expiry. """ from __future__ import annotations @@ -15,7 +14,7 @@ import aiohttp from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteHealthRecord, + ReplicaHealthResult, ValkeyScheduleClient, ) from ai.backend.logging import BraceStyleAdapter @@ -30,12 +29,11 @@ class RouteHealthObserver(RouteObserver): - """Performs HTTP health checks on routes using RouteHealthRecord. + """Performs HTTP health checks on routes using ReplicaProbeTarget from Valkey. - Reads RouteHealthRecord from Valkey to get health_path, replica_host, inference_port. + Reads ReplicaProbeTarget (health_path, replica_host, inference_port) from Valkey. HTTP checks run in parallel via asyncio.gather. - During initial_delay period, health check failures are ignored (not written to Valkey). - After initial_delay, both success and failure results are written. + Writes RouteHealthStatus back to Valkey; TTL expiry signals DEGRADED automatically. """ def __init__( @@ -51,59 +49,46 @@ def name(cls) -> str: return "route-health-observer" async def observe(self, routes: Sequence[RouteData]) -> RouteObservationResult: - """Check health for routes using RouteHealthRecord from Valkey.""" + """Check health for routes using ReplicaProbeTarget from Valkey.""" if not routes: return RouteObservationResult(observed_count=0) - # Filter routes that have replica connection info - checkable = [r for r in routes if r.replica_host and r.replica_port] - if not checkable: - return RouteObservationResult(observed_count=0) + # Load ReplicaProbeTargets from Valkey (keyed by ReplicaID) + replica_ids = [r.route_id for r in routes] + probe_targets = await self._valkey_schedule.get_route_probe_targets_batch(replica_ids) - # Load RouteHealthRecords from Valkey - route_ids = [str(r.route_id) for r in checkable] - records = await self._valkey_schedule.get_route_health_records_batch(route_ids) - - # Collect routes that have records - targets: list[tuple[str, RouteHealthRecord]] = [] - for route in checkable: - route_id_str = str(route.route_id) - record = records.get(route_id_str) - if record is not None: - targets.append((route_id_str, record)) + # Collect routes that have probe targets + checkable = [ + (route, target) + for route in routes + if (target := probe_targets.get(route.route_id)) is not None + ] - if not targets: - if checkable: + if not checkable: + if routes: log.warning( - "Health observer: {} checkable routes but 0 have records in Valkey", - len(checkable), + "Health observer: {} routes but 0 have probe targets in Valkey", + len(routes), ) return RouteObservationResult(observed_count=0) # Perform HTTP health checks in parallel check_tasks = [ - self._http_health_check(record.replica_host, record.inference_port, record.health_path) - for _, record in targets + self._http_health_check(target.replica_host, target.inference_port, target.health_path) + for _, target in checkable ] results = await asyncio.gather(*check_tasks) - # Write results to Valkey - current_time = await self._valkey_schedule.get_redis_time() - for (route_id_str, record), is_healthy in zip(targets, results, strict=False): - within_initial_delay = current_time < record.initial_delay_until - - # Always refresh TTL to prevent key expiry - await self._valkey_schedule.refresh_route_health_ttl(route_id_str) - - if is_healthy: - await self._valkey_schedule.update_route_manager_health(route_id_str, True) - elif not within_initial_delay: - await self._valkey_schedule.update_route_manager_health(route_id_str, False) - # else: failure within initial_delay → ignore (don't write) + # Write results to Valkey in a single batch (TTL refreshed on every call) + health_results = [ + ReplicaHealthResult(replica_id=route.route_id, healthy=is_healthy) + for (route, _target), is_healthy in zip(checkable, results, strict=False) + ] + await self._valkey_schedule.record_route_health_statuses_batch(health_results) - if targets: - log.debug("Health observer: checked {} routes", len(targets)) - return RouteObservationResult(observed_count=len(targets)) + if checkable: + log.debug("Health observer: checked {} routes", len(checkable)) + return RouteObservationResult(observed_count=len(checkable)) @staticmethod async def _http_health_check(host: str, port: int, path: str) -> bool: diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py index df79a68408e..02d90f7e044 100644 --- a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py @@ -22,7 +22,7 @@ HealthCheckStatus, ValkeyScheduleClient, ) -from ai.backend.common.clients.valkey_client.valkey_schedule.types import RouteProbeTarget +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ReplicaProbeTarget from ai.backend.common.defs import REDIS_LIVE_DB from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.typed_validators import HostPortPair as HostPortPairModel @@ -902,8 +902,8 @@ async def test_remove_deletes_only_specified_sessions( assert result == [sid_keep] -class TestRouteProbeTargetClient: - """Test ValkeyScheduleClient methods for RouteProbeTarget.""" +class TestReplicaProbeTargetClient: + """Test ValkeyScheduleClient methods for ReplicaProbeTarget.""" @pytest.fixture async def valkey_schedule_client( @@ -925,8 +925,8 @@ async def valkey_schedule_client( def replica_id(self) -> ReplicaID: return ReplicaID(uuid4()) - def _make_target(self, replica_id: ReplicaID) -> RouteProbeTarget: - return RouteProbeTarget( + def _make_target(self, replica_id: ReplicaID) -> ReplicaProbeTarget: + return ReplicaProbeTarget( replica_id=replica_id, health_path="/health", inference_port=8080, @@ -983,7 +983,7 @@ async def test_register_overwrites_existing( await valkey_schedule_client.register_route_probe_targets_batch([ self._make_target(replica_id) ]) - updated = RouteProbeTarget( + updated = ReplicaProbeTarget( replica_id=replica_id, health_path="/healthz", inference_port=9000, @@ -994,8 +994,8 @@ async def test_register_overwrites_existing( assert results[replica_id] == updated -class TestRouteHealthStatusClient: - """Test ValkeyScheduleClient methods for RouteHealthStatus.""" +class TestReplicaHealthStatusClient: + """Test ValkeyScheduleClient methods for ReplicaHealthStatus.""" @pytest.fixture async def valkey_schedule_client( diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py index 8b274e09caf..c55105c5831 100644 --- a/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_types.py @@ -1,4 +1,4 @@ -"""Unit tests for RouteProbeTarget and RouteHealthStatus Valkey type serialization.""" +"""Unit tests for ReplicaProbeTarget and ReplicaHealthStatus Valkey type serialization.""" from __future__ import annotations @@ -7,27 +7,27 @@ import pytest from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( - RouteHealthStatus, - RouteProbeTarget, + ReplicaHealthStatus, + ReplicaProbeTarget, ) from ai.backend.common.identifier.replica import ReplicaID -class TestRouteProbeTargetSerialization: +class TestReplicaProbeTargetSerialization: @pytest.fixture def replica_id(self) -> ReplicaID: return ReplicaID(uuid4()) @pytest.fixture - def target(self, replica_id: ReplicaID) -> RouteProbeTarget: - return RouteProbeTarget( + def target(self, replica_id: ReplicaID) -> ReplicaProbeTarget: + return ReplicaProbeTarget( replica_id=replica_id, health_path="/health", inference_port=8080, replica_host="10.0.0.1", ) - def test_to_valkey_hash(self, target: RouteProbeTarget) -> None: + def test_to_valkey_hash(self, target: ReplicaProbeTarget) -> None: h = target.to_valkey_hash() assert h["replica_id"] == str(target.replica_id) assert h["health_path"] == "/health" @@ -41,38 +41,38 @@ def test_from_valkey_hash(self, replica_id: ReplicaID) -> None: "inference_port": "9000", "replica_host": "192.168.1.100", } - target = RouteProbeTarget.from_valkey_hash(data) + target = ReplicaProbeTarget.from_valkey_hash(data) assert target.replica_id == replica_id assert target.health_path == "/healthz" assert target.inference_port == 9000 assert target.replica_host == "192.168.1.100" - def test_round_trip(self, target: RouteProbeTarget) -> None: - restored = RouteProbeTarget.from_valkey_hash(target.to_valkey_hash()) + def test_round_trip(self, target: ReplicaProbeTarget) -> None: + restored = ReplicaProbeTarget.from_valkey_hash(target.to_valkey_hash()) assert restored == target -class TestRouteHealthStatusSerialization: +class TestReplicaHealthStatusSerialization: @pytest.fixture def replica_id(self) -> ReplicaID: return ReplicaID(uuid4()) @pytest.fixture - def healthy_status(self, replica_id: ReplicaID) -> RouteHealthStatus: - return RouteHealthStatus( + def healthy_status(self, replica_id: ReplicaID) -> ReplicaHealthStatus: + return ReplicaHealthStatus( replica_id=replica_id, healthy=True, last_check=1700000000, ) - def test_to_valkey_hash_healthy(self, healthy_status: RouteHealthStatus) -> None: + def test_to_valkey_hash_healthy(self, healthy_status: ReplicaHealthStatus) -> None: h = healthy_status.to_valkey_hash() assert h["replica_id"] == str(healthy_status.replica_id) assert h["healthy"] == "1" assert h["last_check"] == "1700000000" def test_to_valkey_hash_unhealthy(self, replica_id: ReplicaID) -> None: - status = RouteHealthStatus( + status = ReplicaHealthStatus( replica_id=replica_id, healthy=False, last_check=1700000000, @@ -85,17 +85,17 @@ def test_from_valkey_hash(self, replica_id: ReplicaID) -> None: "healthy": "1", "last_check": "1700000000", } - status = RouteHealthStatus.from_valkey_hash(data) + status = ReplicaHealthStatus.from_valkey_hash(data) assert status.replica_id == replica_id assert status.healthy is True assert status.last_check == 1700000000 def test_from_valkey_hash_missing_optional_fields(self, replica_id: ReplicaID) -> None: """Missing healthy/last_check fields default to safe values.""" - status = RouteHealthStatus.from_valkey_hash({"replica_id": str(replica_id)}) + status = ReplicaHealthStatus.from_valkey_hash({"replica_id": str(replica_id)}) assert status.healthy is False assert status.last_check == 0 - def test_round_trip(self, healthy_status: RouteHealthStatus) -> None: - restored = RouteHealthStatus.from_valkey_hash(healthy_status.to_valkey_hash()) + def test_round_trip(self, healthy_status: ReplicaHealthStatus) -> None: + restored = ReplicaHealthStatus.from_valkey_hash(healthy_status.to_valkey_hash()) assert restored == healthy_status diff --git a/tests/unit/manager/sokovan/deployment/route/executor/conftest.py b/tests/unit/manager/sokovan/deployment/route/executor/conftest.py index 105e77c1014..10644d000dc 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/conftest.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/conftest.py @@ -72,7 +72,8 @@ def mock_valkey_schedule() -> AsyncMock: """Mock ValkeyScheduleClient.""" client = AsyncMock() client.get_route_health_records_batch = AsyncMock(return_value={}) - client.get_redis_time = AsyncMock(return_value=1000) + client.get_route_health_statuses_batch = AsyncMock(return_value={}) + client.get_route_probe_targets_batch = AsyncMock(return_value={}) return client diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py index d5f4041c0d0..c1c37265ffe 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_check_route_health_register.py @@ -17,7 +17,7 @@ from dateutil.tz import tzutc from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteHealthStatus as ValkeyRouteHealthStatus, + ReplicaHealthStatus as ValkeyReplicaHealthStatus, ) from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID @@ -51,16 +51,16 @@ def _route(health_status: RouteHealthStatus) -> RouteData: ) -def _healthy_status(route: RouteData) -> ValkeyRouteHealthStatus: - return ValkeyRouteHealthStatus( +def _healthy_status(route: RouteData) -> ValkeyReplicaHealthStatus: + return ValkeyReplicaHealthStatus( replica_id=route.route_id, healthy=True, last_check=999, ) -def _unhealthy_status(route: RouteData) -> ValkeyRouteHealthStatus: - return ValkeyRouteHealthStatus( +def _unhealthy_status(route: RouteData) -> ValkeyReplicaHealthStatus: + return ValkeyReplicaHealthStatus( replica_id=route.route_id, healthy=False, last_check=999, diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py b/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py index 46a877582ee..3820122cbc4 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_initial_delay.py @@ -1,15 +1,17 @@ -"""Tests for RouteProbeTarget registration and initial_delay behavior.""" +"""Tests for ReplicaProbeTarget registration and initial_delay behavior.""" from __future__ import annotations from datetime import datetime -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock from uuid import uuid4 from dateutil.tz import tzutc -from ai.backend.common.clients.valkey_client.valkey_schedule import RouteHealthRecord -from ai.backend.common.clients.valkey_client.valkey_schedule.types import RouteProbeTarget +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + ReplicaHealthResult, + ReplicaProbeTarget, +) from ai.backend.common.config import ModelHealthCheck from ai.backend.common.identifier.deployment import DeploymentID from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID @@ -54,7 +56,7 @@ def _make_route( # ============================================================================= -class TestRegisterRouteProbeTargets: +class TestRegisterReplicaProbeTargets: """Tests for _register_route_probe_targets.""" async def test_registers_probe_target_with_health_config( @@ -72,19 +74,19 @@ async def test_registers_probe_target_with_health_config( ) call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args - targets: list[RouteProbeTarget] = call_args[0][0] + targets: list[ReplicaProbeTarget] = call_args[0][0] assert len(targets) == 1 assert targets[0].replica_id == replica_id assert targets[0].health_path == "/healthz" assert targets[0].inference_port == 9000 assert targets[0].replica_host == "10.0.0.2" - async def test_registers_default_health_path_when_no_config( + async def test_skips_route_without_health_check( self, route_executor: RouteExecutor, mock_valkey_schedule: AsyncMock, ) -> None: - """health_path defaults to '/' when route.health_check is None.""" + """Route with health_check=None is not registered.""" route = _make_route(health_check=None) replica_id = ReplicaID(route.route_id) @@ -93,17 +95,16 @@ async def test_registers_default_health_path_when_no_config( {replica_id: RouteSessionKernelInfo(replica_host="10.0.0.1", replica_port=8000)}, ) - call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args - targets: list[RouteProbeTarget] = call_args[0][0] - assert targets[0].health_path == "/" + mock_valkey_schedule.register_route_probe_targets_batch.assert_not_awaited() async def test_registers_multiple_routes( self, route_executor: RouteExecutor, mock_valkey_schedule: AsyncMock, ) -> None: - """Multiple routes produce multiple RouteProbeTarget entries.""" - routes = [_make_route() for _ in range(3)] + """Multiple routes with health_check produce multiple ReplicaProbeTarget entries.""" + health_check = ModelHealthCheck(path="/health", initial_delay=60.0) + routes = [_make_route(health_check=health_check) for _ in range(3)] replica_infos = { ReplicaID(r.route_id): RouteSessionKernelInfo( replica_host="10.0.0.1", replica_port=8000 + i @@ -114,11 +115,11 @@ async def test_registers_multiple_routes( await route_executor._register_route_probe_targets(routes, replica_infos) call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args - targets: list[RouteProbeTarget] = call_args[0][0] + targets: list[ReplicaProbeTarget] = call_args[0][0] assert len(targets) == 3 -class TestSyncRouteProbeTargets: +class TestSyncReplicaProbeTargets: """Tests for sync_route_probe_targets.""" async def test_syncs_routes_with_replica_info( @@ -126,8 +127,8 @@ async def test_syncs_routes_with_replica_info( route_executor: RouteExecutor, mock_valkey_schedule: AsyncMock, ) -> None: - """Routes with replica_host and replica_port are synced.""" - route = _make_route() + """Routes with health_check, replica_host and replica_port are synced.""" + route = _make_route(health_check=ModelHealthCheck(path="/health", initial_delay=60.0)) with RouteRecorderContext.scope("test", entity_ids=[route.route_id]): result = await route_executor.sync_route_probe_targets([route]) @@ -136,14 +137,15 @@ async def test_syncs_routes_with_replica_info( assert result.successes == [] assert result.errors == [] - async def test_skips_routes_without_replica_info( + async def test_skips_routes_without_health_check_or_replica_info( self, route_executor: RouteExecutor, mock_valkey_schedule: AsyncMock, ) -> None: - """Routes without replica_host/port are silently skipped.""" - route = _make_route() - route_no_info = RouteData( + """Routes missing health_check or replica info are silently skipped.""" + route = _make_route(health_check=ModelHealthCheck(path="/health", initial_delay=60.0)) + route_no_health_check = _make_route(health_check=None) + route_no_replica = RouteData( route_id=ReplicaID(uuid4()), deployment_id=DeploymentID(uuid4()), session_id=SessionId(uuid4()), @@ -153,18 +155,23 @@ async def test_skips_routes_without_replica_info( created_at=datetime.fromtimestamp(1000, tz=tzutc()), revision_id=DeploymentRevisionID(uuid4()), traffic_status=RouteTrafficStatus.INACTIVE, - health_check=None, + health_check=ModelHealthCheck(path="/health", initial_delay=60.0), replica_host=None, replica_port=None, ) with RouteRecorderContext.scope( - "test", entity_ids=[route.route_id, route_no_info.route_id] + "test", + entity_ids=[route.route_id, route_no_health_check.route_id, route_no_replica.route_id], ): - await route_executor.sync_route_probe_targets([route, route_no_info]) + await route_executor.sync_route_probe_targets([ + route, + route_no_health_check, + route_no_replica, + ]) call_args = mock_valkey_schedule.register_route_probe_targets_batch.call_args - targets: list[RouteProbeTarget] = call_args[0][0] + targets: list[ReplicaProbeTarget] = call_args[0][0] assert len(targets) == 1 assert targets[0].replica_id == ReplicaID(route.route_id) @@ -197,88 +204,62 @@ async def test_all_routes_missing_replica_info_returns_empty( # ============================================================================= -# RouteHealthObserver: within_initial_delay based on running_at +# RouteHealthObserver: probe target based observation # ============================================================================= -class TestObserverInitialDelay: - """Tests for observer's initial_delay behavior with running_at-based records.""" +class TestObserverSetsHealthStatus: + """Tests for RouteHealthObserver writing RouteHealthStatus to Valkey.""" - async def test_observer_ignores_failure_within_initial_delay(self) -> None: - """ID-005: Observer does not write failure during initial_delay period. + async def test_observer_writes_success_result(self) -> None: + """ID-005: Observer writes healthy=True on successful probe. - Given: running_at=5000, initial_delay=720 → initial_delay_until=5720 - current redis_time=5500 (within initial_delay) - health check fails + Given: ReplicaProbeTarget exists in Valkey, HTTP check succeeds When: Observer runs - Then: update_route_manager_health NOT called for failure + Then: record_route_health_status called with (replica_id, True) """ mock_deployment_repo = AsyncMock() mock_valkey = AsyncMock() - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) - route_data = MagicMock() - route_data.route_id = route.route_id - route_data.replica_host = "10.0.0.1" - route_data.replica_port = 8000 - - record = RouteHealthRecord( - route_id=route_id_str, - created_at=1000, - initial_delay_until=5720, + route = _make_route() + probe_target = ReplicaProbeTarget( + replica_id=route.route_id, health_path="/health", inference_port=8000, replica_host="10.0.0.1", - running_at=5000, ) - mock_valkey.get_route_health_records_batch.return_value = {route_id_str: record} - mock_valkey.get_redis_time.return_value = 5500 # Within initial_delay + mock_valkey.get_route_probe_targets_batch.return_value = {route.route_id: probe_target} observer = RouteHealthObserver( deployment_repository=mock_deployment_repo, valkey_schedule=mock_valkey, ) - # Patch HTTP check to always fail - observer._http_health_check = AsyncMock(return_value=False) # type: ignore[method-assign] + observer._http_health_check = AsyncMock(return_value=True) # type: ignore[method-assign] - await observer.observe([route_data]) + await observer.observe([route]) - # refresh_route_health_ttl should be called (always) - mock_valkey.refresh_route_health_ttl.assert_awaited_once_with(route_id_str) - # update_route_manager_health should NOT be called (failure ignored during initial_delay) - mock_valkey.update_route_manager_health.assert_not_awaited() + mock_valkey.record_route_health_statuses_batch.assert_awaited_once_with([ + ReplicaHealthResult(replica_id=route.route_id, healthy=True) + ]) - async def test_observer_writes_failure_after_initial_delay(self) -> None: - """ID-006: Observer writes failure after initial_delay expires. + async def test_observer_writes_failure_result(self) -> None: + """ID-006: Observer writes healthy=False on failed probe. - Given: running_at=5000, initial_delay=720 → initial_delay_until=5720 - current redis_time=5800 (past initial_delay) - health check fails + Given: ReplicaProbeTarget exists in Valkey, HTTP check fails When: Observer runs - Then: update_route_manager_health called with False + Then: record_route_health_status called with (replica_id, False) """ mock_deployment_repo = AsyncMock() mock_valkey = AsyncMock() - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) - route_data = MagicMock() - route_data.route_id = route.route_id - route_data.replica_host = "10.0.0.1" - route_data.replica_port = 8000 - - record = RouteHealthRecord( - route_id=route_id_str, - created_at=1000, - initial_delay_until=5720, + route = _make_route() + probe_target = ReplicaProbeTarget( + replica_id=route.route_id, health_path="/health", inference_port=8000, replica_host="10.0.0.1", - running_at=5000, ) - mock_valkey.get_route_health_records_batch.return_value = {route_id_str: record} - mock_valkey.get_redis_time.return_value = 5800 # Past initial_delay + mock_valkey.get_route_probe_targets_batch.return_value = {route.route_id: probe_target} observer = RouteHealthObserver( deployment_repository=mock_deployment_repo, @@ -286,38 +267,24 @@ async def test_observer_writes_failure_after_initial_delay(self) -> None: ) observer._http_health_check = AsyncMock(return_value=False) # type: ignore[method-assign] - await observer.observe([route_data]) + await observer.observe([route]) - mock_valkey.update_route_manager_health.assert_awaited_once_with(route_id_str, False) + mock_valkey.record_route_health_statuses_batch.assert_awaited_once_with([ + ReplicaHealthResult(replica_id=route.route_id, healthy=False) + ]) - async def test_observer_writes_success_within_initial_delay(self) -> None: - """ID-007: Observer writes success even during initial_delay. + async def test_observer_skips_route_without_probe_target(self) -> None: + """ID-007: Observer skips route when no probe target in Valkey. - Given: Within initial_delay, health check succeeds + Given: No ReplicaProbeTarget in Valkey When: Observer runs - Then: update_route_manager_health called with True + Then: record_route_health_status NOT called, observed_count=0 """ mock_deployment_repo = AsyncMock() mock_valkey = AsyncMock() - route = _make_route(created_at_ts=1000) - route_id_str = str(route.route_id) - route_data = MagicMock() - route_data.route_id = route.route_id - route_data.replica_host = "10.0.0.1" - route_data.replica_port = 8000 - - record = RouteHealthRecord( - route_id=route_id_str, - created_at=1000, - initial_delay_until=5720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=5000, - ) - mock_valkey.get_route_health_records_batch.return_value = {route_id_str: record} - mock_valkey.get_redis_time.return_value = 5500 # Within initial_delay + route = _make_route() + mock_valkey.get_route_probe_targets_batch.return_value = {route.route_id: None} observer = RouteHealthObserver( deployment_repository=mock_deployment_repo, @@ -325,113 +292,7 @@ async def test_observer_writes_success_within_initial_delay(self) -> None: ) observer._http_health_check = AsyncMock(return_value=True) # type: ignore[method-assign] - await observer.observe([route_data]) - - mock_valkey.update_route_manager_health.assert_awaited_once_with(route_id_str, True) - - -# ============================================================================= -# RouteHealthRecord serialization: running_at -# ============================================================================= - - -class TestRouteHealthRecordRunningAt: - """Tests for RouteHealthRecord running_at serialization.""" + result = await observer.observe([route]) - def test_running_at_none_not_in_hash(self) -> None: - """running_at=None should not appear in serialized hash.""" - record = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=1720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=None, - ) - h = record.to_valkey_hash() - assert "running_at" not in h - - def test_running_at_present_in_hash(self) -> None: - """running_at with value should appear in serialized hash.""" - record = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=5720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=5000, - ) - h = record.to_valkey_hash() - assert h["running_at"] == "5000" - - def test_from_hash_missing_running_at_is_none(self) -> None: - """Deserializing hash without running_at field yields None.""" - data = { - "route_id": "r1", - "created_at": "1000", - "initial_delay_until": "1720", - "health_path": "/health", - "inference_port": "8000", - "replica_host": "10.0.0.1", - } - record = RouteHealthRecord.from_valkey_hash(data) - assert record.running_at is None - - def test_from_hash_zero_running_at_is_none(self) -> None: - """Deserializing hash with running_at=0 yields None (backward compat).""" - data = { - "route_id": "r1", - "created_at": "1000", - "initial_delay_until": "1720", - "health_path": "/health", - "inference_port": "8000", - "replica_host": "10.0.0.1", - "running_at": "0", - } - record = RouteHealthRecord.from_valkey_hash(data) - assert record.running_at is None - - def test_from_hash_valid_running_at(self) -> None: - """Deserializing hash with valid running_at yields int.""" - data = { - "route_id": "r1", - "created_at": "1000", - "initial_delay_until": "5720", - "health_path": "/health", - "inference_port": "8000", - "replica_host": "10.0.0.1", - "running_at": "5000", - } - record = RouteHealthRecord.from_valkey_hash(data) - assert record.running_at == 5000 - - def test_roundtrip_with_running_at(self) -> None: - """Serialize → deserialize preserves running_at.""" - original = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=5720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=5000, - ) - restored = RouteHealthRecord.from_valkey_hash(original.to_valkey_hash()) - assert restored.running_at == 5000 - assert restored.initial_delay_until == 5720 - - def test_roundtrip_without_running_at(self) -> None: - """Serialize → deserialize preserves running_at=None.""" - original = RouteHealthRecord( - route_id="r1", - created_at=1000, - initial_delay_until=1720, - health_path="/health", - inference_port=8000, - replica_host="10.0.0.1", - running_at=None, - ) - restored = RouteHealthRecord.from_valkey_hash(original.to_valkey_hash()) - assert restored.running_at is None + mock_valkey.record_route_health_statuses_batch.assert_not_awaited() + assert result.observed_count == 0 diff --git a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py index 22b8bf219bf..b47c56a181b 100644 --- a/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py +++ b/tests/unit/manager/sokovan/deployment/route/executor/test_route_executor.py @@ -23,7 +23,7 @@ from dateutil.tz import tzutc from ai.backend.common.clients.valkey_client.valkey_schedule import ( - RouteHealthStatus as ValkeyRouteHealthStatus, + ReplicaHealthStatus as ValkeyReplicaHealthStatus, ) from ai.backend.common.dto.appproxy_coordinator.v2.endpoint.response import ( BulkUpdateRoutesResponse, @@ -213,7 +213,7 @@ async def test_healthy_route_in_successes( When: Check route health Then: Route in successes list """ - status = ValkeyRouteHealthStatus( + status = ValkeyReplicaHealthStatus( replica_id=healthy_route.route_id, healthy=True, last_check=995, @@ -242,7 +242,7 @@ async def test_unhealthy_route_in_errors( When: Check route health Then: Route in errors list """ - status = ValkeyRouteHealthStatus( + status = ValkeyReplicaHealthStatus( replica_id=healthy_route.route_id, healthy=False, last_check=995, From cda42ff8e449e62f9ab0bc35f5aae04feb448e27 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:17:12 +0900 Subject: [PATCH 07/13] fix(BA-6035): add ReplicaProbeTargetSyncHandler and fix AppProxy sync target statuses MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add RouteLifecycleType.PROBE_TARGET_SYNC - Add ReplicaProbeTargetSyncHandler: targets RUNNING + [HEALTHY, UNHEALTHY, DEGRADED] (NOT_CHECKED excluded — probe targets registered at WARMING_UP entry) - Register handler + task spec in coordinator (long_interval=60s, initial_delay=30s) - AppProxySyncRouteHandler: target RUNNING + ACTIVE only (remove health=HEALTHY filter) Initial registration waits for HEALTHY via check_route_health; sync handler converges all active routes regardless of health Co-Authored-By: Claude Sonnet 4.6 --- .../sokovan/deployment/route/coordinator.py | 11 +++ .../deployment/route/handlers/__init__.py | 2 + .../route/handlers/appproxy_sync.py | 4 +- .../route/handlers/probe_target_sync.py | 80 +++++++++++++++++++ .../manager/sokovan/deployment/route/types.py | 1 + 5 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py diff --git a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py index 7c4cba36905..874d7057e28 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/coordinator.py +++ b/src/ai/backend/manager/sokovan/deployment/route/coordinator.py @@ -38,6 +38,7 @@ AppProxySyncRouteHandler, HealthCheckRouteHandler, ProvisioningRouteHandler, + ReplicaProbeTargetSyncHandler, RouteEvictionHandler, RouteHandler, RunningRouteHandler, @@ -184,6 +185,9 @@ def _init_handlers(self, executor: RouteExecutor) -> Mapping[RouteLifecycleType, route_executor=executor, event_producer=self._event_producer, ), + RouteLifecycleType.PROBE_TARGET_SYNC: ReplicaProbeTargetSyncHandler( + route_executor=executor, + ), } async def process_route_lifecycle( @@ -505,6 +509,13 @@ def _create_task_specs() -> list[RouteTaskSpec]: long_interval=30.0, initial_delay=15.0, ), + # Probe target sync - refresh ReplicaProbeTarget TTL and recover lost Valkey entries + RouteTaskSpec( + RouteLifecycleType.PROBE_TARGET_SYNC, + short_interval=None, + long_interval=60.0, + initial_delay=30.0, + ), ] def create_task_specs(self) -> list[EventTaskSpec]: diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py index 237517fb9d0..9eb8233ebe9 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/__init__.py @@ -3,6 +3,7 @@ from .appproxy_sync import AppProxySyncRouteHandler from .base import RouteHandler from .health_check import HealthCheckRouteHandler +from .probe_target_sync import ReplicaProbeTargetSyncHandler from .provisioning import ProvisioningRouteHandler from .route_eviction import RouteEvictionHandler from .running import RunningRouteHandler @@ -14,6 +15,7 @@ __all__ = [ "AppProxySyncRouteHandler", "HealthCheckRouteHandler", + "ReplicaProbeTargetSyncHandler", "ProvisioningRouteHandler", "RouteEvictionHandler", "RouteHandler", diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py index bda30a65756..0ebbd140c7a 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/appproxy_sync.py @@ -7,10 +7,10 @@ from ai.backend.logging import BraceStyleAdapter from ai.backend.manager.data.deployment.types import ( RouteHandlerCategory, - RouteHealthStatus, RouteStatus, RouteStatusTransitions, RouteTargetStatuses, + RouteTrafficStatus, ) from ai.backend.manager.defs import LockID from ai.backend.manager.repositories.deployment.types import RouteData @@ -58,7 +58,7 @@ def category(cls) -> RouteHandlerCategory: def target_statuses(cls) -> RouteTargetStatuses: return RouteTargetStatuses( lifecycle=[RouteStatus.RUNNING], - health=[RouteHealthStatus.HEALTHY], + traffic=[RouteTrafficStatus.ACTIVE], ) @classmethod diff --git a/src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py b/src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py new file mode 100644 index 00000000000..c79ff5963df --- /dev/null +++ b/src/ai/backend/manager/sokovan/deployment/route/handlers/probe_target_sync.py @@ -0,0 +1,80 @@ +"""Handler for syncing ReplicaProbeTargets to Valkey for active routes.""" + +import logging +from collections.abc import Sequence + +from ai.backend.logging import BraceStyleAdapter +from ai.backend.manager.data.deployment.types import ( + RouteHandlerCategory, + RouteHealthStatus, + RouteStatus, + RouteStatusTransitions, + RouteTargetStatuses, +) +from ai.backend.manager.defs import LockID +from ai.backend.manager.repositories.deployment.types import RouteData +from ai.backend.manager.sokovan.deployment.route.executor import RouteExecutor +from ai.backend.manager.sokovan.deployment.route.types import RouteExecutionResult + +from .base import RouteHandler + +log = BraceStyleAdapter(logging.getLogger(__name__)) + + +class ReplicaProbeTargetSyncHandler(RouteHandler): + """Periodically syncs ReplicaProbeTarget entries to Valkey for active routes. + + Targets RUNNING routes that have already been health-checked (HEALTHY, + UNHEALTHY, DEGRADED). NOT_CHECKED routes are excluded intentionally — + their probe targets are registered when they first transition to WARMING_UP. + + Covers two cases: + - Valkey data lost (restart, eviction) → re-registers probe targets + - TTL refresh for long-running routes that would otherwise expire + """ + + def __init__(self, route_executor: RouteExecutor) -> None: + self._route_executor = route_executor + + @classmethod + def name(cls) -> str: + return "replica-probe-target-sync" + + @property + def lock_id(self) -> LockID | None: + return None + + @classmethod + def category(cls) -> RouteHandlerCategory: + return RouteHandlerCategory.SYNC + + @classmethod + def target_statuses(cls) -> RouteTargetStatuses: + return RouteTargetStatuses( + lifecycle=[RouteStatus.RUNNING], + health=[ + RouteHealthStatus.HEALTHY, + RouteHealthStatus.UNHEALTHY, + RouteHealthStatus.DEGRADED, + ], + ) + + @classmethod + def status_transitions(cls) -> RouteStatusTransitions: + return RouteStatusTransitions( + success=None, + failure=None, + stale=None, + ) + + async def execute(self, routes: Sequence[RouteData]) -> RouteExecutionResult: + """Sync probe targets for routes that have health_check and replica info.""" + return await self._route_executor.sync_route_probe_targets(routes) + + async def post_process(self, result: RouteExecutionResult) -> None: + if result.errors: + log.warning( + "Probe target sync: {} succeeded, {} failed", + len(result.successes), + len(result.errors), + ) diff --git a/src/ai/backend/manager/sokovan/deployment/route/types.py b/src/ai/backend/manager/sokovan/deployment/route/types.py index 3cbd0dbd981..b661ab4cb7a 100644 --- a/src/ai/backend/manager/sokovan/deployment/route/types.py +++ b/src/ai/backend/manager/sokovan/deployment/route/types.py @@ -19,6 +19,7 @@ class RouteLifecycleType(StrEnum): SERVICE_DISCOVERY_SYNC = "service_discovery_sync" APPPROXY_SYNC = "appproxy_sync" OBSERVE_HEALTH = "observe_health" + PROBE_TARGET_SYNC = "probe_target_sync" @dataclass From c07043f4e989c651b0f3bfd8ea7923d4b0f6dc9a Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:21:43 +0900 Subject: [PATCH 08/13] fix(BA-6035): delete RouteHealthRecord and legacy Valkey health methods - Remove RouteHealthRecord class (replaced by ReplicaProbeTarget + ReplicaHealthStatus) - Remove legacy methods: mark_route_running_at, get_route_running_at_batch, refresh_route_health_ttl, update_route_manager_health, initialize_route_health_records_batch, get_route_health_record, get_route_health_records_batch, record_route_health_status (single) - Remove RouteHealthRecord from __init__.py exports - Update test_valkey_schedule_client.py to use record_route_health_statuses_batch Co-Authored-By: Claude Sonnet 4.6 --- .../valkey_client/valkey_schedule/__init__.py | 2 - .../valkey_client/valkey_schedule/client.py | 272 ------------------ .../test_valkey_schedule_client.py | 29 +- 3 files changed, 22 insertions(+), 281 deletions(-) diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py index 74d33c8d464..90acbe862c9 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/__init__.py @@ -2,7 +2,6 @@ HealthCheckStatus, HealthStatus, KernelStatus, - RouteHealthRecord, ValkeyScheduleClient, ) from .types import ReplicaHealthResult, ReplicaHealthStatus, ReplicaProbeTarget @@ -11,7 +10,6 @@ "HealthCheckStatus", "HealthStatus", "KernelStatus", - "RouteHealthRecord", "ReplicaHealthResult", "ReplicaHealthStatus", "ReplicaProbeTarget", diff --git a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py index d02000e1094..3e7b5e32bf1 100644 --- a/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py +++ b/src/ai/backend/common/clients/valkey_client/valkey_schedule/client.py @@ -92,87 +92,6 @@ def get_status(self) -> HealthCheckStatus | None: return self.readiness -@dataclass -class RouteHealthRecord: - """Health record for a route stored in Valkey. - - Stored as a hash at key `route_health:{route_id}`. - Created when replica_host/port become available (session RUNNING). - Agent and Manager each update their own fields independently. - """ - - route_id: str - created_at: int # Unix timestamp when route was created - initial_delay_until: int # Unix timestamp = running_at + initial_delay - health_path: str # extracted from model_definition - inference_port: int # extracted from kernel - replica_host: str # extracted from kernel - - # Timestamp when route entered RUNNING state (set by coordinator) - running_at: int | None = None - - # Agent check results - agent_healthy: bool = False - agent_last_check: int = 0 # Unix timestamp - - # Manager check results - manager_healthy: bool = False - manager_last_check: int = 0 # Unix timestamp - - @property - def last_check(self) -> int: - """Most recent check timestamp from either agent or manager.""" - return max(self.agent_last_check, self.manager_last_check) - - @property - def healthy(self) -> bool: - """Health based on the most recent checker (agent or manager).""" - if self.agent_last_check >= self.manager_last_check: - return self.agent_healthy - return self.manager_healthy - - def is_stale(self, current_time: int, staleness_sec: int = MAX_HEALTH_STALENESS_SEC) -> bool: - """Check if health data is stale (no recent check from either side).""" - if self.last_check == 0: - return False # Never checked yet — not stale, just NOT_CHECKED - return (current_time - self.last_check) > staleness_sec - - def to_valkey_hash(self) -> Mapping[str, str]: - """Serialize to Valkey hash fields.""" - data: dict[str, str] = { - "route_id": self.route_id, - "created_at": str(self.created_at), - "initial_delay_until": str(self.initial_delay_until), - "health_path": self.health_path, - "inference_port": str(self.inference_port), - "replica_host": self.replica_host, - "agent_healthy": "1" if self.agent_healthy else "0", - "agent_last_check": str(self.agent_last_check), - "manager_healthy": "1" if self.manager_healthy else "0", - "manager_last_check": str(self.manager_last_check), - } - if self.running_at is not None: - data["running_at"] = str(self.running_at) - return data - - @classmethod - def from_valkey_hash(cls, data: Mapping[str, str]) -> RouteHealthRecord: - """Deserialize from Valkey hash fields.""" - return cls( - route_id=data["route_id"], - created_at=int(data["created_at"]), - initial_delay_until=int(data["initial_delay_until"]), - health_path=data["health_path"], - inference_port=int(data["inference_port"]), - replica_host=data["replica_host"], - running_at=int(raw) if (raw := data.get("running_at")) and raw != "0" else None, - agent_healthy=data.get("agent_healthy", "0") == "1", - agent_last_check=int(data.get("agent_last_check", "0")), - manager_healthy=data.get("manager_healthy", "0") == "1", - manager_last_check=int(data.get("manager_last_check", "0")), - ) - - @dataclass class KernelStatus: """Presence status for a kernel.""" @@ -689,86 +608,6 @@ async def update_route_liveness(self, route_id: str, liveness: bool) -> None: async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) - @valkey_schedule_resilience.apply() - async def mark_route_running_at(self, route_id: str) -> None: - """ - Record the RUNNING transition timestamp for a route. - Called when a route transitions to RUNNING status. - Uses Redis time for consistency with health check comparisons. - - :param route_id: The route ID that entered RUNNING state - """ - key = self._get_route_health_key(route_id) - current_time = str(await self._get_redis_time()) - async with self._client.client() as conn: - await conn.hset(key, {"running_at": current_time}) - await conn.expire(key, ROUTE_HEALTH_TTL_SEC) - - @valkey_schedule_resilience.apply() - async def get_route_running_at_batch(self, route_ids: Sequence[str]) -> dict[str, int | None]: - """ - Batch read running_at field from route health hashes. - Works even on partial hashes (before full RouteHealthRecord is initialized). - - :param route_ids: Route IDs to look up - :return: Mapping of route_id to running_at timestamp (None if not set) - """ - if not route_ids: - return {} - - batch = Batch(is_atomic=False) - for route_id in route_ids: - key = self._get_route_health_key(route_id) - batch.hget(key, "running_at") - - async with self._client.client() as conn: - results = await conn.exec(batch, raise_on_error=False) - if results is None: - return dict.fromkeys(route_ids) - - running_at_map: dict[str, int | None] = {} - for i, route_id in enumerate(route_ids): - raw = results[i] if len(results) > i else None - if raw and raw != b"0": - running_at_map[route_id] = int(raw) - else: - running_at_map[route_id] = None - return running_at_map - - @valkey_schedule_resilience.apply() - async def refresh_route_health_ttl(self, route_id: str) -> None: - """ - Refresh TTL for a route health record without changing any data. - Called by observer on every check to prevent key expiry. - - :param route_id: The route ID to refresh - """ - key = self._get_route_health_key(route_id) - async with self._client.client() as conn: - await conn.expire(key, ROUTE_HEALTH_TTL_SEC) - - @valkey_schedule_resilience.apply() - async def update_route_manager_health(self, route_id: str, healthy: bool) -> None: - """ - Update manager health check result for a route. - Called by RouteHealthObserver after HTTP health check. - - :param route_id: The route ID to update - :param healthy: Whether the route passed the health check - """ - key = self._get_route_health_key(route_id) - current_time = str(await self._get_redis_time()) - data: Mapping[str | bytes, str | bytes] = { - "manager_healthy": "1" if healthy else "0", - "manager_last_check": current_time, - } - - batch = Batch(is_atomic=False) - batch.hset(key, data) - batch.expire(key, ROUTE_HEALTH_TTL_SEC) - async with self._client.client() as conn: - await conn.exec(batch, raise_on_error=True) - @valkey_schedule_resilience.apply() async def check_route_health_status( self, route_ids: list[str] @@ -856,93 +695,6 @@ async def update_routes_readiness_batch(self, route_readiness: Mapping[str, bool async with self._client.client() as conn: await conn.exec(batch, raise_on_error=True) - # ==================== RouteHealthRecord Methods ==================== - - @valkey_schedule_resilience.apply() - async def initialize_route_health_records_batch( - self, records: Sequence[RouteHealthRecord] - ) -> None: - """ - Batch initialize RouteHealthRecord entries in Valkey. - Called when replica info becomes available (session RUNNING + kernel host/port known). - - :param records: RouteHealthRecord instances to store - """ - if not records: - return - - batch = Batch(is_atomic=False) - for record in records: - key = self._get_route_health_key(record.route_id) - batch.hset(key, record.to_valkey_hash()) - batch.expire(key, ROUTE_HEALTH_TTL_SEC) - - async with self._client.client() as conn: - await conn.exec(batch, raise_on_error=True) - - @valkey_schedule_resilience.apply() - async def get_route_health_record(self, route_id: str) -> RouteHealthRecord | None: - """ - Get a RouteHealthRecord from Valkey. - - :param route_id: The route ID to look up - :return: RouteHealthRecord or None if not found - """ - key = self._get_route_health_key(route_id) - async with self._client.client() as conn: - result = await conn.hgetall(key) - if not result: - return None - - data = {k.decode(): v.decode() for k, v in result.items()} - if "route_id" not in data: - return None - return RouteHealthRecord.from_valkey_hash(data) - - @valkey_schedule_resilience.apply() - async def get_route_health_records_batch( - self, route_ids: Sequence[str] - ) -> Mapping[str, RouteHealthRecord | None]: - """ - Batch get RouteHealthRecords from Valkey. - - :param route_ids: Route IDs to look up - :return: Mapping of route_id to RouteHealthRecord (None if missing) - """ - if not route_ids: - return {} - - batch = Batch(is_atomic=False) - for route_id in route_ids: - key = self._get_route_health_key(route_id) - batch.hgetall(key) - - async with self._client.client() as conn: - results = await conn.exec(batch, raise_on_error=False) - if results is None: - return dict.fromkeys(route_ids) - - records: dict[str, RouteHealthRecord | None] = {} - for i, route_id in enumerate(route_ids): - hgetall_result = results[i] if len(results) > i else None - if not hgetall_result: - records[route_id] = None - continue - - raw = cast(dict[bytes, bytes], hgetall_result) - if not raw: - records[route_id] = None - continue - - data = {k.decode(): v.decode() for k, v in raw.items()} - if "route_id" not in data: - # Partial hash (e.g., only running_at set by mark_route_running_at) - records[route_id] = None - continue - records[route_id] = RouteHealthRecord.from_valkey_hash(data) - - return records - # ==================== ReplicaProbeTarget / ReplicaHealthStatus Methods ==================== @valkey_schedule_resilience.apply() @@ -1004,30 +756,6 @@ async def get_route_probe_targets_batch( return targets - @valkey_schedule_resilience.apply() - async def record_route_health_status(self, replica_id: ReplicaID, healthy: bool) -> None: - """ - Record health check result for a route. - Called by RouteHealthObserver after each HTTP probe. - Refreshes TTL on every call; key expiry signals DEGRADED. - - :param replica_id: The replica ID to update - :param healthy: Whether the route passed the health check - """ - key = self._get_route_health_status_key(replica_id) - current_time = str(await self._get_redis_time()) - data: Mapping[str | bytes, str | bytes] = { - "replica_id": str(replica_id), - "healthy": "1" if healthy else "0", - "last_check": current_time, - } - batch = Batch(is_atomic=False) - batch.hset(key, data) - batch.expire(key, ROUTE_HEALTH_STATUS_TTL_SEC) - - async with self._client.client() as conn: - await conn.exec(batch, raise_on_error=True) - @valkey_schedule_resilience.apply() async def record_route_health_statuses_batch( self, results: Sequence[ReplicaHealthResult] diff --git a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py index 02d90f7e044..f61d397ca8c 100644 --- a/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py +++ b/tests/unit/common/clients/valkey_client/test_valkey_schedule_client.py @@ -22,7 +22,10 @@ HealthCheckStatus, ValkeyScheduleClient, ) -from ai.backend.common.clients.valkey_client.valkey_schedule.types import ReplicaProbeTarget +from ai.backend.common.clients.valkey_client.valkey_schedule.types import ( + ReplicaHealthResult, + ReplicaProbeTarget, +) from ai.backend.common.defs import REDIS_LIVE_DB from ai.backend.common.identifier.replica import ReplicaID from ai.backend.common.typed_validators import HostPortPair as HostPortPairModel @@ -1022,7 +1025,9 @@ async def test_record_healthy_and_get( valkey_schedule_client: ValkeyScheduleClient, replica_id: ReplicaID, ) -> None: - await valkey_schedule_client.record_route_health_status(replica_id, healthy=True) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=True) + ]) results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) status = results[replica_id] assert status is not None @@ -1034,7 +1039,9 @@ async def test_record_unhealthy_and_get( valkey_schedule_client: ValkeyScheduleClient, replica_id: ReplicaID, ) -> None: - await valkey_schedule_client.record_route_health_status(replica_id, healthy=False) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=False) + ]) results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) status = results[replica_id] assert status is not None @@ -1061,7 +1068,9 @@ async def test_key_deletion_simulates_ttl_expiry( replica_id: ReplicaID, ) -> None: """Deleting the key (simulating TTL expiry) results in None → DEGRADED.""" - await valkey_schedule_client.record_route_health_status(replica_id, healthy=True) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=True) + ]) key = valkey_schedule_client._get_route_health_status_key(replica_id) async with valkey_schedule_client._client.client() as conn: await conn.delete([key]) @@ -1074,7 +1083,9 @@ async def test_record_batch_multiple( ) -> None: replica_ids = [ReplicaID(uuid4()) for _ in range(3)] for rid in replica_ids: - await valkey_schedule_client.record_route_health_status(rid, healthy=True) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=rid, healthy=True) + ]) results = await valkey_schedule_client.get_route_health_statuses_batch(replica_ids) assert len(results) == 3 for rid in replica_ids: @@ -1087,8 +1098,12 @@ async def test_record_overwrites_previous( valkey_schedule_client: ValkeyScheduleClient, replica_id: ReplicaID, ) -> None: - await valkey_schedule_client.record_route_health_status(replica_id, healthy=True) - await valkey_schedule_client.record_route_health_status(replica_id, healthy=False) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=True) + ]) + await valkey_schedule_client.record_route_health_statuses_batch([ + ReplicaHealthResult(replica_id=replica_id, healthy=False) + ]) results = await valkey_schedule_client.get_route_health_statuses_batch([replica_id]) status = results[replica_id] assert status is not None From 8c7834a3d2e6ab6d72ef76c1a9d34b5e0c8bb24f Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:27:08 +0900 Subject: [PATCH 09/13] fix(BA-6035): fix remaining tests using old DeploymentInfo fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update tests broken by the DeploymentInfo refactor (model_revisions → current_revision/deploying_revision) and RouteCreatorSpec health_check addition. Co-Authored-By: Claude Sonnet 4.6 --- .../deployment/test_deployment_repository.py | 3 +++ .../deployment/test_deployment_crud_actions.py | 1 - .../deployment/test_deployment_service.py | 7 ++----- .../deployment/handlers/test_deploying_handler.py | 11 ++++++----- .../deployment/strategy/test_rolling_update.py | 15 +++++++++++---- 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/tests/unit/manager/repositories/deployment/test_deployment_repository.py b/tests/unit/manager/repositories/deployment/test_deployment_repository.py index 875f206a39f..17d43027744 100644 --- a/tests/unit/manager/repositories/deployment/test_deployment_repository.py +++ b/tests/unit/manager/repositories/deployment/test_deployment_repository.py @@ -3359,6 +3359,7 @@ async def test_create_route( domain=test_domain_name, project_id=test_group_id, revision_id=DeploymentRevisionID(uuid.uuid4()), + health_check=None, traffic_ratio=1.0, traffic_status=RouteTrafficStatus.ACTIVE, ) @@ -3393,6 +3394,7 @@ async def test_update_route_status( domain=test_domain_name, project_id=test_group_id, revision_id=DeploymentRevisionID(uuid.uuid4()), + health_check=None, ) creator = RBACEntityCreator( spec=spec, @@ -3441,6 +3443,7 @@ async def test_update_route_with_unified_spec( domain=test_domain_name, project_id=test_group_id, revision_id=DeploymentRevisionID(uuid.uuid4()), + health_check=None, ) creator = RBACEntityCreator( spec=spec, diff --git a/tests/unit/manager/services/deployment/test_deployment_crud_actions.py b/tests/unit/manager/services/deployment/test_deployment_crud_actions.py index f71c2a80b46..e3266f8d64d 100644 --- a/tests/unit/manager/services/deployment/test_deployment_crud_actions.py +++ b/tests/unit/manager/services/deployment/test_deployment_crud_actions.py @@ -155,7 +155,6 @@ def endpoint_info(self, endpoint_id: uuid.UUID) -> DeploymentInfo: network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), ) diff --git a/tests/unit/manager/services/deployment/test_deployment_service.py b/tests/unit/manager/services/deployment/test_deployment_service.py index 6c32576718f..0b31a0cf890 100644 --- a/tests/unit/manager/services/deployment/test_deployment_service.py +++ b/tests/unit/manager/services/deployment/test_deployment_service.py @@ -384,7 +384,6 @@ def deployment_info(self, deployment_id: uuid.UUID) -> DeploymentInfo: network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), ) @@ -560,7 +559,6 @@ def deployment_info( network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), ) @@ -726,10 +724,9 @@ def test_current_revision_resolved_by_id_match_not_list_order( network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[deploying_data, current_data], options=DeploymentOptions(), - current_revision_id=DeploymentRevisionID(current_data.id), - deploying_revision_id=DeploymentRevisionID(deploying_data.id), + current_revision=current_data, + deploying_revision=deploying_data, ) deployment_data = _convert_deployment_info_to_data(deployment_info) diff --git a/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py b/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py index 97afe20929a..c18cbfd0825 100644 --- a/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py +++ b/tests/unit/manager/sokovan/deployment/handlers/test_deploying_handler.py @@ -10,6 +10,7 @@ import dataclasses from datetime import datetime +from typing import cast from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 @@ -26,6 +27,7 @@ DeploymentNetworkData, DeploymentOptions, DeploymentState, + ModelRevisionData, ReplicaData, ) from ai.backend.manager.data.resource.types import ScalingGroupProxyTarget @@ -133,9 +135,8 @@ def deployment_created_without_revision(self) -> DeploymentWithHistory: url=None, preferred_domain_name=None, ), - model_revisions=[revision], - current_revision_id=None, - deploying_revision_id=deploying_rev_id, + current_revision=None, + deploying_revision=cast(ModelRevisionData, revision), sub_step=DeploymentLifecycleSubStep.DEPLOYING_PROVISIONING, options=DeploymentOptions(), ), @@ -158,7 +159,7 @@ async def test_registers_endpoint_for_deployment_created_without_revision( dep, revision_id = call_args[0] info = deployment_created_without_revision.deployment_info assert dep is deployment_created_without_revision - assert revision_id == info.deploying_revision_id + assert revision_id == info.deploying_revision.id # type: ignore[union-attr] async def test_deployment_already_with_url_is_not_reregistered( self, @@ -186,7 +187,7 @@ async def test_deployment_without_deploying_revision_is_filtered( ) -> None: """Deployments with no deploying_revision_id must be filtered out.""" info = deployment_created_without_revision.deployment_info - info_no_rev = dataclasses.replace(info, deploying_revision_id=None) + info_no_rev = dataclasses.replace(info, deploying_revision=None) deployment = DeploymentWithHistory(deployment_info=info_no_rev, last_history=None) await handler.execute([deployment]) diff --git a/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py b/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py index e25cc3ac94a..58b2f348416 100644 --- a/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py +++ b/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py @@ -13,8 +13,11 @@ from __future__ import annotations +import dataclasses from dataclasses import dataclass from datetime import UTC, datetime +from typing import cast +from unittest.mock import MagicMock from uuid import UUID, uuid4 import pytest @@ -33,6 +36,7 @@ DeploymentNetworkData, DeploymentOptions, DeploymentState, + ModelRevisionData, ReplicaData, RouteHealthStatus, RouteInfo, @@ -101,6 +105,10 @@ def make_deployment( current_revision_id: UUID = OLD_REV, endpoint_id: UUID = ENDPOINT_ID, ) -> DeploymentInfo: + deploying_mock = MagicMock() + deploying_mock.id = DeploymentRevisionID(deploying_revision_id) + current_mock = MagicMock() + current_mock.id = DeploymentRevisionID(current_revision_id) return DeploymentInfo( id=DeploymentID(endpoint_id), metadata=DeploymentMetadata( @@ -125,10 +133,9 @@ def make_deployment( network=DeploymentNetworkData( open_to_public=False, access_token_ids=None, url=None, preferred_domain_name=None ), - model_revisions=[], options=DeploymentOptions(), - current_revision_id=DeploymentRevisionID(current_revision_id), - deploying_revision_id=DeploymentRevisionID(deploying_revision_id), + current_revision=cast(ModelRevisionData, current_mock), + deploying_revision=cast(ModelRevisionData, deploying_mock), ) @@ -745,7 +752,7 @@ def test_all_old_inactive_no_new_creates_desired(self) -> None: def test_deploying_rev_none_rejected(self) -> None: """If deploying_revision_id is None, evaluate_cycle raises.""" - deployment = make_deployment(desired=1, deploying_revision_id=None) # type: ignore[arg-type] + deployment = dataclasses.replace(make_deployment(desired=1), deploying_revision=None) spec = RollingUpdateSpec( max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) ) From d3dc32b7787814ecd586ba13c2ba7381b680eba8 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:28:29 +0900 Subject: [PATCH 10/13] fix(BA-6035): remove remaining model_revisions from coordinator history test Co-Authored-By: Claude Sonnet 4.6 --- .../unit/manager/sokovan/deployment/test_coordinator_history.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/manager/sokovan/deployment/test_coordinator_history.py b/tests/unit/manager/sokovan/deployment/test_coordinator_history.py index 3fdf5862279..261ba81575a 100644 --- a/tests/unit/manager/sokovan/deployment/test_coordinator_history.py +++ b/tests/unit/manager/sokovan/deployment/test_coordinator_history.py @@ -71,7 +71,6 @@ def sample_deployment_info() -> DeploymentInfo: url=None, preferred_domain_name=None, ), - model_revisions=[], options=DeploymentOptions(), ) From 3e3d4806a704d1392f7dadcc1759c9ed97dce66c Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:31:44 +0900 Subject: [PATCH 11/13] changelog: add news fragment for PR #11632 --- changes/11632.enhance.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/11632.enhance.md diff --git a/changes/11632.enhance.md b/changes/11632.enhance.md new file mode 100644 index 00000000000..c33f2a9890b --- /dev/null +++ b/changes/11632.enhance.md @@ -0,0 +1 @@ +Split route health Valkey record into ReplicaProbeTarget (probe config) and ReplicaHealthStatus (TTL-based result), removing initial_delay from Valkey and switching to DB-based timeout in check_warming_up_health. From 6f50520d91a84beb61adea7eb592386722007765 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 15:39:11 +0900 Subject: [PATCH 12/13] fix(BA-6035): rebase migration c3d4e5f6a7b8 onto b2d4f6e8c1a3 --- .../versions/c3d4e5f6a7b8_add_health_check_to_routings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py index 5a975cb9d6b..29a7b6d37c4 100644 --- a/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py +++ b/src/ai/backend/manager/models/alembic/versions/c3d4e5f6a7b8_add_health_check_to_routings.py @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. revision = "c3d4e5f6a7b8" -down_revision = "b2c3d4e5f6a7" +down_revision = "b2d4f6e8c1a3" branch_labels = None depends_on = None From 5c102fc1d87ac1d1be672cdfda7273a201d4d6a5 Mon Sep 17 00:00:00 2001 From: HyeockJinKim Date: Fri, 15 May 2026 18:22:20 +0900 Subject: [PATCH 13/13] fix(BA-6035): treat no-health-check RUNNING routes as healthy in rolling update MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add health_check: ModelHealthCheck | None to RouteInfo (populated from RoutingRow) - Rolling update _classify_routes: RUNNING routes without health check count as new_healthy_count so DEPLOYING→READY transition can complete - DB health_status stays DEGRADED (no probe data — correct behaviour) - Add TestNoHealthCheck unit tests - Fix CLI revision add --auto-activate flag (nest under options dict) Co-Authored-By: Claude Sonnet 4.6 --- .../client/cli/v2/deployment/revision.py | 3 +- .../backend/manager/data/deployment/types.py | 3 +- src/ai/backend/manager/models/routing/row.py | 1 + .../deployment/strategy/rolling_update.py | 6 +- .../test_deployment_crud_actions.py | 4 ++ .../strategy/test_rolling_update.py | 71 +++++++++++++++++++ 6 files changed, 85 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/client/cli/v2/deployment/revision.py b/src/ai/backend/client/cli/v2/deployment/revision.py index 99d386c96f1..499168253ce 100644 --- a/src/ai/backend/client/cli/v2/deployment/revision.py +++ b/src/ai/backend/client/cli/v2/deployment/revision.py @@ -49,7 +49,8 @@ def add(deployment_id: str, config: str, preset_id: str | None, auto_activate: b data["deployment_id"] = deployment_id if preset_id is not None: data["revision_preset_id"] = preset_id - data["auto_activate"] = auto_activate + if auto_activate: + data.setdefault("options", {})["auto_activate"] = True body = AddRevisionInput.model_validate(data) async def _run() -> None: diff --git a/src/ai/backend/manager/data/deployment/types.py b/src/ai/backend/manager/data/deployment/types.py index 1e94c790dde..bc85ea5dd36 100644 --- a/src/ai/backend/manager/data/deployment/types.py +++ b/src/ai/backend/manager/data/deployment/types.py @@ -12,7 +12,7 @@ import yarl from pydantic import ConfigDict, Field -from ai.backend.common.config import ModelDefinition, ModelDefinitionDraft +from ai.backend.common.config import ModelDefinition, ModelDefinitionDraft, ModelHealthCheck from ai.backend.common.data.endpoint.types import EndpointLifecycle, ScalingState from ai.backend.common.data.model_deployment.types import ( ActivenessStatus, @@ -802,6 +802,7 @@ class RouteInfo: created_at: datetime revision_id: UUID traffic_status: RouteTrafficStatus + health_check: ModelHealthCheck | None error_data: dict[str, Any] = field(default_factory=dict) @property diff --git a/src/ai/backend/manager/models/routing/row.py b/src/ai/backend/manager/models/routing/row.py index ba4e3e7a825..3b8e6795a89 100644 --- a/src/ai/backend/manager/models/routing/row.py +++ b/src/ai/backend/manager/models/routing/row.py @@ -271,5 +271,6 @@ def to_route_info(self) -> RouteInfo: created_at=self.created_at, revision_id=self.revision, traffic_status=self.traffic_status, + health_check=self.health_check, error_data=self.error_data or {}, ) diff --git a/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py b/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py index d22579c7030..9cf8cea86f9 100644 --- a/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py +++ b/src/ai/backend/manager/sokovan/deployment/strategy/rolling_update.py @@ -129,7 +129,11 @@ def _classify_routes( elif route.status.is_inactive(): classified.new_failed_count += 1 elif route.status == RouteStatus.RUNNING: - if route.health_status == RouteHealthStatus.HEALTHY: + if route.health_status == RouteHealthStatus.HEALTHY or route.health_check is None: + # Routes without a health check have no probe data, so we treat them + # as ready once their process is RUNNING (health_status stays DEGRADED + # in DB because no Valkey entry exists — that is correct behaviour, but + # it must not block the DEPLOYING → READY transition). classified.new_healthy_count += 1 else: # UNHEALTHY / DEGRADED / NOT_CHECKED all count here: diff --git a/tests/unit/manager/services/deployment/test_deployment_crud_actions.py b/tests/unit/manager/services/deployment/test_deployment_crud_actions.py index e3266f8d64d..fec1f73b085 100644 --- a/tests/unit/manager/services/deployment/test_deployment_crud_actions.py +++ b/tests/unit/manager/services/deployment/test_deployment_crud_actions.py @@ -314,6 +314,7 @@ def route_info(self, endpoint_id: uuid.UUID) -> RouteInfo: created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) async def test_existing_replica_returns_data( @@ -364,6 +365,7 @@ async def test_zero_weight_traffic_inactive( created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.INACTIVE, + health_check=None, ) mock_deployment_repository.get_route = AsyncMock(return_value=inactive_route) @@ -391,6 +393,7 @@ async def test_unassigned_session_id_is_none( created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) mock_deployment_repository.get_route = AsyncMock(return_value=route) @@ -416,6 +419,7 @@ def route_info(self, endpoint_id: uuid.UUID) -> RouteInfo: created_at=datetime(2024, 1, 1, tzinfo=UTC), revision_id=uuid.uuid4(), traffic_status=RouteTrafficStatus.ACTIVE, + health_check=None, ) async def test_default_pagination( diff --git a/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py b/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py index 58b2f348416..98f4c7f8f94 100644 --- a/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py +++ b/tests/unit/manager/sokovan/deployment/strategy/test_rolling_update.py @@ -23,6 +23,7 @@ import pytest from pydantic import ValidationError +from ai.backend.common.config import ModelHealthCheck from ai.backend.common.data.endpoint.types import EndpointLifecycle, ScalingState from ai.backend.common.dto.manager.v2.deployment.types import IntOrPercent from ai.backend.common.exception import BackendAISchemaValidationFailed @@ -61,6 +62,7 @@ def make_int_or_percent(value: int | float) -> IntOrPercent: OLD_REV = UUID("11111111-1111-1111-1111-111111111111") +_STUB_HEALTH_CHECK = ModelHealthCheck(path="/health", interval=10.0, initial_delay=30.0) NEW_REV = UUID("22222222-2222-2222-2222-222222222222") PROJECT_ID = UUID("cccccccc-cccc-cccc-cccc-cccccccccccc") USER_ID = UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") @@ -146,6 +148,7 @@ def make_route( health_status: RouteHealthStatus = RouteHealthStatus.HEALTHY, endpoint_id: UUID = ENDPOINT_ID, route_id: UUID | None = None, + health_check: ModelHealthCheck | None = _STUB_HEALTH_CHECK, ) -> RouteInfo: return RouteInfo( route_id=route_id or uuid4(), @@ -159,6 +162,7 @@ def make_route( traffic_status=RouteTrafficStatus.ACTIVE if status.is_active() else RouteTrafficStatus.INACTIVE, + health_check=health_check, ) @@ -1132,3 +1136,70 @@ def test_small_fraction_with_few_replicas(self) -> None: result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) assert len(result.route_changes.rollout_specs) == 1 + + +# =========================================================================== +# No-health-check scenario +# =========================================================================== + + +class TestNoHealthCheck: + """Routes without health_check stay DEGRADED in DB but must allow READY transition.""" + + def test_running_degraded_no_health_check_completes(self) -> None: + """RUNNING + DEGRADED + no health_check → counts as healthy → COMPLETED.""" + deployment = make_deployment(desired=1) + spec = RollingUpdateSpec( + max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) + ) + routes = [ + make_route( + revision_id=NEW_REV, + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.DEGRADED, + health_check=None, + ) + ] + + result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) + + assert result.sub_step == DeploymentLifecycleSubStep.DEPLOYING_COMPLETED + + def test_running_degraded_with_health_check_does_not_complete(self) -> None: + """RUNNING + DEGRADED + has health_check → still unhealthy → PROVISIONING.""" + deployment = make_deployment(desired=1) + spec = RollingUpdateSpec( + max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) + ) + routes = [ + make_route( + revision_id=NEW_REV, + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.DEGRADED, + health_check=_STUB_HEALTH_CHECK, + ) + ] + + result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) + + assert result.sub_step == DeploymentLifecycleSubStep.DEPLOYING_PROVISIONING + + def test_multiple_replicas_no_health_check_completes(self) -> None: + """2 desired, 2 RUNNING DEGRADED no-health-check → COMPLETED.""" + deployment = make_deployment(desired=2) + spec = RollingUpdateSpec( + max_surge=make_int_or_percent(1), max_unavailable=make_int_or_percent(0) + ) + routes = [ + make_route( + revision_id=NEW_REV, + status=RouteStatus.RUNNING, + health_status=RouteHealthStatus.DEGRADED, + health_check=None, + ) + for _ in range(2) + ] + + result = RollingUpdateStrategy().evaluate_cycle(deployment, routes, spec) + + assert result.sub_step == DeploymentLifecycleSubStep.DEPLOYING_COMPLETED