Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ Key variables to understand protocol behavior:
- `A2A_CLIENT_BEARER_TOKEN`: optional bearer token attached to outbound peer
calls made by the embedded A2A client and `a2a_call` tool path.
- `A2A_CLIENT_SUPPORTED_TRANSPORTS`: ordered outbound transport preference list.
- `A2A_TASK_STORE_BACKEND`: task store backend. Supported values: `memory`,
`database`. Default: `memory`.
- `A2A_TASK_STORE_DATABASE_URL`: database URL used when
`A2A_TASK_STORE_BACKEND=database`. For local persistence, prefer
`sqlite+aiosqlite:///./opencode-a2a.db`.
- `A2A_TASK_STORE_TABLE_NAME` / `A2A_TASK_STORE_CREATE_TABLE`: database task
store table name and whether to auto-create database tables on startup.
- Runtime authentication is bearer-token only via `A2A_BEARER_TOKEN`.
- The same outbound client flags are also honored by the server-side embedded
A2A client used for peer calls and `a2a_call` tool execution:
Expand Down Expand Up @@ -157,6 +164,26 @@ OPENCODE_WORKSPACE_ROOT=/abs/path/to/workspace \
opencode-a2a
```

To persist A2A task records across restarts, switch the task store backend to
SQLite:

```bash
OPENCODE_BASE_URL=http://127.0.0.1:4096 \
A2A_BEARER_TOKEN=dev-token \
A2A_TASK_STORE_BACKEND=database \
A2A_TASK_STORE_DATABASE_URL=sqlite+aiosqlite:///./opencode-a2a.db \
opencode-a2a
```

When `A2A_TASK_STORE_BACKEND=database`, the service now persists:

- task records
- session binding / ownership state
- interrupt request bindings and tombstones

In-flight asyncio locks, outbound A2A client caches, and stream-local
aggregation buffers remain process-local runtime state.

## Troubleshooting Provider Auth State

If one deployment works while another fails against the same upstream provider,
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ classifiers = [
]
dependencies = [
"a2a-sdk==0.3.25",
"aiosqlite>=0.20,<1.0",
"fastapi>=0.110,<1.0",
"httpx>=0.27,<1.0",
"pydantic>=2.6,<3.0",
"pydantic-settings>=2.2,<3.0",
"sqlalchemy>=2.0,<3.0",
"sse-starlette>=2.1,<4.0",
"uvicorn>=0.29,<1.0",
]
Expand Down
24 changes: 24 additions & 0 deletions src/opencode_a2a/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"custom",
]
OutsideWorkspaceAccess = Literal["unknown", "allowed", "disallowed", "custom"]
TaskStoreBackend = Literal["memory", "database"]


def _parse_declared_list(value: Any) -> tuple[str, ...]:
Expand Down Expand Up @@ -176,7 +177,30 @@ class Settings(BaseSettings):
alias="A2A_CLIENT_SUPPORTED_TRANSPORTS",
)

# Task store settings
a2a_task_store_backend: TaskStoreBackend = Field(
default="memory",
alias="A2A_TASK_STORE_BACKEND",
)
a2a_task_store_database_url: str | None = Field(
default=None,
alias="A2A_TASK_STORE_DATABASE_URL",
)
a2a_task_store_table_name: str = Field(
default="tasks",
min_length=1,
alias="A2A_TASK_STORE_TABLE_NAME",
)
a2a_task_store_create_table: bool = Field(
default=True,
alias="A2A_TASK_STORE_CREATE_TABLE",
)

@model_validator(mode="after")
def _validate_sandbox_policy(self) -> Settings:
SandboxPolicy.from_settings(self).validate_configuration()
if self.a2a_task_store_backend == "database" and not self.a2a_task_store_database_url:
raise ValueError(
"A2A_TASK_STORE_DATABASE_URL is required when A2A_TASK_STORE_BACKEND=database"
)
return self
3 changes: 3 additions & 0 deletions src/opencode_a2a/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from ..server.application import A2AClientManager
from ..server.state_store import SessionStateRepository

import httpx
from a2a.server.agent_execution import AgentExecutor, RequestContext
Expand Down Expand Up @@ -531,6 +532,7 @@ def __init__(
session_cache_ttl_seconds: int = 3600,
session_cache_maxsize: int = 10_000,
a2a_client_manager: A2AClientManager | None = None,
session_state_repository: SessionStateRepository | None = None,
) -> None:
self._client = client
self._streaming_enabled = streaming_enabled
Expand All @@ -544,6 +546,7 @@ def __init__(
client=client,
session_cache_ttl_seconds=session_cache_ttl_seconds,
session_cache_maxsize=session_cache_maxsize,
state_repository=session_state_repository,
)
self._stream_runtime = StreamRuntime(
client=client,
Expand Down
67 changes: 43 additions & 24 deletions src/opencode_a2a/execution/session_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import asyncio
from typing import Any, cast

from .stream_state import _TTLCache
from ..server.state_store import MemorySessionStateRepository, SessionStateRepository


class SessionManager:
Expand All @@ -12,18 +13,21 @@ def __init__(
client,
session_cache_ttl_seconds: int = 3600,
session_cache_maxsize: int = 10_000,
state_repository: SessionStateRepository | None = None,
) -> None:
self._client = client
self._sessions = _TTLCache(
self._state_repository = state_repository or MemorySessionStateRepository(
ttl_seconds=session_cache_ttl_seconds,
maxsize=session_cache_maxsize,
)
self._session_owners = _TTLCache(
ttl_seconds=session_cache_ttl_seconds,
maxsize=session_cache_maxsize,
refresh_on_get=True,
)
self._pending_session_claims: dict[str, str] = {}
if isinstance(self._state_repository, MemorySessionStateRepository):
self._sessions = self._state_repository.sessions
self._session_owners = self._state_repository.session_owners
self._pending_session_claims = self._state_repository.pending_session_claims
else:
self._sessions = cast("Any", None)
self._session_owners = cast("Any", None)
self._pending_session_claims = cast("Any", None)
self._lock = asyncio.Lock()
self._inflight_session_creates: dict[tuple[str, str], asyncio.Task[str]] = {}
self._session_locks: dict[str, asyncio.Lock] = {}
Expand All @@ -49,7 +53,10 @@ async def get_or_create_session(
task: asyncio.Task[str] | None = None
cache_key = (identity, context_id)
async with self._lock:
existing = self._sessions.get(cache_key)
existing = await self._state_repository.get_session(
identity=cache_key[0],
context_id=cache_key[1],
)
if existing:
return existing, False
task = self._inflight_session_creates.get(cache_key)
Expand All @@ -68,14 +75,18 @@ async def get_or_create_session(
raise

async with self._lock:
owner = self._session_owners.get(session_id)
owner = await self._state_repository.get_owner(session_id=session_id)
if owner and owner != identity:
if self._inflight_session_creates.get(cache_key) is task:
self._inflight_session_creates.pop(cache_key, None)
raise PermissionError(f"Session {session_id} is not owned by you")
self._sessions.set(cache_key, session_id)
await self._state_repository.set_session(
identity=cache_key[0],
context_id=cache_key[1],
session_id=session_id,
)
if not owner:
self._session_owners.set(session_id, identity)
await self._state_repository.set_owner(session_id=session_id, identity=identity)
if self._inflight_session_creates.get(cache_key) is task:
self._inflight_session_creates.pop(cache_key, None)
return session_id, False
Expand All @@ -89,37 +100,45 @@ async def finalize_preferred_session_binding(
) -> None:
await self.finalize_session_claim(identity=identity, session_id=session_id)
async with self._lock:
self._sessions.set((identity, context_id), session_id)
await self._state_repository.set_session(
identity=identity,
context_id=context_id,
session_id=session_id,
)

async def claim_preferred_session(self, *, identity: str, session_id: str) -> bool:
async with self._lock:
owner = self._session_owners.get(session_id)
pending_owner = self._pending_session_claims.get(session_id)
owner = await self._state_repository.get_owner(session_id=session_id)
pending_owner = await self._state_repository.get_pending_claim(session_id=session_id)
if owner and owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
if pending_owner and pending_owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
if owner == identity:
return False
self._pending_session_claims[session_id] = identity
await self._state_repository.set_pending_claim(session_id=session_id, identity=identity)
return True

async def finalize_session_claim(self, *, identity: str, session_id: str) -> None:
async with self._lock:
owner = self._session_owners.get(session_id)
pending_owner = self._pending_session_claims.get(session_id)
owner = await self._state_repository.get_owner(session_id=session_id)
pending_owner = await self._state_repository.get_pending_claim(session_id=session_id)
if owner and owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
if pending_owner and pending_owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
self._session_owners.set(session_id, identity)
if self._pending_session_claims.get(session_id) == identity:
self._pending_session_claims.pop(session_id, None)
await self._state_repository.set_owner(session_id=session_id, identity=identity)
await self._state_repository.clear_pending_claim(
session_id=session_id,
identity=identity,
)

async def release_preferred_session_claim(self, *, identity: str, session_id: str) -> None:
async with self._lock:
if self._pending_session_claims.get(session_id) == identity:
self._pending_session_claims.pop(session_id, None)
await self._state_repository.clear_pending_claim(
session_id=session_id,
identity=identity,
)

async def get_session_lock(self, session_id: str) -> asyncio.Lock:
async with self._lock:
Expand All @@ -136,5 +155,5 @@ async def pop_cached_session(
context_id: str,
) -> asyncio.Task[str] | None:
async with self._lock:
self._sessions.pop((identity, context_id))
await self._state_repository.pop_session(identity=identity, context_id=context_id)
return self._inflight_session_creates.pop((identity, context_id), None)
4 changes: 2 additions & 2 deletions src/opencode_a2a/execution/stream_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def _tool_chunks(
None,
)
if callable(remember_request):
remember_request(
await remember_request(
request_id=request_id,
session_id=session_id,
interrupt_type=asked["interrupt_type"],
Expand All @@ -456,7 +456,7 @@ def _tool_chunks(
None,
)
if callable(discard_request):
discard_request(resolved_request_id)
await discard_request(resolved_request_id)
if cleared_pending:
await _emit_interrupt_status(
state=TaskState.working,
Expand Down
8 changes: 4 additions & 4 deletions src/opencode_a2a/jsonrpc/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ async def _handle_interrupt_callback_request(
)
resolve_request = getattr(self._upstream_client, "resolve_interrupt_request", None)
if callable(resolve_request):
status, binding = resolve_request(request_id)
status, binding = await resolve_request(request_id)
if status != "active" or binding is None:
return self._generate_error_response(
base_request.id,
Expand Down Expand Up @@ -818,7 +818,7 @@ async def _handle_interrupt_callback_request(
else:
resolve_session = getattr(self._upstream_client, "resolve_interrupt_session", None)
if callable(resolve_session):
if not resolve_session(request_id):
if not await resolve_session(request_id):
return self._generate_error_response(
base_request.id,
interrupt_not_found_error(
Expand Down Expand Up @@ -869,7 +869,7 @@ async def _handle_interrupt_callback_request(
await self._upstream_client.question_reject(request_id, directory=directory)
discard_request = getattr(self._upstream_client, "discard_interrupt_request", None)
if callable(discard_request):
discard_request(request_id)
await discard_request(request_id)
except ValueError as exc:
return self._generate_error_response(
base_request.id,
Expand All @@ -880,7 +880,7 @@ async def _handle_interrupt_callback_request(
if upstream_status == 404:
discard_request = getattr(self._upstream_client, "discard_interrupt_request", None)
if callable(discard_request):
discard_request(request_id)
await discard_request(request_id)
return self._generate_error_response(
base_request.id,
interrupt_not_found_error(
Expand Down
Loading
Loading