Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 97 additions & 66 deletions src/opencode_a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,19 @@ def __init__(
self._client: Client | None = None
self._agent_card: object | None = None
self._lock = asyncio.Lock()
self._request_lock = asyncio.Lock()
self._active_requests = 0

async def close(self) -> None:
"""Close cached client resources and owned HTTP transport."""
self._client = None
if self._httpx_client is not None and self._owns_httpx_client:
await self._httpx_client.aclose()

def is_busy(self) -> bool:
"""Report whether this facade currently has in-flight work."""
return self._active_requests > 0

async def get_agent_card(self) -> Any:
"""Fetch and cache peer Agent Card."""
if self._agent_card is not None:
Expand Down Expand Up @@ -101,29 +107,33 @@ async def send_message(
extensions: list[str] | None = None,
) -> AsyncIterator[A2AClientEvent]:
"""Send one user message and stream protocol events."""
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
request = self._build_user_message(
text=text,
context_id=context_id,
task_id=task_id,
message_id=message_id,
)
await self._acquire_operation()
try:
async for event in client.send_message(
request,
context=build_call_context(self._settings.bearer_token, extra_headers),
request_metadata=request_metadata,
extensions=extensions,
):
yield event
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("message/send", exc) from exc
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
request = self._build_user_message(
text=text,
context_id=context_id,
task_id=task_id,
message_id=message_id,
)
try:
async for event in client.send_message(
request,
context=build_call_context(self._settings.bearer_token, extra_headers),
request_metadata=request_metadata,
extensions=extensions,
):
yield event
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("message/send", exc) from exc
finally:
await self._release_operation()

async def send(
self,
Expand Down Expand Up @@ -156,24 +166,28 @@ async def get_task(
metadata: Mapping[str, Any] | None = None,
) -> Task:
"""Fetch one task by id."""
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
await self._acquire_operation()
try:
return await client.get_task(
TaskQueryParams(
id=task_id,
history_length=history_length,
metadata=request_metadata or {},
),
context=build_call_context(self._settings.bearer_token, extra_headers),
)
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("tasks/get", exc) from exc
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
try:
return await client.get_task(
TaskQueryParams(
id=task_id,
history_length=history_length,
metadata=request_metadata or {},
),
context=build_call_context(self._settings.bearer_token, extra_headers),
)
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("tasks/get", exc) from exc
finally:
await self._release_operation()

async def cancel_task(
self,
Expand All @@ -182,20 +196,24 @@ async def cancel_task(
metadata: Mapping[str, Any] | None = None,
) -> Task:
"""Cancel one task by id."""
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
await self._acquire_operation()
try:
return await client.cancel_task(
TaskIdParams(id=task_id, metadata=request_metadata or {}),
context=build_call_context(self._settings.bearer_token, extra_headers),
)
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("tasks/cancel", exc) from exc
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
try:
return await client.cancel_task(
TaskIdParams(id=task_id, metadata=request_metadata or {}),
context=build_call_context(self._settings.bearer_token, extra_headers),
)
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("tasks/cancel", exc) from exc
finally:
await self._release_operation()

async def resubscribe_task(
self,
Expand All @@ -204,21 +222,25 @@ async def resubscribe_task(
metadata: Mapping[str, Any] | None = None,
) -> AsyncIterator[tuple[Task, TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None]]:
"""Resubscribe to task updates."""
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
await self._acquire_operation()
try:
async for event in client.resubscribe(
TaskIdParams(id=task_id, metadata=request_metadata or {}),
context=build_call_context(self._settings.bearer_token, extra_headers),
):
yield event
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("tasks/resubscribe", exc) from exc
client = await self._ensure_client()
request_metadata, extra_headers = split_request_metadata(metadata)
try:
async for event in client.resubscribe(
TaskIdParams(id=task_id, metadata=request_metadata or {}),
context=build_call_context(self._settings.bearer_token, extra_headers),
):
yield event
except (
A2AClientHTTPError,
A2AClientJSONRPCError,
httpx.TimeoutException,
httpx.TransportError,
) as exc:
raise map_operation_error("tasks/resubscribe", exc) from exc
finally:
await self._release_operation()

async def _ensure_client(self) -> Client:
async with self._lock:
Expand Down Expand Up @@ -254,6 +276,15 @@ async def _get_httpx_client(self) -> httpx.AsyncClient:
self._httpx_client = httpx.AsyncClient(timeout=self._settings.default_timeout)
return self._httpx_client

async def _acquire_operation(self) -> None:
async with self._request_lock:
self._active_requests += 1

async def _release_operation(self) -> None:
async with self._request_lock:
if self._active_requests > 0:
self._active_requests -= 1

def _build_user_message(
self,
*,
Expand Down
10 changes: 10 additions & 0 deletions src/opencode_a2a/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ class Settings(BaseSettings):
default=False, alias="A2A_CLIENT_USE_CLIENT_PREFERENCE"
)
a2a_client_bearer_token: str | None = Field(default=None, alias="A2A_CLIENT_BEARER_TOKEN")
a2a_client_cache_ttl_seconds: float = Field(
default=900.0,
ge=0.0,
alias="A2A_CLIENT_CACHE_TTL_SECONDS",
)
a2a_client_cache_maxsize: int = Field(
default=256,
ge=0,
alias="A2A_CLIENT_CACHE_MAXSIZE",
)
a2a_client_supported_transports: DeclaredStringList = Field(
default=("JSONRPC", "HTTP+JSON"),
alias="A2A_CLIENT_SUPPORTED_TRANSPORTS",
Expand Down
12 changes: 6 additions & 6 deletions src/opencode_a2a/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,14 +642,14 @@ async def _handle_a2a_call_tool(self, part: dict[str, Any]) -> dict[str, Any]:
}

try:
client = await mgr.get_client(agent_url)
event = None
result_text = ""
async for current_event in client.send_message(message):
event = current_event
extracted = client.extract_text(current_event)
if extracted:
result_text = self._merge_streamed_tool_output(result_text, extracted)
async with mgr.borrow_client(agent_url) as client:
async for current_event in client.send_message(message):
event = current_event
extracted = client.extract_text(current_event)
if extracted:
result_text = self._merge_streamed_tool_output(result_text, extracted)

from a2a.types import Task

Expand Down
Loading
Loading