diff --git a/src/opencode_a2a/client/client.py b/src/opencode_a2a/client/client.py index b609193..7b3801a 100644 --- a/src/opencode_a2a/client/client.py +++ b/src/opencode_a2a/client/client.py @@ -57,6 +57,8 @@ def __init__( self._client: Client | None = None self._agent_card: object | None = None self._lock = asyncio.Lock() + self._request_lock = asyncio.Lock() + self._active_requests = 0 async def close(self) -> None: """Close cached client resources and owned HTTP transport.""" @@ -64,6 +66,10 @@ async def close(self) -> None: if self._httpx_client is not None and self._owns_httpx_client: await self._httpx_client.aclose() + def is_busy(self) -> bool: + """Report whether this facade currently has in-flight work.""" + return self._active_requests > 0 + async def get_agent_card(self) -> Any: """Fetch and cache peer Agent Card.""" if self._agent_card is not None: @@ -101,29 +107,33 @@ async def send_message( extensions: list[str] | None = None, ) -> AsyncIterator[A2AClientEvent]: """Send one user message and stream protocol events.""" - client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) - request = self._build_user_message( - text=text, - context_id=context_id, - task_id=task_id, - message_id=message_id, - ) + await self._acquire_operation() try: - async for event in client.send_message( - request, - context=build_call_context(self._settings.bearer_token, extra_headers), - request_metadata=request_metadata, - extensions=extensions, - ): - yield event - except ( - A2AClientHTTPError, - A2AClientJSONRPCError, - httpx.TimeoutException, - httpx.TransportError, - ) as exc: - raise map_operation_error("message/send", exc) from exc + client = await self._ensure_client() + request_metadata, extra_headers = split_request_metadata(metadata) + request = self._build_user_message( + text=text, + context_id=context_id, + task_id=task_id, + message_id=message_id, + ) + try: + async for event in client.send_message( + request, + context=build_call_context(self._settings.bearer_token, extra_headers), + request_metadata=request_metadata, + extensions=extensions, + ): + yield event + except ( + A2AClientHTTPError, + A2AClientJSONRPCError, + httpx.TimeoutException, + httpx.TransportError, + ) as exc: + raise map_operation_error("message/send", exc) from exc + finally: + await self._release_operation() async def send( self, @@ -156,24 +166,28 @@ async def get_task( metadata: Mapping[str, Any] | None = None, ) -> Task: """Fetch one task by id.""" - client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + await self._acquire_operation() try: - return await client.get_task( - TaskQueryParams( - id=task_id, - history_length=history_length, - metadata=request_metadata or {}, - ), - context=build_call_context(self._settings.bearer_token, extra_headers), - ) - except ( - A2AClientHTTPError, - A2AClientJSONRPCError, - httpx.TimeoutException, - httpx.TransportError, - ) as exc: - raise map_operation_error("tasks/get", exc) from exc + client = await self._ensure_client() + request_metadata, extra_headers = split_request_metadata(metadata) + try: + return await client.get_task( + TaskQueryParams( + id=task_id, + history_length=history_length, + metadata=request_metadata or {}, + ), + context=build_call_context(self._settings.bearer_token, extra_headers), + ) + except ( + A2AClientHTTPError, + A2AClientJSONRPCError, + httpx.TimeoutException, + httpx.TransportError, + ) as exc: + raise map_operation_error("tasks/get", exc) from exc + finally: + await self._release_operation() async def cancel_task( self, @@ -182,20 +196,24 @@ async def cancel_task( metadata: Mapping[str, Any] | None = None, ) -> Task: """Cancel one task by id.""" - client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + await self._acquire_operation() try: - return await client.cancel_task( - TaskIdParams(id=task_id, metadata=request_metadata or {}), - context=build_call_context(self._settings.bearer_token, extra_headers), - ) - except ( - A2AClientHTTPError, - A2AClientJSONRPCError, - httpx.TimeoutException, - httpx.TransportError, - ) as exc: - raise map_operation_error("tasks/cancel", exc) from exc + client = await self._ensure_client() + request_metadata, extra_headers = split_request_metadata(metadata) + try: + return await client.cancel_task( + TaskIdParams(id=task_id, metadata=request_metadata or {}), + context=build_call_context(self._settings.bearer_token, extra_headers), + ) + except ( + A2AClientHTTPError, + A2AClientJSONRPCError, + httpx.TimeoutException, + httpx.TransportError, + ) as exc: + raise map_operation_error("tasks/cancel", exc) from exc + finally: + await self._release_operation() async def resubscribe_task( self, @@ -204,21 +222,25 @@ async def resubscribe_task( metadata: Mapping[str, Any] | None = None, ) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None]]: """Resubscribe to task updates.""" - client = await self._ensure_client() - request_metadata, extra_headers = split_request_metadata(metadata) + await self._acquire_operation() try: - async for event in client.resubscribe( - TaskIdParams(id=task_id, metadata=request_metadata or {}), - context=build_call_context(self._settings.bearer_token, extra_headers), - ): - yield event - except ( - A2AClientHTTPError, - A2AClientJSONRPCError, - httpx.TimeoutException, - httpx.TransportError, - ) as exc: - raise map_operation_error("tasks/resubscribe", exc) from exc + client = await self._ensure_client() + request_metadata, extra_headers = split_request_metadata(metadata) + try: + async for event in client.resubscribe( + TaskIdParams(id=task_id, metadata=request_metadata or {}), + context=build_call_context(self._settings.bearer_token, extra_headers), + ): + yield event + except ( + A2AClientHTTPError, + A2AClientJSONRPCError, + httpx.TimeoutException, + httpx.TransportError, + ) as exc: + raise map_operation_error("tasks/resubscribe", exc) from exc + finally: + await self._release_operation() async def _ensure_client(self) -> Client: async with self._lock: @@ -254,6 +276,15 @@ async def _get_httpx_client(self) -> httpx.AsyncClient: self._httpx_client = httpx.AsyncClient(timeout=self._settings.default_timeout) return self._httpx_client + async def _acquire_operation(self) -> None: + async with self._request_lock: + self._active_requests += 1 + + async def _release_operation(self) -> None: + async with self._request_lock: + if self._active_requests > 0: + self._active_requests -= 1 + def _build_user_message( self, *, diff --git a/src/opencode_a2a/config.py b/src/opencode_a2a/config.py index 4699aad..ad72f65 100644 --- a/src/opencode_a2a/config.py +++ b/src/opencode_a2a/config.py @@ -158,6 +158,16 @@ class Settings(BaseSettings): default=False, alias="A2A_CLIENT_USE_CLIENT_PREFERENCE" ) a2a_client_bearer_token: str | None = Field(default=None, alias="A2A_CLIENT_BEARER_TOKEN") + a2a_client_cache_ttl_seconds: float = Field( + default=900.0, + ge=0.0, + alias="A2A_CLIENT_CACHE_TTL_SECONDS", + ) + a2a_client_cache_maxsize: int = Field( + default=256, + ge=0, + alias="A2A_CLIENT_CACHE_MAXSIZE", + ) a2a_client_supported_transports: DeclaredStringList = Field( default=("JSONRPC", "HTTP+JSON"), alias="A2A_CLIENT_SUPPORTED_TRANSPORTS", diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index 55b0a48..8e516ed 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -642,14 +642,14 @@ async def _handle_a2a_call_tool(self, part: dict[str, Any]) -> dict[str, Any]: } try: - client = await mgr.get_client(agent_url) event = None result_text = "" - async for current_event in client.send_message(message): - event = current_event - extracted = client.extract_text(current_event) - if extracted: - result_text = self._merge_streamed_tool_output(result_text, extracted) + async with mgr.borrow_client(agent_url) as client: + async for current_event in client.send_message(message): + event = current_event + extracted = client.extract_text(current_event) + if extracted: + result_text = self._merge_streamed_tool_output(result_text, extracted) from a2a.types import Task diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 2a2ceb7..69af4a5 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -340,6 +340,8 @@ async def bearer_auth(request: Request, call_next): class A2AClientManager: def __init__(self, settings: Settings) -> None: + import time + from ..client.config import load_settings as load_client_settings self.client_settings = load_client_settings( @@ -353,21 +355,148 @@ def __init__(self, settings: Settings) -> None: "A2A_CLIENT_SUPPORTED_TRANSPORTS": settings.a2a_client_supported_transports, } ) - self.clients: dict[str, A2AClient] = {} + self._cache_ttl_seconds = float(settings.a2a_client_cache_ttl_seconds) + self._cache_maxsize = int(settings.a2a_client_cache_maxsize) + self._now = time.monotonic + self.clients: dict[str, _ClientCacheEntry] = {} self._lock = asyncio.Lock() - async def get_client(self, agent_url: str) -> A2AClient: + @property + def cache_ttl_seconds(self) -> float: + return self._cache_ttl_seconds + + @property + def cache_maxsize(self) -> int: + return self._cache_maxsize + + @asynccontextmanager + async def borrow_client(self, agent_url: str): + url = agent_url.rstrip("/") + if self._cache_maxsize <= 0: + client = A2AClient(url, settings=self.client_settings) + try: + yield client + finally: + await client.close() + return + + to_close: list[A2AClient] = [] async with self._lock: - url = agent_url.rstrip("/") - if url not in self.clients: - self.clients[url] = A2AClient(url, settings=self.client_settings) - return self.clients[url] + now = self._now() + entry = self.clients.get(url) + if entry is not None and self._entry_expired(entry, now=now): + if self._entry_in_use(entry): + entry.pending_eviction = True + else: + self.clients.pop(url, None) + to_close.append(entry.client) + entry = None + to_close.extend(self._evict_locked(now=now, protected_keys={url})) + if entry is None: + entry = _ClientCacheEntry( + client=A2AClient(url, settings=self.client_settings), + last_used=now, + expires_at=self._expires_at_for(now), + ) + self.clients[url] = entry + else: + entry.last_used = now + entry.expires_at = self._expires_at_for(now) + entry.pending_eviction = False + entry.borrow_count += 1 + to_close.extend(self._evict_locked(now=now, protected_keys={url})) + await self._close_clients(to_close) + + try: + yield entry.client + finally: + async with self._lock: + now = self._now() + current = self.clients.get(url) + if current is entry: + if current.borrow_count > 0: + current.borrow_count -= 1 + current.last_used = now + current.expires_at = self._expires_at_for(now) + to_close = self._evict_locked(now=now) + await self._close_clients(to_close) async def close_all(self) -> None: async with self._lock: - for client in self.clients.values(): - await client.close() + clients = [entry.client for entry in self.clients.values()] self.clients.clear() + for client in clients: + await client.close() + + def _expires_at_for(self, now: float) -> float | None: + if self._cache_ttl_seconds <= 0: + return None + return now + self._cache_ttl_seconds + + def _evict_locked( + self, + *, + now: float, + protected_keys: set[str] | None = None, + ) -> list[A2AClient]: + protected = protected_keys or set() + to_close: list[A2AClient] = [] + + for key, entry in list(self.clients.items()): + expired = self._entry_expired(entry, now=now) + if not expired and not entry.pending_eviction: + continue + if key in protected or self._entry_in_use(entry): + entry.pending_eviction = True + continue + self.clients.pop(key, None) + to_close.append(entry.client) + + if self._cache_maxsize <= 0 or len(self.clients) <= self._cache_maxsize: + return to_close + + if any(entry.pending_eviction for entry in self.clients.values()): + return to_close + + for key, entry in sorted(self.clients.items(), key=lambda item: item[1].last_used): + if len(self.clients) <= self._cache_maxsize: + break + if key in protected: + continue + if self._entry_in_use(entry): + entry.pending_eviction = True + continue + self.clients.pop(key, None) + to_close.append(entry.client) + + return to_close + + def _entry_expired(self, entry: _ClientCacheEntry, *, now: float) -> bool: + return entry.expires_at is not None and entry.expires_at <= now + + def _entry_in_use(self, entry: _ClientCacheEntry) -> bool: + return entry.borrow_count > 0 or entry.client.is_busy() + + async def _close_clients(self, clients: list[A2AClient]) -> None: + for client in clients: + await client.close() + + +class _ClientCacheEntry: + def __init__( + self, + *, + client: A2AClient, + last_used: float, + expires_at: float | None, + borrow_count: int = 0, + pending_eviction: bool = False, + ) -> None: + self.client = client + self.last_used = last_used + self.expires_at = expires_at + self.borrow_count = borrow_count + self.pending_eviction = pending_eviction def create_app(settings: Settings) -> FastAPI: diff --git a/tests/execution/test_opencode_agent_session_binding.py b/tests/execution/test_opencode_agent_session_binding.py index e0105d2..66b01dc 100644 --- a/tests/execution/test_opencode_agent_session_binding.py +++ b/tests/execution/test_opencode_agent_session_binding.py @@ -263,8 +263,17 @@ async def close(self): pass class MockManager: - async def get_client(self, url: str): - return MockA2AClient() + class _BorrowedClient: + async def __aenter__(self): + return MockA2AClient() + + async def __aexit__(self, exc_type, exc, tb): + del exc_type, exc, tb + return False + + def borrow_client(self, url: str): + del url + return self._BorrowedClient() client = DummyChatOpencodeUpstreamClient() manager = MockManager() @@ -322,27 +331,36 @@ async def send_message(self, *args, **kwargs) -> OpencodeMessage: return OpencodeMessage(text="done", session_id="s1", message_id="m2", raw={}) class MockManager: - async def get_client(self, url: str): - mock_client = MagicMock() - - async def _send_message(_text: str): - task = Task(id="t", context_id="c", status=TaskStatus(state=TaskState.working)) - yield ( - task, - TaskArtifactUpdateEvent( - task_id="t", - context_id="c", - artifact=Artifact( - artifact_id="artifact-1", - name="response", - parts=[Part(root=TextPart(text="streamed tool output"))], + class _BorrowedClient: + async def __aenter__(self): + mock_client = MagicMock() + + async def _send_message(_text: str): + task = Task(id="t", context_id="c", status=TaskStatus(state=TaskState.working)) + yield ( + task, + TaskArtifactUpdateEvent( + task_id="t", + context_id="c", + artifact=Artifact( + artifact_id="artifact-1", + name="response", + parts=[Part(root=TextPart(text="streamed tool output"))], + ), ), - ), - ) + ) - mock_client.send_message = _send_message - mock_client.extract_text = A2AClient.extract_text - return mock_client + mock_client.send_message = _send_message + mock_client.extract_text = A2AClient.extract_text + return mock_client + + async def __aexit__(self, exc_type, exc, tb): + del exc_type, exc, tb + return False + + def borrow_client(self, url: str): + del url + return self._BorrowedClient() from unittest.mock import MagicMock @@ -437,9 +455,17 @@ def send_message(self, text: str): return _AuthErrorStream() class MockManager: - async def get_client(self, url: str): + class _BorrowedClient: + async def __aenter__(self): + return MockA2AClient() + + async def __aexit__(self, exc_type, exc, tb): + del exc_type, exc, tb + return False + + def borrow_client(self, url: str): del url - return MockA2AClient() + return self._BorrowedClient() client = DummyChatOpencodeUpstreamClient() executor = OpencodeAgentExecutor( diff --git a/tests/server/test_a2a_client_manager.py b/tests/server/test_a2a_client_manager.py new file mode 100644 index 0000000..82e769e --- /dev/null +++ b/tests/server/test_a2a_client_manager.py @@ -0,0 +1,199 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from opencode_a2a.server import application as app_module + + +def _make_settings(**overrides: object) -> SimpleNamespace: + values = { + "a2a_client_timeout_seconds": 30.0, + "a2a_client_card_fetch_timeout_seconds": 5.0, + "a2a_client_use_client_preference": False, + "a2a_client_bearer_token": None, + "a2a_client_supported_transports": ("JSONRPC", "HTTP+JSON"), + "a2a_client_cache_ttl_seconds": 60.0, + "a2a_client_cache_maxsize": 2, + } + values.update(overrides) + return SimpleNamespace(**values) + + +@pytest.mark.asyncio +async def test_client_manager_evicts_lru_idle_clients(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[_FakeClient] = [] + + class _FakeClient: + def __init__(self, agent_url: str, *, settings) -> None: + del settings + self.agent_url = agent_url + self.closed = False + self.busy = False + created.append(self) + + def is_busy(self) -> bool: + return self.busy + + async def close(self) -> None: + self.closed = True + + monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + + manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=2)) + + async with manager.borrow_client("http://peer-1"): + pass + async with manager.borrow_client("http://peer-2"): + pass + async with manager.borrow_client("http://peer-3"): + pass + + assert set(manager.clients) == {"http://peer-2", "http://peer-3"} + assert created[0].closed is True + assert created[1].closed is False + assert created[2].closed is False + + +@pytest.mark.asyncio +async def test_client_manager_defers_busy_client_eviction(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[_FakeClient] = [] + + class _FakeClient: + def __init__(self, agent_url: str, *, settings) -> None: + del settings + self.agent_url = agent_url + self.closed = False + self.busy = False + created.append(self) + + def is_busy(self) -> bool: + return self.busy + + async def close(self) -> None: + self.closed = True + + monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + + manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=1)) + + async with manager.borrow_client("http://peer-1") as first_client: + first_client.busy = True + async with manager.borrow_client("http://peer-2"): + pass + assert set(manager.clients) == {"http://peer-1", "http://peer-2"} + assert first_client.closed is False + first_client.busy = False + + assert set(manager.clients) == {"http://peer-2"} + assert created[0].closed is True + assert created[1].closed is False + + +@pytest.mark.asyncio +async def test_client_manager_preserves_borrowed_client_before_operation_starts( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[_FakeClient] = [] + + class _FakeClient: + def __init__(self, agent_url: str, *, settings) -> None: + del settings + self.agent_url = agent_url + self.closed = False + created.append(self) + + def is_busy(self) -> bool: + return False + + async def close(self) -> None: + self.closed = True + + monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + + manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_maxsize=1)) + + async with manager.borrow_client("http://peer-1"): + async with manager.borrow_client("http://peer-2"): + pass + assert set(manager.clients) == {"http://peer-1", "http://peer-2"} + assert created[0].closed is False + + assert set(manager.clients) == {"http://peer-2"} + assert created[0].closed is True + assert created[1].closed is False + + +@pytest.mark.asyncio +async def test_client_manager_evicts_expired_clients(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[_FakeClient] = [] + + class _FakeClient: + def __init__(self, agent_url: str, *, settings) -> None: + del settings + self.agent_url = agent_url + self.closed = False + self.busy = False + created.append(self) + + def is_busy(self) -> bool: + return self.busy + + async def close(self) -> None: + self.closed = True + + monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + + now = 100.0 + manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_ttl_seconds=10.0)) + manager._now = lambda: now + + async with manager.borrow_client("http://peer-1"): + pass + + now = 111.0 + async with manager.borrow_client("http://peer-2"): + pass + + assert set(manager.clients) == {"http://peer-2"} + assert created[0].closed is True + assert created[1].closed is False + + +@pytest.mark.asyncio +async def test_client_manager_rebuilds_expired_entry_for_same_url( + monkeypatch: pytest.MonkeyPatch, +) -> None: + created: list[_FakeClient] = [] + + class _FakeClient: + def __init__(self, agent_url: str, *, settings) -> None: + del settings + self.agent_url = agent_url + self.closed = False + self.busy = False + created.append(self) + + def is_busy(self) -> bool: + return self.busy + + async def close(self) -> None: + self.closed = True + + monkeypatch.setattr(app_module, "A2AClient", _FakeClient) + + now = 100.0 + manager = app_module.A2AClientManager(_make_settings(a2a_client_cache_ttl_seconds=10.0)) + manager._now = lambda: now + + async with manager.borrow_client("http://peer-1") as first_client: + pass + + now = 111.0 + async with manager.borrow_client("http://peer-1") as second_client: + pass + + assert first_client is not second_client + assert created[0].closed is True + assert created[1].closed is False diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index 5a74ce8..e15612d 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -539,6 +539,8 @@ def test_create_app_propagates_outbound_client_settings(monkeypatch) -> None: a2a_client_card_fetch_timeout_seconds=7.0, a2a_client_use_client_preference=True, a2a_client_bearer_token="peer-token", + a2a_client_cache_ttl_seconds=321.0, + a2a_client_cache_maxsize=12, a2a_client_supported_transports=("http-json", "json-rpc"), ) ) @@ -550,6 +552,8 @@ def test_create_app_propagates_outbound_client_settings(monkeypatch) -> None: assert settings.use_client_preference is True assert settings.bearer_token == "peer-token" assert settings.supported_transports == ("HTTP+JSON", "JSONRPC") + assert client_manager.cache_ttl_seconds == 321.0 + assert client_manager.cache_maxsize == 12 def test_create_app_requires_control_guard_hooks(monkeypatch) -> None: