From 10573d5b69dc308b1ce83b349432901ef89b4311 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Mon, 23 Mar 2026 10:04:42 -0400 Subject: [PATCH 1/4] feat(server): add database-backed task store backend (#304) --- docs/guide.md | 22 +++++ pyproject.toml | 2 + src/opencode_a2a/config.py | 24 +++++ src/opencode_a2a/server/application.py | 6 +- src/opencode_a2a/server/task_store.py | 87 ++++++++++++++++++ tests/server/test_task_store_factory.py | 56 ++++++++++++ tests/server/test_transport_contract.py | 33 +++++++ uv.lock | 113 ++++++++++++++++++++++++ 8 files changed, 341 insertions(+), 2 deletions(-) create mode 100644 src/opencode_a2a/server/task_store.py create mode 100644 tests/server/test_task_store_factory.py diff --git a/docs/guide.md b/docs/guide.md index 300014c..026ce9c 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -76,6 +76,13 @@ Key variables to understand protocol behavior: - `A2A_CLIENT_BEARER_TOKEN`: optional bearer token attached to outbound peer calls made by the embedded A2A client and `a2a_call` tool path. - `A2A_CLIENT_SUPPORTED_TRANSPORTS`: ordered outbound transport preference list. +- `A2A_TASK_STORE_BACKEND`: task store backend. Supported values: `memory`, + `database`. Default: `memory`. +- `A2A_TASK_STORE_DATABASE_URL`: database URL used when + `A2A_TASK_STORE_BACKEND=database`. For local persistence, prefer + `sqlite+aiosqlite:///./opencode-a2a.db`. +- `A2A_TASK_STORE_TABLE_NAME` / `A2A_TASK_STORE_CREATE_TABLE`: database task + store table name and whether to auto-create it on startup. - Runtime authentication is bearer-token only via `A2A_BEARER_TOKEN`. - The same outbound client flags are also honored by the server-side embedded A2A client used for peer calls and `a2a_call` tool execution: @@ -157,6 +164,21 @@ OPENCODE_WORKSPACE_ROOT=/abs/path/to/workspace \ opencode-a2a ``` +To persist A2A task records across restarts, switch the task store backend to +SQLite: + +```bash +OPENCODE_BASE_URL=http://127.0.0.1:4096 \ +A2A_BEARER_TOKEN=dev-token \ +A2A_TASK_STORE_BACKEND=database \ +A2A_TASK_STORE_DATABASE_URL=sqlite+aiosqlite:///./opencode-a2a.db \ +opencode-a2a +``` + +At the moment, this database-backed store persists task records only. +Session binding state and interrupt request bindings remain in-process runtime +state and are not yet persisted. + ## Troubleshooting Provider Auth State If one deployment works while another fails against the same upstream provider, diff --git a/pyproject.toml b/pyproject.toml index 9effa1b..a4d7b7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,10 +25,12 @@ classifiers = [ ] dependencies = [ "a2a-sdk==0.3.25", + "aiosqlite>=0.20,<1.0", "fastapi>=0.110,<1.0", "httpx>=0.27,<1.0", "pydantic>=2.6,<3.0", "pydantic-settings>=2.2,<3.0", + "sqlalchemy>=2.0,<3.0", "sse-starlette>=2.1,<4.0", "uvicorn>=0.29,<1.0", ] diff --git a/src/opencode_a2a/config.py b/src/opencode_a2a/config.py index 1638274..d82043f 100644 --- a/src/opencode_a2a/config.py +++ b/src/opencode_a2a/config.py @@ -35,6 +35,7 @@ "custom", ] OutsideWorkspaceAccess = Literal["unknown", "allowed", "disallowed", "custom"] +TaskStoreBackend = Literal["memory", "database"] def _parse_declared_list(value: Any) -> tuple[str, ...]: @@ -176,7 +177,30 @@ class Settings(BaseSettings): alias="A2A_CLIENT_SUPPORTED_TRANSPORTS", ) + # Task store settings + a2a_task_store_backend: TaskStoreBackend = Field( + default="memory", + alias="A2A_TASK_STORE_BACKEND", + ) + a2a_task_store_database_url: str | None = Field( + default=None, + alias="A2A_TASK_STORE_DATABASE_URL", + ) + a2a_task_store_table_name: str = Field( + default="tasks", + min_length=1, + alias="A2A_TASK_STORE_TABLE_NAME", + ) + a2a_task_store_create_table: bool = Field( + default=True, + alias="A2A_TASK_STORE_CREATE_TABLE", + ) + @model_validator(mode="after") def _validate_sandbox_policy(self) -> Settings: SandboxPolicy.from_settings(self).validate_configuration() + if self.a2a_task_store_backend == "database" and not self.a2a_task_store_database_url: + raise ValueError( + "A2A_TASK_STORE_DATABASE_URL is required when A2A_TASK_STORE_BACKEND=database" + ) return self diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index f741bf6..f14c2aa 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -18,7 +18,6 @@ TERMINAL_TASK_STATES, DefaultRequestHandler, ) -from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore from a2a.types import ( Task, TaskIdParams, @@ -79,6 +78,7 @@ _request_body_too_large_response, _RequestBodyTooLargeError, ) +from .task_store import build_task_store, close_task_store, initialize_task_store logger = logging.getLogger(__name__) @@ -499,7 +499,7 @@ def create_app(settings: Settings) -> FastAPI: session_cache_maxsize=settings.a2a_session_cache_maxsize, a2a_client_manager=client_manager, ) - task_store = InMemoryTaskStore() + task_store = build_task_store(settings) handler = OpencodeRequestHandler( agent_executor=executor, task_store=task_store, @@ -549,7 +549,9 @@ def create_app(settings: Settings) -> FastAPI: @asynccontextmanager async def lifespan(_app: FastAPI): + await initialize_task_store(task_store) yield + await close_task_store(task_store) await client_manager.close_all() await upstream_client.close() diff --git a/src/opencode_a2a/server/task_store.py b/src/opencode_a2a/server/task_store.py new file mode 100644 index 0000000..05a42e9 --- /dev/null +++ b/src/opencode_a2a/server/task_store.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, cast + +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.server.tasks.task_store import TaskStore + +from ..config import Settings + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncEngine + +_CUSTOM_TASK_MODELS: dict[str, object] = {} + + +class _ConfiguredDatabaseTaskStore(TaskStore): + def __init__( + self, + *, + engine: "AsyncEngine", + create_table: bool, + table_name: str, + ) -> None: + from a2a.server.models import TaskModel, create_task_model + from a2a.server.tasks.database_task_store import DatabaseTaskStore + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + + self._delegate = DatabaseTaskStore.__new__(DatabaseTaskStore) + self._delegate.engine = engine + self._delegate.async_session_maker = async_sessionmaker( + engine, + expire_on_commit=False, + class_=AsyncSession, + ) + self._delegate.create_table = create_table + self._delegate._initialized = False + if table_name == "tasks": + self._delegate.task_model = TaskModel + else: + task_model = cast("type[TaskModel] | None", _CUSTOM_TASK_MODELS.get(table_name)) + if task_model is None: + task_model = create_task_model(table_name) + _CUSTOM_TASK_MODELS[table_name] = task_model + self._delegate.task_model = task_model + + @property + def engine(self) -> "AsyncEngine": + return self._delegate.engine + + async def initialize(self) -> None: + await self._delegate.initialize() + + async def save(self, task, context=None) -> None: # noqa: ANN001 + await self._delegate.save(task, context) + + async def get(self, task_id, context=None): # noqa: ANN001 + return await self._delegate.get(task_id, context) + + async def delete(self, task_id, context=None) -> None: # noqa: ANN001 + await self._delegate.delete(task_id, context) + + +def build_task_store(settings: Settings) -> TaskStore: + if settings.a2a_task_store_backend == "memory": + return InMemoryTaskStore() + + from sqlalchemy.ext.asyncio import create_async_engine + + database_url = cast(str, settings.a2a_task_store_database_url) + engine = create_async_engine(database_url) + return _ConfiguredDatabaseTaskStore( + engine=engine, + create_table=settings.a2a_task_store_create_table, + table_name=settings.a2a_task_store_table_name, + ) + + +async def initialize_task_store(task_store: TaskStore) -> None: + initialize = getattr(task_store, "initialize", None) + if callable(initialize): + await initialize() + + +async def close_task_store(task_store: TaskStore) -> None: + engine = cast("AsyncEngine | None", getattr(task_store, "engine", None)) + if engine is not None: + await engine.dispose() diff --git a/tests/server/test_task_store_factory.py b/tests/server/test_task_store_factory.py new file mode 100644 index 0000000..2f418a9 --- /dev/null +++ b/tests/server/test_task_store_factory.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import Task, TaskState, TaskStatus + +from opencode_a2a.server.task_store import ( + build_task_store, + close_task_store, + initialize_task_store, +) +from tests.support.helpers import make_settings + + +def _task(task_id: str, *, context_id: str = "ctx-1") -> Task: + return Task( + id=task_id, + contextId=context_id, + status=TaskStatus(state=TaskState.working), + ) + + +def test_build_task_store_defaults_to_memory_backend() -> None: + store = build_task_store(make_settings(a2a_bearer_token="test-token")) + + assert isinstance(store, InMemoryTaskStore) + + +@pytest.mark.asyncio +async def test_database_task_store_persists_tasks_across_rebuilds(tmp_path: Path) -> None: + database_path = tmp_path / "tasks.db" + database_url = f"sqlite+aiosqlite:///{database_path}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="database", + a2a_task_store_database_url=database_url, + a2a_task_store_table_name="tasks_test", + ) + + writer = build_task_store(settings) + await initialize_task_store(writer) + await writer.save(_task("task-1")) + await close_task_store(writer) + + reader = build_task_store(settings) + await initialize_task_store(reader) + restored = await reader.get("task-1") + + assert restored is not None + assert restored.id == "task-1" + assert restored.context_id == "ctx-1" + assert restored.status.state == TaskState.working + + await close_task_store(reader) diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index 065bbe2..a76d7ec 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -591,3 +591,36 @@ async def cancel(self, _context, _event_queue) -> None: # noqa: ANN001 with pytest.raises(ValueError, match="Control methods require guard hooks"): app_module.create_app(make_settings(a2a_bearer_token="test-token")) + + +def test_create_app_builds_configured_task_store(monkeypatch) -> None: + import opencode_a2a.server.application as app_module + + captured: dict[str, object] = {} + + def _build_task_store(settings): # noqa: ANN001 + captured["backend"] = settings.a2a_task_store_backend + captured["database_url"] = settings.a2a_task_store_database_url + captured["table_name"] = settings.a2a_task_store_table_name + captured["create_table"] = settings.a2a_task_store_create_table + return MagicMock() + + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", DummyChatOpencodeUpstreamClient) + monkeypatch.setattr(app_module, "build_task_store", _build_task_store) + + app_module.create_app( + make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="database", + a2a_task_store_database_url="sqlite+aiosqlite:///./test.db", + a2a_task_store_table_name="a2a_tasks", + a2a_task_store_create_table=False, + ) + ) + + assert captured == { + "backend": "database", + "database_url": "sqlite+aiosqlite:///./test.db", + "table_name": "a2a_tasks", + "create_table": False, + } diff --git a/uv.lock b/uv.lock index d385103..c520ab9 100644 --- a/uv.lock +++ b/uv.lock @@ -22,6 +22,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/f9/6a62520b7ecb945188a6e1192275f4732ff9341cd4629bc975a6c146aeab/a2a_sdk-0.3.25-py3-none-any.whl", hash = "sha256:2fce38faea82eb0b6f9f9c2bcf761b0d78612c80ef0e599b50d566db1b2654b5", size = 149609, upload-time = "2026-03-10T13:08:44.7Z" }, ] +[[package]] +name = "aiosqlite" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -396,6 +405,53 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c4/ab/09169d5a4612a5f92490806649ac8d41e3ec9129c636754575b3553f4ea4/googleapis_common_protos-1.72.0-py3-none-any.whl", hash = "sha256:4299c5a82d5ae1a9702ada957347726b167f9f8d1fc352477702a1e851ff4038", size = 297515, upload-time = "2025-11-06T18:29:13.14Z" }, ] +[[package]] +name = "greenlet" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/51/1664f6b78fc6ebbd98019a1fd730e83fa78f2db7058f72b1463d3612b8db/greenlet-3.3.2.tar.gz", hash = "sha256:2eaf067fc6d886931c7962e8c6bede15d2f01965560f3359b27c80bde2d151f2", size = 188267, upload-time = "2026-02-20T20:54:15.531Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f3/47/16400cb42d18d7a6bb46f0626852c1718612e35dcb0dffa16bbaffdf5dd2/greenlet-3.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:c56692189a7d1c7606cb794be0a8381470d95c57ce5be03fb3d0ef57c7853b86", size = 278890, upload-time = "2026-02-20T20:19:39.263Z" }, + { url = "https://files.pythonhosted.org/packages/a3/90/42762b77a5b6aa96cd8c0e80612663d39211e8ae8a6cd47c7f1249a66262/greenlet-3.3.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ebd458fa8285960f382841da585e02201b53a5ec2bac6b156fc623b5ce4499f", size = 581120, upload-time = "2026-02-20T20:47:30.161Z" }, + { url = "https://files.pythonhosted.org/packages/bf/6f/f3d64f4fa0a9c7b5c5b3c810ff1df614540d5aa7d519261b53fba55d4df9/greenlet-3.3.2-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a443358b33c4ec7b05b79a7c8b466f5d275025e750298be7340f8fc63dff2a55", size = 594363, upload-time = "2026-02-20T20:55:56.965Z" }, + { url = "https://files.pythonhosted.org/packages/72/83/3e06a52aca8128bdd4dcd67e932b809e76a96ab8c232a8b025b2850264c5/greenlet-3.3.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e2cd90d413acbf5e77ae41e5d3c9b3ac1d011a756d7284d7f3f2b806bbd6358", size = 594156, upload-time = "2026-02-20T20:20:59.955Z" }, + { url = "https://files.pythonhosted.org/packages/70/79/0de5e62b873e08fe3cef7dbe84e5c4bc0e8ed0c7ff131bccb8405cd107c8/greenlet-3.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:442b6057453c8cb29b4fb36a2ac689382fc71112273726e2423f7f17dc73bf99", size = 1554649, upload-time = "2026-02-20T20:49:32.293Z" }, + { url = "https://files.pythonhosted.org/packages/5a/00/32d30dee8389dc36d42170a9c66217757289e2afb0de59a3565260f38373/greenlet-3.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:45abe8eb6339518180d5a7fa47fa01945414d7cca5ecb745346fc6a87d2750be", size = 1619472, upload-time = "2026-02-20T20:21:07.966Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3a/efb2cf697fbccdf75b24e2c18025e7dfa54c4f31fab75c51d0fe79942cef/greenlet-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e692b2dae4cc7077cbb11b47d258533b48c8fde69a33d0d8a82e2fe8d8531d5", size = 230389, upload-time = "2026-02-20T20:17:18.772Z" }, + { url = "https://files.pythonhosted.org/packages/e1/a1/65bbc059a43a7e2143ec4fc1f9e3f673e04f9c7b371a494a101422ac4fd5/greenlet-3.3.2-cp311-cp311-win_arm64.whl", hash = "sha256:02b0a8682aecd4d3c6c18edf52bc8e51eacdd75c8eac52a790a210b06aa295fd", size = 229645, upload-time = "2026-02-20T20:18:18.695Z" }, + { url = "https://files.pythonhosted.org/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd", size = 280358, upload-time = "2026-02-20T20:17:43.971Z" }, + { url = "https://files.pythonhosted.org/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd", size = 601217, upload-time = "2026-02-20T20:47:31.462Z" }, + { url = "https://files.pythonhosted.org/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac", size = 611792, upload-time = "2026-02-20T20:55:58.423Z" }, + { url = "https://files.pythonhosted.org/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070", size = 613875, upload-time = "2026-02-20T20:21:01.102Z" }, + { url = "https://files.pythonhosted.org/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79", size = 1571467, upload-time = "2026-02-20T20:49:33.495Z" }, + { url = "https://files.pythonhosted.org/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395", size = 1640001, upload-time = "2026-02-20T20:21:09.154Z" }, + { url = "https://files.pythonhosted.org/packages/9b/40/cc802e067d02af8b60b6771cea7d57e21ef5e6659912814babb42b864713/greenlet-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:34308836d8370bddadb41f5a7ce96879b72e2fdfb4e87729330c6ab52376409f", size = 231081, upload-time = "2026-02-20T20:17:28.121Z" }, + { url = "https://files.pythonhosted.org/packages/58/2e/fe7f36ff1982d6b10a60d5e0740c759259a7d6d2e1dc41da6d96de32fff6/greenlet-3.3.2-cp312-cp312-win_arm64.whl", hash = "sha256:d3a62fa76a32b462a97198e4c9e99afb9ab375115e74e9a83ce180e7a496f643", size = 230331, upload-time = "2026-02-20T20:17:23.34Z" }, + { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, + { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, + { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, + { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, + { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, + { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, + { url = "https://files.pythonhosted.org/packages/91/39/5ef5aa23bc545aa0d31e1b9b55822b32c8da93ba657295840b6b34124009/greenlet-3.3.2-cp313-cp313-win_amd64.whl", hash = "sha256:a7945dd0eab63ded0a48e4dcade82939783c172290a7903ebde9e184333ca124", size = 230961, upload-time = "2026-02-20T20:16:58.461Z" }, + { url = "https://files.pythonhosted.org/packages/62/6b/a89f8456dcb06becff288f563618e9f20deed8dd29beea14f9a168aef64b/greenlet-3.3.2-cp313-cp313-win_arm64.whl", hash = "sha256:394ead29063ee3515b4e775216cb756b2e3b4a7e55ae8fd884f17fa579e6b327", size = 230221, upload-time = "2026-02-20T20:17:37.152Z" }, + { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, + { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, + { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, + { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, + { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, + { url = "https://files.pythonhosted.org/packages/f3/ca/2101ca3d9223a1dc125140dbc063644dca76df6ff356531eb27bc267b446/greenlet-3.3.2-cp314-cp314-win_amd64.whl", hash = "sha256:8c4dd0f3997cf2512f7601563cc90dfb8957c0cff1e3a1b23991d4ea1776c492", size = 232034, upload-time = "2026-02-20T20:20:08.186Z" }, + { url = "https://files.pythonhosted.org/packages/f6/4a/ecf894e962a59dea60f04877eea0fd5724618da89f1867b28ee8b91e811f/greenlet-3.3.2-cp314-cp314-win_arm64.whl", hash = "sha256:cd6f9e2bbd46321ba3bbb4c8a15794d32960e3b0ae2cc4d49a1a53d314805d71", size = 231437, upload-time = "2026-02-20T20:18:59.722Z" }, + { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, + { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, + { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, + { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, + { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, + { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" }, + { url = "https://files.pythonhosted.org/packages/29/4b/45d90626aef8e65336bed690106d1382f7a43665e2249017e9527df8823b/greenlet-3.3.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c04c5e06ec3e022cbfe2cd4a846e1d4e50087444f875ff6d2c2ad8445495cf1a", size = 237086, upload-time = "2026-02-20T20:20:45.786Z" }, +] + [[package]] name = "h11" version = "0.16.0" @@ -690,10 +746,12 @@ name = "opencode-a2a" source = { editable = "." } dependencies = [ { name = "a2a-sdk" }, + { name = "aiosqlite" }, { name = "fastapi" }, { name = "httpx" }, { name = "pydantic" }, { name = "pydantic-settings" }, + { name = "sqlalchemy" }, { name = "sse-starlette" }, { name = "uvicorn" }, ] @@ -712,6 +770,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "a2a-sdk", specifier = "==0.3.25" }, + { name = "aiosqlite", specifier = ">=0.20,<1.0" }, { name = "fastapi", specifier = ">=0.110,<1.0" }, { name = "httpx", specifier = ">=0.27,<1.0" }, { name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1,<2.0" }, @@ -723,6 +782,7 @@ requires-dist = [ { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23,<2.0" }, { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=7.0.0,<8.0.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.5,<0.16" }, + { name = "sqlalchemy", specifier = ">=2.0,<3.0" }, { name = "sse-starlette", specifier = ">=2.1,<4.0" }, { name = "uvicorn", specifier = ">=0.29,<1.0" }, ] @@ -1229,6 +1289,59 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/46/9cb0e58b2deb7f82b84065f37f3bffeb12413f947f9388e4cac22c4621ce/sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0", size = 29575, upload-time = "2021-05-16T22:03:41.177Z" }, ] +[[package]] +name = "sqlalchemy" +version = "2.0.48" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1f/73/b4a9737255583b5fa858e0bb8e116eb94b88c910164ed2ed719147bde3de/sqlalchemy-2.0.48.tar.gz", hash = "sha256:5ca74f37f3369b45e1f6b7b06afb182af1fd5dde009e4ffd831830d98cbe5fe7", size = 9886075, upload-time = "2026-03-02T15:28:51.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/6d/b8b78b5b80f3c3ab3f7fa90faa195ec3401f6d884b60221260fd4d51864c/sqlalchemy-2.0.48-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b4c575df7368b3b13e0cebf01d4679f9a28ed2ae6c1cd0b1d5beffb6b2007dc", size = 2157184, upload-time = "2026-03-02T15:38:28.161Z" }, + { url = "https://files.pythonhosted.org/packages/21/4b/4f3d4a43743ab58b95b9ddf5580a265b593d017693df9e08bd55780af5bb/sqlalchemy-2.0.48-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e83e3f959aaa1c9df95c22c528096d94848a1bc819f5d0ebf7ee3df0ca63db6c", size = 3313555, upload-time = "2026-03-02T15:58:57.21Z" }, + { url = "https://files.pythonhosted.org/packages/21/dd/3b7c53f1dbbf736fd27041aee68f8ac52226b610f914085b1652c2323442/sqlalchemy-2.0.48-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f7b7243850edd0b8b97043f04748f31de50cf426e939def5c16bedb540698f7", size = 3313057, upload-time = "2026-03-02T15:52:29.366Z" }, + { url = "https://files.pythonhosted.org/packages/d9/cc/3e600a90ae64047f33313d7d32e5ad025417f09d2ded487e8284b5e21a15/sqlalchemy-2.0.48-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:82745b03b4043e04600a6b665cb98697c4339b24e34d74b0a2ac0a2488b6f94d", size = 3265431, upload-time = "2026-03-02T15:58:59.096Z" }, + { url = "https://files.pythonhosted.org/packages/8b/19/780138dacfe3f5024f4cf96e4005e91edf6653d53d3673be4844578faf1d/sqlalchemy-2.0.48-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e5e088bf43f6ee6fec7dbf1ef7ff7774a616c236b5c0cb3e00662dd71a56b571", size = 3287646, upload-time = "2026-03-02T15:52:31.569Z" }, + { url = "https://files.pythonhosted.org/packages/40/fd/f32ced124f01a23151f4777e4c705f3a470adc7bd241d9f36a7c941a33bf/sqlalchemy-2.0.48-cp311-cp311-win32.whl", hash = "sha256:9c7d0a77e36b5f4b01ca398482230ab792061d243d715299b44a0b55c89fe617", size = 2116956, upload-time = "2026-03-02T15:46:54.535Z" }, + { url = "https://files.pythonhosted.org/packages/58/d5/dd767277f6feef12d05651538f280277e661698f617fa4d086cce6055416/sqlalchemy-2.0.48-cp311-cp311-win_amd64.whl", hash = "sha256:583849c743e0e3c9bb7446f5b5addeacedc168d657a69b418063dfdb2d90081c", size = 2141627, upload-time = "2026-03-02T15:46:55.849Z" }, + { url = "https://files.pythonhosted.org/packages/ef/91/a42ae716f8925e9659df2da21ba941f158686856107a61cc97a95e7647a3/sqlalchemy-2.0.48-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:348174f228b99f33ca1f773e85510e08927620caa59ffe7803b37170df30332b", size = 2155737, upload-time = "2026-03-02T15:49:13.207Z" }, + { url = "https://files.pythonhosted.org/packages/b9/52/f75f516a1f3888f027c1cfb5d22d4376f4b46236f2e8669dcb0cddc60275/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53667b5f668991e279d21f94ccfa6e45b4e3f4500e7591ae59a8012d0f010dcb", size = 3337020, upload-time = "2026-03-02T15:50:34.547Z" }, + { url = "https://files.pythonhosted.org/packages/37/9a/0c28b6371e0cdcb14f8f1930778cb3123acfcbd2c95bb9cf6b4a2ba0cce3/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34634e196f620c7a61d18d5cf7dc841ca6daa7961aed75d532b7e58b309ac894", size = 3349983, upload-time = "2026-03-02T15:53:25.542Z" }, + { url = "https://files.pythonhosted.org/packages/1c/46/0aee8f3ff20b1dcbceb46ca2d87fcc3d48b407925a383ff668218509d132/sqlalchemy-2.0.48-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:546572a1793cc35857a2ffa1fe0e58571af1779bcc1ffa7c9fb0839885ed69a9", size = 3279690, upload-time = "2026-03-02T15:50:36.277Z" }, + { url = "https://files.pythonhosted.org/packages/ce/8c/a957bc91293b49181350bfd55e6dfc6e30b7f7d83dc6792d72043274a390/sqlalchemy-2.0.48-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:07edba08061bc277bfdc772dd2a1a43978f5a45994dd3ede26391b405c15221e", size = 3314738, upload-time = "2026-03-02T15:53:27.519Z" }, + { url = "https://files.pythonhosted.org/packages/4b/44/1d257d9f9556661e7bdc83667cc414ba210acfc110c82938cb3611eea58f/sqlalchemy-2.0.48-cp312-cp312-win32.whl", hash = "sha256:908a3fa6908716f803b86896a09a2c4dde5f5ce2bb07aacc71ffebb57986ce99", size = 2115546, upload-time = "2026-03-02T15:54:31.591Z" }, + { url = "https://files.pythonhosted.org/packages/f2/af/c3c7e1f3a2b383155a16454df62ae8c62a30dd238e42e68c24cebebbfae6/sqlalchemy-2.0.48-cp312-cp312-win_amd64.whl", hash = "sha256:68549c403f79a8e25984376480959975212a670405e3913830614432b5daa07a", size = 2142484, upload-time = "2026-03-02T15:54:34.072Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c6/569dc8bf3cd375abc5907e82235923e986799f301cd79a903f784b996fca/sqlalchemy-2.0.48-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e3070c03701037aa418b55d36532ecb8f8446ed0135acb71c678dbdf12f5b6e4", size = 2152599, upload-time = "2026-03-02T15:49:14.41Z" }, + { url = "https://files.pythonhosted.org/packages/6d/ff/f4e04a4bd5a24304f38cb0d4aa2ad4c0fb34999f8b884c656535e1b2b74c/sqlalchemy-2.0.48-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2645b7d8a738763b664a12a1542c89c940daa55196e8d73e55b169cc5c99f65f", size = 3278825, upload-time = "2026-03-02T15:50:38.269Z" }, + { url = "https://files.pythonhosted.org/packages/fe/88/cb59509e4668d8001818d7355d9995be90c321313078c912420603a7cb95/sqlalchemy-2.0.48-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b19151e76620a412c2ac1c6f977ab1b9fa7ad43140178345136456d5265b32ed", size = 3295200, upload-time = "2026-03-02T15:53:29.366Z" }, + { url = "https://files.pythonhosted.org/packages/87/dc/1609a4442aefd750ea2f32629559394ec92e89ac1d621a7f462b70f736ff/sqlalchemy-2.0.48-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5b193a7e29fd9fa56e502920dca47dffe60f97c863494946bd698c6058a55658", size = 3226876, upload-time = "2026-03-02T15:50:39.802Z" }, + { url = "https://files.pythonhosted.org/packages/37/c3/6ae2ab5ea2fa989fbac4e674de01224b7a9d744becaf59bb967d62e99bed/sqlalchemy-2.0.48-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:36ac4ddc3d33e852da9cb00ffb08cea62ca05c39711dc67062ca2bb1fae35fd8", size = 3265045, upload-time = "2026-03-02T15:53:31.421Z" }, + { url = "https://files.pythonhosted.org/packages/6f/82/ea4665d1bb98c50c19666e672f21b81356bd6077c4574e3d2bbb84541f53/sqlalchemy-2.0.48-cp313-cp313-win32.whl", hash = "sha256:389b984139278f97757ea9b08993e7b9d1142912e046ab7d82b3fbaeb0209131", size = 2113700, upload-time = "2026-03-02T15:54:35.825Z" }, + { url = "https://files.pythonhosted.org/packages/b7/2b/b9040bec58c58225f073f5b0c1870defe1940835549dafec680cbd58c3c3/sqlalchemy-2.0.48-cp313-cp313-win_amd64.whl", hash = "sha256:d612c976cbc2d17edfcc4c006874b764e85e990c29ce9bd411f926bbfb02b9a2", size = 2139487, upload-time = "2026-03-02T15:54:37.079Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f4/7b17bd50244b78a49d22cc63c969d71dc4de54567dc152a9b46f6fae40ce/sqlalchemy-2.0.48-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:69f5bc24904d3bc3640961cddd2523e361257ef68585d6e364166dfbe8c78fae", size = 3558851, upload-time = "2026-03-02T15:57:48.607Z" }, + { url = "https://files.pythonhosted.org/packages/20/0d/213668e9aca61d370f7d2a6449ea4ec699747fac67d4bda1bb3d129025be/sqlalchemy-2.0.48-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fd08b90d211c086181caed76931ecfa2bdfc83eea3cfccdb0f82abc6c4b876cb", size = 3525525, upload-time = "2026-03-02T16:04:38.058Z" }, + { url = "https://files.pythonhosted.org/packages/85/d7/a84edf412979e7d59c69b89a5871f90a49228360594680e667cb2c46a828/sqlalchemy-2.0.48-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:1ccd42229aaac2df431562117ac7e667d702e8e44afdb6cf0e50fa3f18160f0b", size = 3466611, upload-time = "2026-03-02T15:57:50.759Z" }, + { url = "https://files.pythonhosted.org/packages/86/55/42404ce5770f6be26a2b0607e7866c31b9a4176c819e9a7a5e0a055770be/sqlalchemy-2.0.48-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f0dcbc588cd5b725162c076eb9119342f6579c7f7f55057bb7e3c6ff27e13121", size = 3475812, upload-time = "2026-03-02T16:04:40.092Z" }, + { url = "https://files.pythonhosted.org/packages/ae/ae/29b87775fadc43e627cf582fe3bda4d02e300f6b8f2747c764950d13784c/sqlalchemy-2.0.48-cp313-cp313t-win32.whl", hash = "sha256:9764014ef5e58aab76220c5664abb5d47d5bc858d9debf821e55cfdd0f128485", size = 2141335, upload-time = "2026-03-02T15:52:51.518Z" }, + { url = "https://files.pythonhosted.org/packages/91/44/f39d063c90f2443e5b46ec4819abd3d8de653893aae92df42a5c4f5843de/sqlalchemy-2.0.48-cp313-cp313t-win_amd64.whl", hash = "sha256:e2f35b4cccd9ed286ad62e0a3c3ac21e06c02abc60e20aa51a3e305a30f5fa79", size = 2173095, upload-time = "2026-03-02T15:52:52.79Z" }, + { url = "https://files.pythonhosted.org/packages/f7/b3/f437eaa1cf028bb3c927172c7272366393e73ccd104dcf5b6963f4ab5318/sqlalchemy-2.0.48-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e2d0d88686e3d35a76f3e15a34e8c12d73fc94c1dea1cd55782e695cc14086dd", size = 2154401, upload-time = "2026-03-02T15:49:17.24Z" }, + { url = "https://files.pythonhosted.org/packages/6c/1c/b3abdf0f402aa3f60f0df6ea53d92a162b458fca2321d8f1f00278506402/sqlalchemy-2.0.48-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49b7bddc1eebf011ea5ab722fdbe67a401caa34a350d278cc7733c0e88fecb1f", size = 3274528, upload-time = "2026-03-02T15:50:41.489Z" }, + { url = "https://files.pythonhosted.org/packages/f2/5e/327428a034407651a048f5e624361adf3f9fbac9d0fa98e981e9c6ff2f5e/sqlalchemy-2.0.48-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:426c5ca86415d9b8945c7073597e10de9644802e2ff502b8e1f11a7a2642856b", size = 3279523, upload-time = "2026-03-02T15:53:32.962Z" }, + { url = "https://files.pythonhosted.org/packages/2a/ca/ece73c81a918add0965b76b868b7b5359e068380b90ef1656ee995940c02/sqlalchemy-2.0.48-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:288937433bd44e3990e7da2402fabc44a3c6c25d3704da066b85b89a85474ae0", size = 3224312, upload-time = "2026-03-02T15:50:42.996Z" }, + { url = "https://files.pythonhosted.org/packages/88/11/fbaf1ae91fa4ee43f4fe79661cead6358644824419c26adb004941bdce7c/sqlalchemy-2.0.48-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8183dc57ae7d9edc1346e007e840a9f3d6aa7b7f165203a99e16f447150140d2", size = 3246304, upload-time = "2026-03-02T15:53:34.937Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a8/5fb0deb13930b4f2f698c5541ae076c18981173e27dd00376dbaea7a9c82/sqlalchemy-2.0.48-cp314-cp314-win32.whl", hash = "sha256:1182437cb2d97988cfea04cf6cdc0b0bb9c74f4d56ec3d08b81e23d621a28cc6", size = 2116565, upload-time = "2026-03-02T15:54:38.321Z" }, + { url = "https://files.pythonhosted.org/packages/95/7e/e83615cb63f80047f18e61e31e8e32257d39458426c23006deeaf48f463b/sqlalchemy-2.0.48-cp314-cp314-win_amd64.whl", hash = "sha256:144921da96c08feb9e2b052c5c5c1d0d151a292c6135623c6b2c041f2a45f9e0", size = 2142205, upload-time = "2026-03-02T15:54:39.831Z" }, + { url = "https://files.pythonhosted.org/packages/83/e3/69d8711b3f2c5135e9cde5f063bc1605860f0b2c53086d40c04017eb1f77/sqlalchemy-2.0.48-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5aee45fd2c6c0f2b9cdddf48c48535e7471e42d6fb81adfde801da0bd5b93241", size = 3563519, upload-time = "2026-03-02T15:57:52.387Z" }, + { url = "https://files.pythonhosted.org/packages/f8/4f/a7cce98facca73c149ea4578981594aaa5fd841e956834931de503359336/sqlalchemy-2.0.48-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7cddca31edf8b0653090cbb54562ca027c421c58ddde2c0685f49ff56a1690e0", size = 3528611, upload-time = "2026-03-02T16:04:42.097Z" }, + { url = "https://files.pythonhosted.org/packages/cd/7d/5936c7a03a0b0cb0fa0cc425998821c6029756b0855a8f7ee70fba1de955/sqlalchemy-2.0.48-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7a936f1bb23d370b7c8cc079d5fce4c7d18da87a33c6744e51a93b0f9e97e9b3", size = 3472326, upload-time = "2026-03-02T15:57:54.423Z" }, + { url = "https://files.pythonhosted.org/packages/f4/33/cea7dfc31b52904efe3dcdc169eb4514078887dff1f5ae28a7f4c5d54b3c/sqlalchemy-2.0.48-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e004aa9248e8cb0a5f9b96d003ca7c1c0a5da8decd1066e7b53f59eb8ce7c62b", size = 3478453, upload-time = "2026-03-02T16:04:44.584Z" }, + { url = "https://files.pythonhosted.org/packages/c8/95/32107c4d13be077a9cae61e9ae49966a35dc4bf442a8852dd871db31f62e/sqlalchemy-2.0.48-cp314-cp314t-win32.whl", hash = "sha256:b8438ec5594980d405251451c5b7ea9aa58dda38eb7ac35fb7e4c696712ee24f", size = 2147209, upload-time = "2026-03-02T15:52:54.274Z" }, + { url = "https://files.pythonhosted.org/packages/d2/d7/1e073da7a4bc645eb83c76067284a0374e643bc4be57f14cc6414656f92c/sqlalchemy-2.0.48-cp314-cp314t-win_amd64.whl", hash = "sha256:d854b3970067297f3a7fbd7a4683587134aa9b3877ee15aa29eea478dc68f933", size = 2182198, upload-time = "2026-03-02T15:52:55.606Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/9664130905f03db57961b8980b05cab624afd114bf2be2576628a9f22da4/sqlalchemy-2.0.48-py3-none-any.whl", hash = "sha256:a66fe406437dd65cacd96a72689a3aaaecaebbcd62d81c5ac1c0fdbeac835096", size = 1940202, upload-time = "2026-03-02T15:52:43.285Z" }, +] + [[package]] name = "sse-starlette" version = "3.2.0" From e51bc9701d1ee851bc1ab9a47f0287fa7c9060c4 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Mon, 23 Mar 2026 10:30:26 -0400 Subject: [PATCH 2/4] feat(server): persist session and interrupt state in database backend (#304) --- docs/guide.md | 11 +- src/opencode_a2a/execution/executor.py | 3 + src/opencode_a2a/execution/session_manager.py | 67 +- src/opencode_a2a/execution/stream_runtime.py | 4 +- src/opencode_a2a/jsonrpc/application.py | 8 +- src/opencode_a2a/opencode_upstream_client.py | 111 +-- src/opencode_a2a/runtime_state.py | 20 + src/opencode_a2a/server/application.py | 53 +- src/opencode_a2a/server/state_store.py | 666 ++++++++++++++++++ src/opencode_a2a/server/task_store.py | 24 +- .../test_extension_contract_consistency.py | 2 +- tests/execution/test_metrics.py | 4 +- tests/execution/test_multipart_input.py | 6 +- ...t_opencode_session_extension_interrupts.py | 14 +- tests/server/test_state_store.py | 82 +++ tests/server/test_transport_contract.py | 4 + tests/support/helpers.py | 14 +- tests/support/streaming_output.py | 6 +- .../test_opencode_upstream_client_params.py | 42 +- 19 files changed, 972 insertions(+), 169 deletions(-) create mode 100644 src/opencode_a2a/runtime_state.py create mode 100644 src/opencode_a2a/server/state_store.py create mode 100644 tests/server/test_state_store.py diff --git a/docs/guide.md b/docs/guide.md index 026ce9c..1097678 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -175,9 +175,14 @@ A2A_TASK_STORE_DATABASE_URL=sqlite+aiosqlite:///./opencode-a2a.db \ opencode-a2a ``` -At the moment, this database-backed store persists task records only. -Session binding state and interrupt request bindings remain in-process runtime -state and are not yet persisted. +When `A2A_TASK_STORE_BACKEND=database`, the service now persists: + +- task records +- session binding / ownership state +- interrupt request bindings and tombstones + +In-flight asyncio locks, outbound A2A client caches, and stream-local +aggregation buffers remain process-local runtime state. ## Troubleshooting Provider Auth State diff --git a/src/opencode_a2a/execution/executor.py b/src/opencode_a2a/execution/executor.py index 707cc89..672b8bc 100644 --- a/src/opencode_a2a/execution/executor.py +++ b/src/opencode_a2a/execution/executor.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from ..server.application import A2AClientManager + from ..server.state_store import SessionStateRepository import httpx from a2a.server.agent_execution import AgentExecutor, RequestContext @@ -531,6 +532,7 @@ def __init__( session_cache_ttl_seconds: int = 3600, session_cache_maxsize: int = 10_000, a2a_client_manager: A2AClientManager | None = None, + session_state_repository: SessionStateRepository | None = None, ) -> None: self._client = client self._streaming_enabled = streaming_enabled @@ -544,6 +546,7 @@ def __init__( client=client, session_cache_ttl_seconds=session_cache_ttl_seconds, session_cache_maxsize=session_cache_maxsize, + state_repository=session_state_repository, ) self._stream_runtime = StreamRuntime( client=client, diff --git a/src/opencode_a2a/execution/session_manager.py b/src/opencode_a2a/execution/session_manager.py index 8a6ce35..9fc7346 100644 --- a/src/opencode_a2a/execution/session_manager.py +++ b/src/opencode_a2a/execution/session_manager.py @@ -1,8 +1,9 @@ from __future__ import annotations import asyncio +from typing import Any, cast -from .stream_state import _TTLCache +from ..server.state_store import MemorySessionStateRepository, SessionStateRepository class SessionManager: @@ -12,18 +13,21 @@ def __init__( client, session_cache_ttl_seconds: int = 3600, session_cache_maxsize: int = 10_000, + state_repository: SessionStateRepository | None = None, ) -> None: self._client = client - self._sessions = _TTLCache( + self._state_repository = state_repository or MemorySessionStateRepository( ttl_seconds=session_cache_ttl_seconds, maxsize=session_cache_maxsize, ) - self._session_owners = _TTLCache( - ttl_seconds=session_cache_ttl_seconds, - maxsize=session_cache_maxsize, - refresh_on_get=True, - ) - self._pending_session_claims: dict[str, str] = {} + if isinstance(self._state_repository, MemorySessionStateRepository): + self._sessions = self._state_repository.sessions + self._session_owners = self._state_repository.session_owners + self._pending_session_claims = self._state_repository.pending_session_claims + else: + self._sessions = cast("Any", None) + self._session_owners = cast("Any", None) + self._pending_session_claims = cast("Any", None) self._lock = asyncio.Lock() self._inflight_session_creates: dict[tuple[str, str], asyncio.Task[str]] = {} self._session_locks: dict[str, asyncio.Lock] = {} @@ -49,7 +53,10 @@ async def get_or_create_session( task: asyncio.Task[str] | None = None cache_key = (identity, context_id) async with self._lock: - existing = self._sessions.get(cache_key) + existing = await self._state_repository.get_session( + identity=cache_key[0], + context_id=cache_key[1], + ) if existing: return existing, False task = self._inflight_session_creates.get(cache_key) @@ -68,14 +75,18 @@ async def get_or_create_session( raise async with self._lock: - owner = self._session_owners.get(session_id) + owner = await self._state_repository.get_owner(session_id=session_id) if owner and owner != identity: if self._inflight_session_creates.get(cache_key) is task: self._inflight_session_creates.pop(cache_key, None) raise PermissionError(f"Session {session_id} is not owned by you") - self._sessions.set(cache_key, session_id) + await self._state_repository.set_session( + identity=cache_key[0], + context_id=cache_key[1], + session_id=session_id, + ) if not owner: - self._session_owners.set(session_id, identity) + await self._state_repository.set_owner(session_id=session_id, identity=identity) if self._inflight_session_creates.get(cache_key) is task: self._inflight_session_creates.pop(cache_key, None) return session_id, False @@ -89,37 +100,45 @@ async def finalize_preferred_session_binding( ) -> None: await self.finalize_session_claim(identity=identity, session_id=session_id) async with self._lock: - self._sessions.set((identity, context_id), session_id) + await self._state_repository.set_session( + identity=identity, + context_id=context_id, + session_id=session_id, + ) async def claim_preferred_session(self, *, identity: str, session_id: str) -> bool: async with self._lock: - owner = self._session_owners.get(session_id) - pending_owner = self._pending_session_claims.get(session_id) + owner = await self._state_repository.get_owner(session_id=session_id) + pending_owner = await self._state_repository.get_pending_claim(session_id=session_id) if owner and owner != identity: raise PermissionError(f"Session {session_id} is not owned by you") if pending_owner and pending_owner != identity: raise PermissionError(f"Session {session_id} is not owned by you") if owner == identity: return False - self._pending_session_claims[session_id] = identity + await self._state_repository.set_pending_claim(session_id=session_id, identity=identity) return True async def finalize_session_claim(self, *, identity: str, session_id: str) -> None: async with self._lock: - owner = self._session_owners.get(session_id) - pending_owner = self._pending_session_claims.get(session_id) + owner = await self._state_repository.get_owner(session_id=session_id) + pending_owner = await self._state_repository.get_pending_claim(session_id=session_id) if owner and owner != identity: raise PermissionError(f"Session {session_id} is not owned by you") if pending_owner and pending_owner != identity: raise PermissionError(f"Session {session_id} is not owned by you") - self._session_owners.set(session_id, identity) - if self._pending_session_claims.get(session_id) == identity: - self._pending_session_claims.pop(session_id, None) + await self._state_repository.set_owner(session_id=session_id, identity=identity) + await self._state_repository.clear_pending_claim( + session_id=session_id, + identity=identity, + ) async def release_preferred_session_claim(self, *, identity: str, session_id: str) -> None: async with self._lock: - if self._pending_session_claims.get(session_id) == identity: - self._pending_session_claims.pop(session_id, None) + await self._state_repository.clear_pending_claim( + session_id=session_id, + identity=identity, + ) async def get_session_lock(self, session_id: str) -> asyncio.Lock: async with self._lock: @@ -136,5 +155,5 @@ async def pop_cached_session( context_id: str, ) -> asyncio.Task[str] | None: async with self._lock: - self._sessions.pop((identity, context_id)) + await self._state_repository.pop_session(identity=identity, context_id=context_id) return self._inflight_session_creates.pop((identity, context_id), None) diff --git a/src/opencode_a2a/execution/stream_runtime.py b/src/opencode_a2a/execution/stream_runtime.py index 33d30bf..59ad183 100644 --- a/src/opencode_a2a/execution/stream_runtime.py +++ b/src/opencode_a2a/execution/stream_runtime.py @@ -429,7 +429,7 @@ def _tool_chunks( None, ) if callable(remember_request): - remember_request( + await remember_request( request_id=request_id, session_id=session_id, interrupt_type=asked["interrupt_type"], @@ -456,7 +456,7 @@ def _tool_chunks( None, ) if callable(discard_request): - discard_request(resolved_request_id) + await discard_request(resolved_request_id) if cleared_pending: await _emit_interrupt_status( state=TaskState.working, diff --git a/src/opencode_a2a/jsonrpc/application.py b/src/opencode_a2a/jsonrpc/application.py index c8c51bb..f7ef7fc 100644 --- a/src/opencode_a2a/jsonrpc/application.py +++ b/src/opencode_a2a/jsonrpc/application.py @@ -776,7 +776,7 @@ async def _handle_interrupt_callback_request( ) resolve_request = getattr(self._upstream_client, "resolve_interrupt_request", None) if callable(resolve_request): - status, binding = resolve_request(request_id) + status, binding = await resolve_request(request_id) if status != "active" or binding is None: return self._generate_error_response( base_request.id, @@ -818,7 +818,7 @@ async def _handle_interrupt_callback_request( else: resolve_session = getattr(self._upstream_client, "resolve_interrupt_session", None) if callable(resolve_session): - if not resolve_session(request_id): + if not await resolve_session(request_id): return self._generate_error_response( base_request.id, interrupt_not_found_error( @@ -869,7 +869,7 @@ async def _handle_interrupt_callback_request( await self._upstream_client.question_reject(request_id, directory=directory) discard_request = getattr(self._upstream_client, "discard_interrupt_request", None) if callable(discard_request): - discard_request(request_id) + await discard_request(request_id) except ValueError as exc: return self._generate_error_response( base_request.id, @@ -880,7 +880,7 @@ async def _handle_interrupt_callback_request( if upstream_status == 404: discard_request = getattr(self._upstream_client, "discard_interrupt_request", None) if callable(discard_request): - discard_request(request_id) + await discard_request(request_id) return self._generate_error_response( base_request.id, interrupt_not_found_error( diff --git a/src/opencode_a2a/opencode_upstream_client.py b/src/opencode_a2a/opencode_upstream_client.py index ffba049..7b64af5 100644 --- a/src/opencode_a2a/opencode_upstream_client.py +++ b/src/opencode_a2a/opencode_upstream_client.py @@ -12,6 +12,8 @@ from .config import Settings from .parts.text import extract_text_from_parts +from .runtime_state import InterruptRequestBinding +from .server.state_store import InterruptRequestRepository, MemoryInterruptRequestRepository _UNSET = object() logger = logging.getLogger(__name__) @@ -29,25 +31,13 @@ class OpencodeMessage: raw: dict[str, Any] -@dataclass(frozen=True) -class InterruptRequestBinding: - request_id: str - session_id: str - interrupt_type: str - identity: str | None - task_id: str | None - context_id: str | None - expires_at: float - - -@dataclass(frozen=True) -class InterruptRequestTombstone: - request_id: str - expires_at: float - - class OpencodeUpstreamClient: - def __init__(self, settings: Settings) -> None: + def __init__( + self, + settings: Settings, + *, + interrupt_request_repository: InterruptRequestRepository | None = None, + ) -> None: self._settings = settings self._base_url = settings.opencode_base_url.rstrip("/") self._directory = settings.opencode_workspace_root @@ -61,10 +51,20 @@ def __init__(self, settings: Settings) -> None: settings.a2a_interrupt_request_tombstone_ttl_seconds ) self._interrupt_request_clock = time.monotonic - self._interrupt_requests: dict[str, InterruptRequestBinding] = {} - self._interrupt_request_tombstones: dict[str, InterruptRequestTombstone] = {} + self._interrupt_request_repository = ( + interrupt_request_repository + or MemoryInterruptRequestRepository( + request_ttl_seconds=self._interrupt_request_ttl_seconds, + tombstone_ttl_seconds=self._interrupt_request_tombstone_ttl_seconds, + clock=self._interrupt_request_clock, + ) + ) self._client = self._build_http_client(self._base_url) + def _sync_interrupt_clock(self) -> None: + if isinstance(self._interrupt_request_repository, MemoryInterruptRequestRepository): + self._interrupt_request_repository._clock = self._interrupt_request_clock + def _build_http_client(self, base_url: str) -> httpx.AsyncClient: return httpx.AsyncClient( base_url=base_url, @@ -157,36 +157,7 @@ async def _post_boolean( ) return self._require_boolean_response(endpoint=endpoint, payload=data) - def _prune_interrupt_requests(self, *, now: float) -> None: - expired = [ - request_id - for request_id, binding in self._interrupt_requests.items() - if binding.expires_at <= now - ] - for request_id in expired: - self._interrupt_requests.pop(request_id, None) - self._remember_interrupt_request_tombstone(request_id, now=now) - - def _prune_interrupt_request_tombstones(self, *, now: float) -> None: - expired = [ - request_id - for request_id, tombstone in self._interrupt_request_tombstones.items() - if tombstone.expires_at <= now - ] - for request_id in expired: - self._interrupt_request_tombstones.pop(request_id, None) - - def _remember_interrupt_request_tombstone(self, request_id: str, *, now: float) -> None: - ttl = self._interrupt_request_tombstone_ttl_seconds - if ttl <= 0: - self._interrupt_request_tombstones.pop(request_id, None) - return - self._interrupt_request_tombstones[request_id] = InterruptRequestTombstone( - request_id=request_id, - expires_at=now + ttl, - ) - - def remember_interrupt_request( + async def remember_interrupt_request( self, *, request_id: str, @@ -202,12 +173,8 @@ def remember_interrupt_request( kind = interrupt_type.strip() if not request or not session or kind not in {"permission", "question"}: return - now = self._interrupt_request_clock() - self._prune_interrupt_requests(now=now) - self._prune_interrupt_request_tombstones(now=now) - ttl = self._interrupt_request_ttl_seconds if ttl_seconds is None else ttl_seconds - expires_at = now + max(0.0, float(ttl)) - self._interrupt_requests[request] = InterruptRequestBinding( + self._sync_interrupt_clock() + await self._interrupt_request_repository.remember( request_id=request, session_id=session, interrupt_type=kind, @@ -216,44 +183,30 @@ def remember_interrupt_request( context_id=( context_id.strip() if isinstance(context_id, str) and context_id.strip() else None ), - expires_at=expires_at, + ttl_seconds=ttl_seconds, ) - self._interrupt_request_tombstones.pop(request, None) - def resolve_interrupt_request( + async def resolve_interrupt_request( self, request_id: str, ) -> tuple[str, InterruptRequestBinding | None]: request = request_id.strip() if not request: return "missing", None - now = self._interrupt_request_clock() - self._prune_interrupt_request_tombstones(now=now) - binding = self._interrupt_requests.get(request) - if binding is None: - if request in self._interrupt_request_tombstones: - return "expired", None - return "missing", None - if binding.expires_at <= now: - self._interrupt_requests.pop(request, None) - self._prune_interrupt_requests(now=now) - self._remember_interrupt_request_tombstone(request, now=now) - return "expired", None - self._prune_interrupt_requests(now=now) - return "active", binding - - def resolve_interrupt_session(self, request_id: str) -> str | None: - status, binding = self.resolve_interrupt_request(request_id) + self._sync_interrupt_clock() + return await self._interrupt_request_repository.resolve(request_id=request) + + async def resolve_interrupt_session(self, request_id: str) -> str | None: + status, binding = await self.resolve_interrupt_request(request_id) if status != "active" or binding is None: return None return binding.session_id - def discard_interrupt_request(self, request_id: str) -> None: + async def discard_interrupt_request(self, request_id: str) -> None: request = request_id.strip() if not request: return - self._interrupt_requests.pop(request, None) - self._interrupt_request_tombstones.pop(request, None) + await self._interrupt_request_repository.discard(request_id=request) @property def stream_timeout(self) -> float | None: diff --git a/src/opencode_a2a/runtime_state.py b/src/opencode_a2a/runtime_state.py new file mode 100644 index 0000000..39efbc6 --- /dev/null +++ b/src/opencode_a2a/runtime_state.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class InterruptRequestBinding: + request_id: str + session_id: str + interrupt_type: str + identity: str | None + task_id: str | None + context_id: str | None + expires_at: float + + +@dataclass(frozen=True) +class InterruptRequestTombstone: + request_id: str + expires_at: float diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index f14c2aa..39fc7b2 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -2,6 +2,7 @@ import asyncio import hashlib +import inspect import logging import secrets from contextlib import asynccontextmanager @@ -78,7 +79,12 @@ _request_body_too_large_response, _RequestBodyTooLargeError, ) -from .task_store import build_task_store, close_task_store, initialize_task_store +from .state_store import ( + build_interrupt_request_repository, + build_session_state_repository, + initialize_state_repository, +) +from .task_store import build_database_engine, build_task_store, initialize_task_store logger = logging.getLogger(__name__) @@ -488,18 +494,52 @@ def __init__( self.pending_eviction = pending_eviction +def _call_with_optional_kwargs(factory, /, *args, **kwargs): # noqa: ANN001 + try: + return factory(*args, **kwargs) + except TypeError as exc: + signature = inspect.signature(factory) + supported_kwargs = { + name: value for name, value in kwargs.items() if name in signature.parameters + } + if supported_kwargs == kwargs: + raise + try: + return factory(*args, **supported_kwargs) + except TypeError: + raise exc from None + + def create_app(settings: Settings) -> FastAPI: - upstream_client = OpencodeUpstreamClient(settings) + database_engine = ( + build_database_engine(settings) if settings.a2a_task_store_backend == "database" else None + ) + session_state_repository = build_session_state_repository(settings, engine=database_engine) + interrupt_request_repository = build_interrupt_request_repository( + settings, + engine=database_engine, + ) + upstream_client = _call_with_optional_kwargs( + OpencodeUpstreamClient, + settings, + interrupt_request_repository=interrupt_request_repository, + ) client_manager = A2AClientManager(settings) - executor = OpencodeAgentExecutor( + executor = _call_with_optional_kwargs( + OpencodeAgentExecutor, upstream_client, streaming_enabled=True, cancel_abort_timeout_seconds=settings.a2a_cancel_abort_timeout_seconds, session_cache_ttl_seconds=settings.a2a_session_cache_ttl_seconds, session_cache_maxsize=settings.a2a_session_cache_maxsize, a2a_client_manager=client_manager, + session_state_repository=session_state_repository, + ) + task_store = _call_with_optional_kwargs( + build_task_store, + settings, + engine=database_engine, ) - task_store = build_task_store(settings) handler = OpencodeRequestHandler( agent_executor=executor, task_store=task_store, @@ -550,8 +590,11 @@ def create_app(settings: Settings) -> FastAPI: @asynccontextmanager async def lifespan(_app: FastAPI): await initialize_task_store(task_store) + await initialize_state_repository(session_state_repository) + await initialize_state_repository(interrupt_request_repository) yield - await close_task_store(task_store) + if database_engine is not None: + await database_engine.dispose() await client_manager.close_all() await upstream_client.close() diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py new file mode 100644 index 0000000..b6502eb --- /dev/null +++ b/src/opencode_a2a/server/state_store.py @@ -0,0 +1,666 @@ +from __future__ import annotations + +import time +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import TYPE_CHECKING, cast + +from sqlalchemy import ( + Column, + Float, + MetaData, + String, + Table, + and_, + delete, + insert, + select, + update, +) +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from ..config import Settings +from ..execution.stream_state import _TTLCache +from ..runtime_state import InterruptRequestBinding, InterruptRequestTombstone + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncEngine + +_STATE_METADATA = MetaData() + +_SESSION_BINDINGS = Table( + "a2a_session_bindings", + _STATE_METADATA, + Column("identity", String, primary_key=True), + Column("context_id", String, primary_key=True), + Column("session_id", String, nullable=False), + Column("expires_at", Float, nullable=True), + Column("updated_at", Float, nullable=False), +) + +_SESSION_OWNERS = Table( + "a2a_session_owners", + _STATE_METADATA, + Column("session_id", String, primary_key=True), + Column("identity", String, nullable=False), + Column("expires_at", Float, nullable=True), + Column("updated_at", Float, nullable=False), +) + +_PENDING_SESSION_CLAIMS = Table( + "a2a_pending_session_claims", + _STATE_METADATA, + Column("session_id", String, primary_key=True), + Column("identity", String, nullable=False), + Column("updated_at", Float, nullable=False), +) + +_INTERRUPT_REQUESTS = Table( + "a2a_interrupt_requests", + _STATE_METADATA, + Column("request_id", String, primary_key=True), + Column("session_id", String, nullable=True), + Column("interrupt_type", String, nullable=True), + Column("identity", String, nullable=True), + Column("task_id", String, nullable=True), + Column("context_id", String, nullable=True), + Column("expires_at", Float, nullable=True), + Column("tombstone_expires_at", Float, nullable=True), +) + + +class SessionStateRepository(ABC): + @abstractmethod + async def get_session(self, *, identity: str, context_id: str) -> str | None: ... + + @abstractmethod + async def set_session(self, *, identity: str, context_id: str, session_id: str) -> None: ... + + @abstractmethod + async def pop_session(self, *, identity: str, context_id: str) -> None: ... + + @abstractmethod + async def get_owner(self, *, session_id: str) -> str | None: ... + + @abstractmethod + async def set_owner(self, *, session_id: str, identity: str) -> None: ... + + @abstractmethod + async def get_pending_claim(self, *, session_id: str) -> str | None: ... + + @abstractmethod + async def set_pending_claim(self, *, session_id: str, identity: str) -> None: ... + + @abstractmethod + async def clear_pending_claim(self, *, session_id: str, identity: str | None = None) -> None: ... + + +class InterruptRequestRepository(ABC): + @abstractmethod + async def remember( + self, + *, + request_id: str, + session_id: str, + interrupt_type: str, + identity: str | None, + task_id: str | None, + context_id: str | None, + ttl_seconds: float | None, + ) -> None: ... + + @abstractmethod + async def resolve( + self, + *, + request_id: str, + ) -> tuple[str, InterruptRequestBinding | None]: ... + + @abstractmethod + async def discard(self, *, request_id: str) -> None: ... + + +class MemorySessionStateRepository(SessionStateRepository): + def __init__( + self, + *, + ttl_seconds: int, + maxsize: int, + ) -> None: + self.sessions = _TTLCache(ttl_seconds=ttl_seconds, maxsize=maxsize) + self.session_owners = _TTLCache( + ttl_seconds=ttl_seconds, + maxsize=maxsize, + refresh_on_get=True, + ) + self.pending_session_claims: dict[str, str] = {} + + async def get_session(self, *, identity: str, context_id: str) -> str | None: + return self.sessions.get((identity, context_id)) + + async def set_session(self, *, identity: str, context_id: str, session_id: str) -> None: + self.sessions.set((identity, context_id), session_id) + + async def pop_session(self, *, identity: str, context_id: str) -> None: + self.sessions.pop((identity, context_id)) + + async def get_owner(self, *, session_id: str) -> str | None: + return self.session_owners.get(session_id) + + async def set_owner(self, *, session_id: str, identity: str) -> None: + self.session_owners.set(session_id, identity) + + async def get_pending_claim(self, *, session_id: str) -> str | None: + return self.pending_session_claims.get(session_id) + + async def set_pending_claim(self, *, session_id: str, identity: str) -> None: + self.pending_session_claims[session_id] = identity + + async def clear_pending_claim(self, *, session_id: str, identity: str | None = None) -> None: + if identity is None or self.pending_session_claims.get(session_id) == identity: + self.pending_session_claims.pop(session_id, None) + + +class DatabaseSessionStateRepository(SessionStateRepository): + def __init__( + self, + *, + engine: "AsyncEngine", + ttl_seconds: int, + maxsize: int, + clock: Callable[[], float] = time.time, + ) -> None: + self.engine = engine + self._ttl_seconds = int(ttl_seconds) + self._maxsize = int(maxsize) + self._clock = clock + self._initialized = False + self._session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + + async def initialize(self) -> None: + if self._initialized: + return + async with self.engine.begin() as conn: + await conn.run_sync(_STATE_METADATA.create_all) + self._initialized = True + + async def _ensure_initialized(self) -> None: + if not self._initialized: + await self.initialize() + + def _expires_at(self, now: float) -> float | None: + if self._ttl_seconds <= 0: + return None + return now + float(self._ttl_seconds) + + async def _prune_expired( + self, + session: AsyncSession, + *, + now: float, + ) -> None: + await session.execute( + delete(_SESSION_BINDINGS).where( + and_(_SESSION_BINDINGS.c.expires_at.is_not(None), _SESSION_BINDINGS.c.expires_at <= now) + ) + ) + await session.execute( + delete(_SESSION_OWNERS).where( + and_(_SESSION_OWNERS.c.expires_at.is_not(None), _SESSION_OWNERS.c.expires_at <= now) + ) + ) + + async def _prune_overflow(self, session: AsyncSession, *, table: Table) -> None: + if self._maxsize <= 0: + return + count = await session.execute(select(table).order_by(table.c.updated_at.asc())) + rows = count.fetchall() + overflow = len(rows) - self._maxsize + if overflow <= 0: + return + if table is _SESSION_BINDINGS: + for row in rows[:overflow]: + await session.execute( + delete(_SESSION_BINDINGS).where( + and_( + _SESSION_BINDINGS.c.identity == row.identity, + _SESSION_BINDINGS.c.context_id == row.context_id, + ) + ) + ) + return + for row in rows[:overflow]: + await session.execute( + delete(_SESSION_OWNERS).where(_SESSION_OWNERS.c.session_id == row.session_id) + ) + + async def get_session(self, *, identity: str, context_id: str) -> str | None: + await self._ensure_initialized() + now = self._clock() + async with self._session_maker.begin() as session: + await self._prune_expired(session, now=now) + result = await session.execute( + select(_SESSION_BINDINGS.c.session_id).where( + and_( + _SESSION_BINDINGS.c.identity == identity, + _SESSION_BINDINGS.c.context_id == context_id, + ) + ) + ) + return cast("str | None", result.scalar_one_or_none()) + + async def set_session(self, *, identity: str, context_id: str, session_id: str) -> None: + await self._ensure_initialized() + now = self._clock() + expires_at = self._expires_at(now) + async with self._session_maker.begin() as session: + await self._prune_expired(session, now=now) + exists = await session.execute( + select(_SESSION_BINDINGS.c.session_id).where( + and_( + _SESSION_BINDINGS.c.identity == identity, + _SESSION_BINDINGS.c.context_id == context_id, + ) + ) + ) + values = { + "session_id": session_id, + "expires_at": expires_at, + "updated_at": now, + } + if exists.scalar_one_or_none() is None: + await session.execute( + insert(_SESSION_BINDINGS).values( + identity=identity, + context_id=context_id, + **values, + ) + ) + else: + await session.execute( + update(_SESSION_BINDINGS) + .where( + and_( + _SESSION_BINDINGS.c.identity == identity, + _SESSION_BINDINGS.c.context_id == context_id, + ) + ) + .values(**values) + ) + await self._prune_overflow(session, table=_SESSION_BINDINGS) + + async def pop_session(self, *, identity: str, context_id: str) -> None: + await self._ensure_initialized() + async with self._session_maker.begin() as session: + await session.execute( + delete(_SESSION_BINDINGS).where( + and_( + _SESSION_BINDINGS.c.identity == identity, + _SESSION_BINDINGS.c.context_id == context_id, + ) + ) + ) + + async def get_owner(self, *, session_id: str) -> str | None: + await self._ensure_initialized() + now = self._clock() + async with self._session_maker.begin() as session: + await self._prune_expired(session, now=now) + result = await session.execute( + select(_SESSION_OWNERS.c.identity).where(_SESSION_OWNERS.c.session_id == session_id) + ) + owner = cast("str | None", result.scalar_one_or_none()) + if owner is not None: + await session.execute( + update(_SESSION_OWNERS) + .where(_SESSION_OWNERS.c.session_id == session_id) + .values(expires_at=self._expires_at(now), updated_at=now) + ) + return owner + + async def set_owner(self, *, session_id: str, identity: str) -> None: + await self._ensure_initialized() + now = self._clock() + expires_at = self._expires_at(now) + async with self._session_maker.begin() as session: + await self._prune_expired(session, now=now) + exists = await session.execute( + select(_SESSION_OWNERS.c.session_id).where(_SESSION_OWNERS.c.session_id == session_id) + ) + values = { + "identity": identity, + "expires_at": expires_at, + "updated_at": now, + } + if exists.scalar_one_or_none() is None: + await session.execute( + insert(_SESSION_OWNERS).values(session_id=session_id, **values) + ) + else: + await session.execute( + update(_SESSION_OWNERS) + .where(_SESSION_OWNERS.c.session_id == session_id) + .values(**values) + ) + await self._prune_overflow(session, table=_SESSION_OWNERS) + + async def get_pending_claim(self, *, session_id: str) -> str | None: + await self._ensure_initialized() + async with self._session_maker.begin() as session: + result = await session.execute( + select(_PENDING_SESSION_CLAIMS.c.identity).where( + _PENDING_SESSION_CLAIMS.c.session_id == session_id + ) + ) + return cast("str | None", result.scalar_one_or_none()) + + async def set_pending_claim(self, *, session_id: str, identity: str) -> None: + await self._ensure_initialized() + now = self._clock() + async with self._session_maker.begin() as session: + exists = await session.execute( + select(_PENDING_SESSION_CLAIMS.c.session_id).where( + _PENDING_SESSION_CLAIMS.c.session_id == session_id + ) + ) + values = {"identity": identity, "updated_at": now} + if exists.scalar_one_or_none() is None: + await session.execute( + insert(_PENDING_SESSION_CLAIMS).values(session_id=session_id, **values) + ) + else: + await session.execute( + update(_PENDING_SESSION_CLAIMS) + .where(_PENDING_SESSION_CLAIMS.c.session_id == session_id) + .values(**values) + ) + + async def clear_pending_claim(self, *, session_id: str, identity: str | None = None) -> None: + await self._ensure_initialized() + async with self._session_maker.begin() as session: + stmt = delete(_PENDING_SESSION_CLAIMS).where( + _PENDING_SESSION_CLAIMS.c.session_id == session_id + ) + if identity is not None: + stmt = stmt.where(_PENDING_SESSION_CLAIMS.c.identity == identity) + await session.execute(stmt) + + +class MemoryInterruptRequestRepository(InterruptRequestRepository): + def __init__( + self, + *, + request_ttl_seconds: float, + tombstone_ttl_seconds: float, + clock: Callable[[], float] = time.monotonic, + ) -> None: + self._request_ttl_seconds = float(request_ttl_seconds) + self._tombstone_ttl_seconds = float(tombstone_ttl_seconds) + self._clock = clock + self._interrupt_requests: dict[str, InterruptRequestBinding] = {} + self._interrupt_request_tombstones: dict[str, InterruptRequestTombstone] = {} + + def _prune_interrupt_requests(self, *, now: float) -> None: + expired = [ + request_id + for request_id, binding in self._interrupt_requests.items() + if binding.expires_at <= now + ] + for request_id in expired: + self._interrupt_requests.pop(request_id, None) + self._remember_interrupt_request_tombstone(request_id, now=now) + + def _prune_interrupt_request_tombstones(self, *, now: float) -> None: + expired = [ + request_id + for request_id, tombstone in self._interrupt_request_tombstones.items() + if tombstone.expires_at <= now + ] + for request_id in expired: + self._interrupt_request_tombstones.pop(request_id, None) + + def _remember_interrupt_request_tombstone(self, request_id: str, *, now: float) -> None: + ttl = self._tombstone_ttl_seconds + if ttl <= 0: + self._interrupt_request_tombstones.pop(request_id, None) + return + self._interrupt_request_tombstones[request_id] = InterruptRequestTombstone( + request_id=request_id, + expires_at=now + ttl, + ) + + async def remember( + self, + *, + request_id: str, + session_id: str, + interrupt_type: str, + identity: str | None, + task_id: str | None, + context_id: str | None, + ttl_seconds: float | None, + ) -> None: + now = self._clock() + self._prune_interrupt_requests(now=now) + self._prune_interrupt_request_tombstones(now=now) + ttl = self._request_ttl_seconds if ttl_seconds is None else ttl_seconds + self._interrupt_requests[request_id] = InterruptRequestBinding( + request_id=request_id, + session_id=session_id, + interrupt_type=interrupt_type, + identity=identity, + task_id=task_id, + context_id=context_id, + expires_at=now + max(0.0, float(ttl)), + ) + self._interrupt_request_tombstones.pop(request_id, None) + + async def resolve( + self, + *, + request_id: str, + ) -> tuple[str, InterruptRequestBinding | None]: + if not request_id: + return "missing", None + now = self._clock() + self._prune_interrupt_request_tombstones(now=now) + binding = self._interrupt_requests.get(request_id) + if binding is None: + if request_id in self._interrupt_request_tombstones: + return "expired", None + return "missing", None + if binding.expires_at <= now: + self._interrupt_requests.pop(request_id, None) + self._prune_interrupt_requests(now=now) + self._remember_interrupt_request_tombstone(request_id, now=now) + return "expired", None + self._prune_interrupt_requests(now=now) + return "active", binding + + async def discard(self, *, request_id: str) -> None: + self._interrupt_requests.pop(request_id, None) + self._interrupt_request_tombstones.pop(request_id, None) + + +class DatabaseInterruptRequestRepository(InterruptRequestRepository): + def __init__( + self, + *, + engine: "AsyncEngine", + request_ttl_seconds: float, + tombstone_ttl_seconds: float, + clock: Callable[[], float] = time.time, + ) -> None: + self.engine = engine + self._request_ttl_seconds = float(request_ttl_seconds) + self._tombstone_ttl_seconds = float(tombstone_ttl_seconds) + self._clock = clock + self._initialized = False + self._session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + + async def initialize(self) -> None: + if self._initialized: + return + async with self.engine.begin() as conn: + await conn.run_sync(_STATE_METADATA.create_all) + self._initialized = True + + async def _ensure_initialized(self) -> None: + if not self._initialized: + await self.initialize() + + async def _prune_tombstones(self, session: AsyncSession, *, now: float) -> None: + await session.execute( + delete(_INTERRUPT_REQUESTS).where( + and_( + _INTERRUPT_REQUESTS.c.tombstone_expires_at.is_not(None), + _INTERRUPT_REQUESTS.c.tombstone_expires_at <= now, + ) + ) + ) + + async def _set_tombstone(self, session: AsyncSession, *, request_id: str, now: float) -> None: + tombstone_expires_at = ( + None + if self._tombstone_ttl_seconds <= 0 + else now + self._tombstone_ttl_seconds + ) + await session.execute( + update(_INTERRUPT_REQUESTS) + .where(_INTERRUPT_REQUESTS.c.request_id == request_id) + .values( + session_id=None, + interrupt_type=None, + identity=None, + task_id=None, + context_id=None, + expires_at=None, + tombstone_expires_at=tombstone_expires_at, + ) + ) + + async def remember( + self, + *, + request_id: str, + session_id: str, + interrupt_type: str, + identity: str | None, + task_id: str | None, + context_id: str | None, + ttl_seconds: float | None, + ) -> None: + await self._ensure_initialized() + now = self._clock() + ttl = self._request_ttl_seconds if ttl_seconds is None else ttl_seconds + expires_at = now + max(0.0, float(ttl)) + async with self._session_maker.begin() as session: + await self._prune_tombstones(session, now=now) + exists = await session.execute( + select(_INTERRUPT_REQUESTS.c.request_id).where( + _INTERRUPT_REQUESTS.c.request_id == request_id + ) + ) + values = { + "session_id": session_id, + "interrupt_type": interrupt_type, + "identity": identity, + "task_id": task_id, + "context_id": context_id, + "expires_at": expires_at, + "tombstone_expires_at": None, + } + if exists.scalar_one_or_none() is None: + await session.execute(insert(_INTERRUPT_REQUESTS).values(request_id=request_id, **values)) + else: + await session.execute( + update(_INTERRUPT_REQUESTS) + .where(_INTERRUPT_REQUESTS.c.request_id == request_id) + .values(**values) + ) + + async def resolve( + self, + *, + request_id: str, + ) -> tuple[str, InterruptRequestBinding | None]: + if not request_id: + return "missing", None + await self._ensure_initialized() + now = self._clock() + async with self._session_maker.begin() as session: + await self._prune_tombstones(session, now=now) + result = await session.execute( + select(_INTERRUPT_REQUESTS).where(_INTERRUPT_REQUESTS.c.request_id == request_id) + ) + row = result.mappings().one_or_none() + if row is None: + return "missing", None + tombstone_expires_at = row.get("tombstone_expires_at") + if tombstone_expires_at is not None and tombstone_expires_at > now: + return "expired", None + expires_at = row.get("expires_at") + if expires_at is None: + return "missing", None + if expires_at <= now: + await self._set_tombstone(session, request_id=request_id, now=now) + return "expired", None + return ( + "active", + InterruptRequestBinding( + request_id=request_id, + session_id=cast("str", row["session_id"]), + interrupt_type=cast("str", row["interrupt_type"]), + identity=cast("str | None", row["identity"]), + task_id=cast("str | None", row["task_id"]), + context_id=cast("str | None", row["context_id"]), + expires_at=cast("float", expires_at), + ), + ) + + async def discard(self, *, request_id: str) -> None: + await self._ensure_initialized() + async with self._session_maker.begin() as session: + await session.execute( + delete(_INTERRUPT_REQUESTS).where(_INTERRUPT_REQUESTS.c.request_id == request_id) + ) + + +def build_session_state_repository( + settings: Settings, + *, + engine: "AsyncEngine | None" = None, +) -> SessionStateRepository: + if settings.a2a_task_store_backend == "database": + return DatabaseSessionStateRepository( + engine=cast("AsyncEngine", engine), + ttl_seconds=settings.a2a_session_cache_ttl_seconds, + maxsize=settings.a2a_session_cache_maxsize, + ) + return MemorySessionStateRepository( + ttl_seconds=settings.a2a_session_cache_ttl_seconds, + maxsize=settings.a2a_session_cache_maxsize, + ) + + +def build_interrupt_request_repository( + settings: Settings, + *, + engine: "AsyncEngine | None" = None, +) -> InterruptRequestRepository: + if settings.a2a_task_store_backend == "database": + return DatabaseInterruptRequestRepository( + engine=cast("AsyncEngine", engine), + request_ttl_seconds=settings.a2a_interrupt_request_ttl_seconds, + tombstone_ttl_seconds=settings.a2a_interrupt_request_tombstone_ttl_seconds, + ) + return MemoryInterruptRequestRepository( + request_ttl_seconds=settings.a2a_interrupt_request_ttl_seconds, + tombstone_ttl_seconds=settings.a2a_interrupt_request_tombstone_ttl_seconds, + ) + + +async def initialize_state_repository(repository: object) -> None: + initialize = getattr(repository, "initialize", None) + if callable(initialize): + await initialize() diff --git a/src/opencode_a2a/server/task_store.py b/src/opencode_a2a/server/task_store.py index 05a42e9..7f5f708 100644 --- a/src/opencode_a2a/server/task_store.py +++ b/src/opencode_a2a/server/task_store.py @@ -17,7 +17,7 @@ class _ConfiguredDatabaseTaskStore(TaskStore): def __init__( self, *, - engine: "AsyncEngine", + engine: AsyncEngine, create_table: bool, table_name: str, ) -> None: @@ -44,7 +44,7 @@ def __init__( self._delegate.task_model = task_model @property - def engine(self) -> "AsyncEngine": + def engine(self) -> AsyncEngine: return self._delegate.engine async def initialize(self) -> None: @@ -60,21 +60,29 @@ async def delete(self, task_id, context=None) -> None: # noqa: ANN001 await self._delegate.delete(task_id, context) -def build_task_store(settings: Settings) -> TaskStore: +def build_task_store( + settings: Settings, + *, + engine: AsyncEngine | None = None, +) -> TaskStore: if settings.a2a_task_store_backend == "memory": return InMemoryTaskStore() - from sqlalchemy.ext.asyncio import create_async_engine - - database_url = cast(str, settings.a2a_task_store_database_url) - engine = create_async_engine(database_url) + resolved_engine = engine or build_database_engine(settings) return _ConfiguredDatabaseTaskStore( - engine=engine, + engine=resolved_engine, create_table=settings.a2a_task_store_create_table, table_name=settings.a2a_task_store_table_name, ) +def build_database_engine(settings: Settings) -> AsyncEngine: + from sqlalchemy.ext.asyncio import create_async_engine + + database_url = cast(str, settings.a2a_task_store_database_url) + return create_async_engine(database_url) + + async def initialize_task_store(task_store: TaskStore) -> None: initialize = getattr(task_store, "initialize", None) if callable(initialize): diff --git a/tests/contracts/test_extension_contract_consistency.py b/tests/contracts/test_extension_contract_consistency.py index d5d9e7b..89e78e0 100644 --- a/tests/contracts/test_extension_contract_consistency.py +++ b/tests/contracts/test_extension_contract_consistency.py @@ -301,7 +301,7 @@ async def test_extension_notification_contracts_return_204( if interrupt_type is not None: request_id = params["request_id"] assert isinstance(request_id, str) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id=request_id, session_id="s-1", interrupt_type=interrupt_type, diff --git a/tests/execution/test_metrics.py b/tests/execution/test_metrics.py index 23f454e..ceb0acd 100644 --- a/tests/execution/test_metrics.py +++ b/tests/execution/test_metrics.py @@ -111,7 +111,7 @@ async def stream_events(self, stop_event=None, *, directory=None): # noqa: ANN0 }, } - def remember_interrupt_request( + async def remember_interrupt_request( self, *, request_id: str, @@ -125,7 +125,7 @@ def remember_interrupt_request( del interrupt_type, identity, task_id, context_id, ttl_seconds self._interrupt_requests[request_id] = session_id - def discard_interrupt_request(self, request_id: str) -> None: + async def discard_interrupt_request(self, request_id: str) -> None: self._interrupt_requests.pop(request_id, None) executor = OpencodeAgentExecutor(_Client(), streaming_enabled=True) diff --git a/tests/execution/test_multipart_input.py b/tests/execution/test_multipart_input.py index 7f269a1..7c1c58e 100644 --- a/tests/execution/test_multipart_input.py +++ b/tests/execution/test_multipart_input.py @@ -62,14 +62,14 @@ async def stream_events(self, stop_event=None, *, directory: str | None = None): for _ in (): yield {} - def remember_interrupt_request(self, **_kwargs) -> None: + async def remember_interrupt_request(self, **_kwargs) -> None: return None - def resolve_interrupt_session(self, request_id: str) -> str | None: + async def resolve_interrupt_session(self, request_id: str) -> str | None: del request_id return None - def discard_interrupt_request(self, request_id: str) -> None: + async def discard_interrupt_request(self, request_id: str) -> None: del request_id diff --git a/tests/jsonrpc/test_opencode_session_extension_interrupts.py b/tests/jsonrpc/test_opencode_session_extension_interrupts.py index caed29c..49a8ca4 100644 --- a/tests/jsonrpc/test_opencode_session_extension_interrupts.py +++ b/tests/jsonrpc/test_opencode_session_extension_interrupts.py @@ -39,7 +39,7 @@ async def permission_reply( dummy = InterruptClient( make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) ) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id="perm-1", session_id="ses-1", interrupt_type="permission", @@ -183,12 +183,12 @@ async def question_reject( dummy = InterruptClient( make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) ) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id="q-1", session_id="ses-1", interrupt_type="question", ) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id="q-2", session_id="ses-1", interrupt_type="question", @@ -269,7 +269,7 @@ async def permission_reply( settings = make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) dummy = NotFoundInterruptClient(settings) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id="perm-404", session_id="ses-1", interrupt_type="permission", @@ -302,7 +302,7 @@ async def test_interrupt_callback_extension_rejects_expired_request(monkeypatch) import opencode_a2a.server.application as app_module class ExpiredInterruptClient(DummyOpencodeUpstreamClient): - def resolve_interrupt_request(self, request_id: str): + async def resolve_interrupt_request(self, request_id: str): del request_id return "expired", None @@ -387,7 +387,7 @@ class InterruptClient(DummyOpencodeUpstreamClient): dummy = InterruptClient( make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) ) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id="q-only", session_id="ses-1", interrupt_type="question", @@ -425,7 +425,7 @@ class InterruptClient(DummyOpencodeUpstreamClient): dummy = InterruptClient( make_settings(a2a_bearer_token="t-1", a2a_log_payloads=False, **_BASE_SETTINGS) ) - dummy.remember_interrupt_request( + await dummy.remember_interrupt_request( request_id="perm-owned", session_id="ses-1", interrupt_type="permission", diff --git a/tests/server/test_state_store.py b/tests/server/test_state_store.py new file mode 100644 index 0000000..92ebc01 --- /dev/null +++ b/tests/server/test_state_store.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from opencode_a2a.server.state_store import ( + build_interrupt_request_repository, + build_session_state_repository, + initialize_state_repository, +) +from opencode_a2a.server.task_store import build_database_engine +from tests.support.helpers import make_settings + + +@pytest.mark.asyncio +async def test_database_session_state_repository_persists_bindings(tmp_path: Path) -> None: + database_url = f"sqlite+aiosqlite:///{tmp_path / 'state.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="database", + a2a_task_store_database_url=database_url, + ) + engine = build_database_engine(settings) + + writer = build_session_state_repository(settings, engine=engine) + await initialize_state_repository(writer) + await writer.set_session(identity="user-1", context_id="ctx-1", session_id="ses-1") + await writer.set_owner(session_id="ses-1", identity="user-1") + await writer.set_pending_claim(session_id="ses-2", identity="user-2") + await engine.dispose() + + engine = build_database_engine(settings) + reader = build_session_state_repository(settings, engine=engine) + await initialize_state_repository(reader) + + assert await reader.get_session(identity="user-1", context_id="ctx-1") == "ses-1" + assert await reader.get_owner(session_id="ses-1") == "user-1" + assert await reader.get_pending_claim(session_id="ses-2") == "user-2" + + await engine.dispose() + + +@pytest.mark.asyncio +async def test_database_interrupt_request_repository_persists_active_binding( + tmp_path: Path, +) -> None: + database_url = f"sqlite+aiosqlite:///{tmp_path / 'interrupt.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="database", + a2a_task_store_database_url=database_url, + ) + engine = build_database_engine(settings) + + writer = build_interrupt_request_repository(settings, engine=engine) + await initialize_state_repository(writer) + await writer.remember( + request_id="perm-1", + session_id="ses-1", + interrupt_type="permission", + identity="user-1", + task_id="task-1", + context_id="ctx-1", + ttl_seconds=30.0, + ) + await engine.dispose() + + engine = build_database_engine(settings) + reader = build_interrupt_request_repository(settings, engine=engine) + await initialize_state_repository(reader) + status, binding = await reader.resolve(request_id="perm-1") + + assert status == "active" + assert binding is not None + assert binding.session_id == "ses-1" + assert binding.interrupt_type == "permission" + assert binding.identity == "user-1" + assert binding.task_id == "task-1" + assert binding.context_id == "ctx-1" + + await engine.dispose() diff --git a/tests/server/test_transport_contract.py b/tests/server/test_transport_contract.py index a76d7ec..88384de 100644 --- a/tests/server/test_transport_contract.py +++ b/tests/server/test_transport_contract.py @@ -486,12 +486,14 @@ def __init__( session_cache_ttl_seconds: int, session_cache_maxsize: int, a2a_client_manager: object = None, + session_state_repository: object = None, ) -> None: captured["streaming_enabled"] = streaming_enabled captured["cancel_abort_timeout_seconds"] = cancel_abort_timeout_seconds captured["session_cache_ttl_seconds"] = session_cache_ttl_seconds captured["session_cache_maxsize"] = session_cache_maxsize captured["a2a_client_manager"] = a2a_client_manager + captured["session_state_repository"] = session_state_repository async def execute(self, _context, _event_queue) -> None: # noqa: ANN001 raise NotImplementedError @@ -564,6 +566,7 @@ def __init__( session_cache_ttl_seconds: int, session_cache_maxsize: int, a2a_client_manager: object = None, + session_state_repository: object = None, ) -> None: del ( streaming_enabled, @@ -571,6 +574,7 @@ def __init__( session_cache_ttl_seconds, session_cache_maxsize, a2a_client_manager, + session_state_repository, ) self._session_manager = types.SimpleNamespace( finalize_session_claim=AsyncMock(), diff --git a/tests/support/helpers.py b/tests/support/helpers.py index 9cca385..2782570 100644 --- a/tests/support/helpers.py +++ b/tests/support/helpers.py @@ -165,7 +165,7 @@ async def stream_events(self, stop_event=None, *, directory: str | None = None): for _ in (): yield {} - def remember_interrupt_request( + async def remember_interrupt_request( self, *, request_id: str, @@ -178,11 +178,11 @@ def remember_interrupt_request( ) -> None: del request_id, session_id, interrupt_type, identity, task_id, context_id, ttl_seconds - def resolve_interrupt_session(self, request_id: str) -> str | None: + async def resolve_interrupt_session(self, request_id: str) -> str | None: del request_id return None - def discard_interrupt_request(self, request_id: str) -> None: + async def discard_interrupt_request(self, request_id: str) -> None: del request_id @@ -317,7 +317,7 @@ async def list_provider_catalog(self, *, directory: str | None = None): del directory return self.provider_catalog_payload - def remember_interrupt_request( + async def remember_interrupt_request( self, *, request_id: str, @@ -337,7 +337,7 @@ def remember_interrupt_request( "context_id": context_id, } - def resolve_interrupt_request(self, request_id: str): + async def resolve_interrupt_request(self, request_id: str): payload = self._interrupt_requests.get(request_id) if payload is None: return "missing", None @@ -352,13 +352,13 @@ def __init__(self, data: dict[str, str | None]) -> None: return "active", _Binding(payload) - def resolve_interrupt_session(self, request_id: str) -> str | None: + async def resolve_interrupt_session(self, request_id: str) -> str | None: payload = self._interrupt_requests.get(request_id) if payload is None: return None return payload.get("session_id") - def discard_interrupt_request(self, request_id: str) -> None: + async def discard_interrupt_request(self, request_id: str) -> None: self._interrupt_requests.pop(request_id, None) async def permission_reply( diff --git a/tests/support/streaming_output.py b/tests/support/streaming_output.py index 8c0b312..4cf8727 100644 --- a/tests/support/streaming_output.py +++ b/tests/support/streaming_output.py @@ -88,7 +88,7 @@ async def stream_events(self, stop_event=None, *, directory: str | None = None): ): yield {"type": "session.idle", "properties": {"sessionID": "ses-1"}} - def remember_interrupt_request( + async def remember_interrupt_request( self, *, request_id: str, @@ -102,10 +102,10 @@ def remember_interrupt_request( del interrupt_type, identity, task_id, context_id, ttl_seconds self._interrupt_sessions[request_id] = session_id - def resolve_interrupt_session(self, request_id: str) -> str | None: + async def resolve_interrupt_session(self, request_id: str) -> str | None: return self._interrupt_sessions.get(request_id) - def discard_interrupt_request(self, request_id: str) -> None: + async def discard_interrupt_request(self, request_id: str) -> None: self._interrupt_sessions.pop(request_id, None) diff --git a/tests/upstream/test_opencode_upstream_client_params.py b/tests/upstream/test_opencode_upstream_client_params.py index d294b84..5b64e31 100644 --- a/tests/upstream/test_opencode_upstream_client_params.py +++ b/tests/upstream/test_opencode_upstream_client_params.py @@ -518,7 +518,7 @@ async def test_interrupt_request_binding_expires_after_ttl() -> None: now = 1000.0 client._interrupt_request_clock = lambda: now # type: ignore[method-assign] - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id="perm-1", session_id="ses-1", interrupt_type="permission", @@ -528,21 +528,21 @@ async def test_interrupt_request_binding_expires_after_ttl() -> None: ttl_seconds=5.0, ) - status, binding = client.resolve_interrupt_request("perm-1") + status, binding = await client.resolve_interrupt_request("perm-1") assert status == "active" assert binding is not None assert binding.session_id == "ses-1" assert binding.interrupt_type == "permission" now = 1006.0 - status, binding = client.resolve_interrupt_request("perm-1") + status, binding = await client.resolve_interrupt_request("perm-1") assert status == "expired" assert binding is None - assert client.resolve_interrupt_session("perm-1") is None - assert client.resolve_interrupt_request("perm-1") == ("expired", None) + assert await client.resolve_interrupt_session("perm-1") is None + assert await client.resolve_interrupt_request("perm-1") == ("expired", None) now = 1009.0 - status, binding = client.resolve_interrupt_request("perm-1") + status, binding = await client.resolve_interrupt_request("perm-1") assert status == "missing" assert binding is None @@ -563,7 +563,7 @@ async def test_interrupt_request_prune_keeps_expired_tombstone() -> None: now = 100.0 client._interrupt_request_clock = lambda: now # type: ignore[method-assign] - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id="perm-1", session_id="ses-1", interrupt_type="permission", @@ -571,18 +571,18 @@ async def test_interrupt_request_prune_keeps_expired_tombstone() -> None: ) now = 103.0 - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id="perm-2", session_id="ses-2", interrupt_type="permission", ttl_seconds=10.0, ) - assert client.resolve_interrupt_request("perm-1") == ("expired", None) - assert client.resolve_interrupt_request("perm-2")[0] == "active" + assert await client.resolve_interrupt_request("perm-1") == ("expired", None) + assert (await client.resolve_interrupt_request("perm-2"))[0] == "active" now = 109.0 - assert client.resolve_interrupt_request("perm-1") == ("missing", None) + assert await client.resolve_interrupt_request("perm-1") == ("missing", None) await client.close() @@ -851,25 +851,25 @@ async def test_interrupt_request_helpers_ignore_invalid_and_trim_values() -> Non ) ) - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id=" ", session_id="ses-1", interrupt_type="permission", ) - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id="perm-1", session_id=" ", interrupt_type="permission", ) - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id="perm-2", session_id="ses-2", interrupt_type="unsupported", ) - assert client.resolve_interrupt_request("perm-1") == ("missing", None) + assert await client.resolve_interrupt_request("perm-1") == ("missing", None) - client.remember_interrupt_request( + await client.remember_interrupt_request( request_id=" perm-3 ", session_id=" ses-3 ", interrupt_type=" question ", @@ -877,7 +877,7 @@ async def test_interrupt_request_helpers_ignore_invalid_and_trim_values() -> Non task_id=" task-1 ", context_id=" ctx-1 ", ) - status, binding = client.resolve_interrupt_request("perm-3") + status, binding = await client.resolve_interrupt_request("perm-3") assert status == "active" assert binding is not None assert binding.request_id == "perm-3" @@ -886,9 +886,9 @@ async def test_interrupt_request_helpers_ignore_invalid_and_trim_values() -> Non assert binding.task_id == "task-1" assert binding.context_id == "ctx-1" - assert client.resolve_interrupt_request(" ") == ("missing", None) - client.discard_interrupt_request(" ") - client.discard_interrupt_request("perm-3") - assert client.resolve_interrupt_session("perm-3") is None + assert await client.resolve_interrupt_request(" ") == ("missing", None) + await client.discard_interrupt_request(" ") + await client.discard_interrupt_request("perm-3") + assert await client.resolve_interrupt_session("perm-3") is None await client.close() From 33285231f13040370c37dbc2f4a53f68c14255c0 Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Mon, 23 Mar 2026 10:46:20 -0400 Subject: [PATCH 3/4] test(server): cover app restart persistence for database backend (#304) --- src/opencode_a2a/server/application.py | 3 + src/opencode_a2a/server/state_store.py | 52 +++-- tests/server/test_database_app_persistence.py | 211 ++++++++++++++++++ 3 files changed, 251 insertions(+), 15 deletions(-) create mode 100644 tests/server/test_database_app_persistence.py diff --git a/src/opencode_a2a/server/application.py b/src/opencode_a2a/server/application.py index 39fc7b2..5641555 100644 --- a/src/opencode_a2a/server/application.py +++ b/src/opencode_a2a/server/application.py @@ -607,6 +607,9 @@ async def lifespan(_app: FastAPI): for route, callback in rest_adapter.routes().items(): app.add_api_route(route[0], callback, methods=[route[1]]) app.state._jsonrpc_app = jsonrpc_app + app.state.task_store = task_store + app.state.agent_executor = executor + app.state.upstream_client = upstream_client app.state.a2a_client_manager = client_manager _patch_jsonrpc_openapi_contract(app, settings, runtime_profile=runtime_profile) diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py index b6502eb..2f8767c 100644 --- a/src/opencode_a2a/server/state_store.py +++ b/src/opencode_a2a/server/state_store.py @@ -92,7 +92,12 @@ async def get_pending_claim(self, *, session_id: str) -> str | None: ... async def set_pending_claim(self, *, session_id: str, identity: str) -> None: ... @abstractmethod - async def clear_pending_claim(self, *, session_id: str, identity: str | None = None) -> None: ... + async def clear_pending_claim( + self, + *, + session_id: str, + identity: str | None = None, + ) -> None: ... class InterruptRequestRepository(ABC): @@ -165,7 +170,7 @@ class DatabaseSessionStateRepository(SessionStateRepository): def __init__( self, *, - engine: "AsyncEngine", + engine: AsyncEngine, ttl_seconds: int, maxsize: int, clock: Callable[[], float] = time.time, @@ -175,7 +180,9 @@ def __init__( self._maxsize = int(maxsize) self._clock = clock self._initialized = False - self._session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + self._session_maker = async_sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) async def initialize(self) -> None: if self._initialized: @@ -201,12 +208,18 @@ async def _prune_expired( ) -> None: await session.execute( delete(_SESSION_BINDINGS).where( - and_(_SESSION_BINDINGS.c.expires_at.is_not(None), _SESSION_BINDINGS.c.expires_at <= now) + and_( + _SESSION_BINDINGS.c.expires_at.is_not(None), + _SESSION_BINDINGS.c.expires_at <= now, + ) ) ) await session.execute( delete(_SESSION_OWNERS).where( - and_(_SESSION_OWNERS.c.expires_at.is_not(None), _SESSION_OWNERS.c.expires_at <= now) + and_( + _SESSION_OWNERS.c.expires_at.is_not(None), + _SESSION_OWNERS.c.expires_at <= now, + ) ) ) @@ -325,7 +338,9 @@ async def set_owner(self, *, session_id: str, identity: str) -> None: async with self._session_maker.begin() as session: await self._prune_expired(session, now=now) exists = await session.execute( - select(_SESSION_OWNERS.c.session_id).where(_SESSION_OWNERS.c.session_id == session_id) + select(_SESSION_OWNERS.c.session_id).where( + _SESSION_OWNERS.c.session_id == session_id + ) ) values = { "identity": identity, @@ -375,7 +390,12 @@ async def set_pending_claim(self, *, session_id: str, identity: str) -> None: .values(**values) ) - async def clear_pending_claim(self, *, session_id: str, identity: str | None = None) -> None: + async def clear_pending_claim( + self, + *, + session_id: str, + identity: str | None = None, + ) -> None: await self._ensure_initialized() async with self._session_maker.begin() as session: stmt = delete(_PENDING_SESSION_CLAIMS).where( @@ -486,7 +506,7 @@ class DatabaseInterruptRequestRepository(InterruptRequestRepository): def __init__( self, *, - engine: "AsyncEngine", + engine: AsyncEngine, request_ttl_seconds: float, tombstone_ttl_seconds: float, clock: Callable[[], float] = time.time, @@ -496,7 +516,9 @@ def __init__( self._tombstone_ttl_seconds = float(tombstone_ttl_seconds) self._clock = clock self._initialized = False - self._session_maker = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) + self._session_maker = async_sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) async def initialize(self) -> None: if self._initialized: @@ -521,9 +543,7 @@ async def _prune_tombstones(self, session: AsyncSession, *, now: float) -> None: async def _set_tombstone(self, session: AsyncSession, *, request_id: str, now: float) -> None: tombstone_expires_at = ( - None - if self._tombstone_ttl_seconds <= 0 - else now + self._tombstone_ttl_seconds + None if self._tombstone_ttl_seconds <= 0 else now + self._tombstone_ttl_seconds ) await session.execute( update(_INTERRUPT_REQUESTS) @@ -571,7 +591,9 @@ async def remember( "tombstone_expires_at": None, } if exists.scalar_one_or_none() is None: - await session.execute(insert(_INTERRUPT_REQUESTS).values(request_id=request_id, **values)) + await session.execute( + insert(_INTERRUPT_REQUESTS).values(request_id=request_id, **values) + ) else: await session.execute( update(_INTERRUPT_REQUESTS) @@ -629,7 +651,7 @@ async def discard(self, *, request_id: str) -> None: def build_session_state_repository( settings: Settings, *, - engine: "AsyncEngine | None" = None, + engine: AsyncEngine | None = None, ) -> SessionStateRepository: if settings.a2a_task_store_backend == "database": return DatabaseSessionStateRepository( @@ -646,7 +668,7 @@ def build_session_state_repository( def build_interrupt_request_repository( settings: Settings, *, - engine: "AsyncEngine | None" = None, + engine: AsyncEngine | None = None, ) -> InterruptRequestRepository: if settings.a2a_task_store_backend == "database": return DatabaseInterruptRequestRepository( diff --git a/tests/server/test_database_app_persistence.py b/tests/server/test_database_app_persistence.py new file mode 100644 index 0000000..484bc01 --- /dev/null +++ b/tests/server/test_database_app_persistence.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +from pathlib import Path + +import httpx +import pytest +from a2a.types import Task, TaskState, TaskStatus + +from opencode_a2a.opencode_upstream_client import OpencodeMessage +from tests.support.helpers import make_settings + + +def _task(task_id: str, *, context_id: str = "ctx-1") -> Task: + return Task( + id=task_id, + contextId=context_id, + status=TaskStatus(state=TaskState.working), + ) + + +def _task_store_from_app(app): # noqa: ANN001 + return app.state.task_store + + +def _executor_from_app(app): # noqa: ANN001 + return app.state.agent_executor + + +@pytest.mark.asyncio +async def test_database_backend_persists_task_session_and_interrupt_state_across_app_restart( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + import opencode_a2a.server.application as app_module + + class PersistentStateDummyClient: + created_sessions = 0 + permission_reply_calls: list[dict[str, str | None]] = [] + + def __init__(self, settings, *, interrupt_request_repository=None) -> None: # noqa: ANN001 + self.settings = settings + self.directory = settings.opencode_workspace_root + self.stream_timeout = None + self._interrupt_request_repository = interrupt_request_repository + + async def close(self) -> None: + return None + + async def create_session( + self, + title: str | None = None, + *, + directory: str | None = None, + ) -> str: + del title, directory + type(self).created_sessions += 1 + return f"ses-{type(self).created_sessions}" + + async def send_message( + self, + session_id: str, + text: str | None = None, + *, + parts=None, # noqa: ANN001 + directory: str | None = None, + model_override=None, # noqa: ANN001 + timeout_override=None, # noqa: ANN001 + ) -> OpencodeMessage: + del text, parts, directory, model_override, timeout_override + return OpencodeMessage( + text="ok", + session_id=session_id, + message_id="m-1", + raw={}, + ) + + async def remember_interrupt_request( + self, + *, + request_id: str, + session_id: str, + interrupt_type: str, + identity: str | None = None, + task_id: str | None = None, + context_id: str | None = None, + ttl_seconds: float | None = None, + ) -> None: + assert self._interrupt_request_repository is not None + await self._interrupt_request_repository.remember( + request_id=request_id, + session_id=session_id, + interrupt_type=interrupt_type, + identity=identity, + task_id=task_id, + context_id=context_id, + ttl_seconds=ttl_seconds, + ) + + async def resolve_interrupt_request(self, request_id: str): + assert self._interrupt_request_repository is not None + return await self._interrupt_request_repository.resolve(request_id=request_id) + + async def resolve_interrupt_session(self, request_id: str) -> str | None: + status, binding = await self.resolve_interrupt_request(request_id) + if status != "active" or binding is None: + return None + return binding.session_id + + async def discard_interrupt_request(self, request_id: str) -> None: + assert self._interrupt_request_repository is not None + await self._interrupt_request_repository.discard(request_id=request_id) + + async def permission_reply( + self, + request_id: str, + *, + reply: str, + message: str | None = None, + directory: str | None = None, + ) -> bool: + type(self).permission_reply_calls.append( + { + "request_id": request_id, + "reply": reply, + "message": message, + "directory": directory, + } + ) + return True + + PersistentStateDummyClient.created_sessions = 0 + PersistentStateDummyClient.permission_reply_calls = [] + monkeypatch.setattr(app_module, "OpencodeUpstreamClient", PersistentStateDummyClient) + + database_url = f"sqlite+aiosqlite:///{tmp_path / 'app-state.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="database", + a2a_task_store_database_url=database_url, + ) + + app1 = app_module.create_app(settings) + async with app1.router.lifespan_context(app1): + task_store = _task_store_from_app(app1) + executor = _executor_from_app(app1) + upstream_client = app1.state._jsonrpc_app._upstream_client + + await task_store.save(_task("task-1")) + session_id, pending = await executor._session_manager.get_or_create_session( + identity="user-1", + context_id="ctx-1", + title="hello", + ) + assert pending is False + assert session_id == "ses-1" + await upstream_client.remember_interrupt_request( + request_id="perm-1", + session_id=session_id, + interrupt_type="permission", + identity=None, + task_id="task-1", + context_id="ctx-1", + ttl_seconds=60.0, + ) + + app2 = app_module.create_app(settings) + async with app2.router.lifespan_context(app2): + task_store = _task_store_from_app(app2) + executor = _executor_from_app(app2) + + restored_task = await task_store.get("task-1") + assert restored_task is not None + assert restored_task.id == "task-1" + + restored_session_id, pending = await executor._session_manager.get_or_create_session( + identity="user-1", + context_id="ctx-1", + title="hello again", + ) + assert pending is False + assert restored_session_id == "ses-1" + assert PersistentStateDummyClient.created_sessions == 1 + + transport = httpx.ASGITransport(app=app2) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/", + headers={"Authorization": "Bearer test-token"}, + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "a2a.interrupt.permission.reply", + "params": { + "request_id": "perm-1", + "reply": "once", + }, + }, + ) + + payload = response.json() + assert payload.get("error") is None + assert payload["result"]["ok"] is True + assert payload["result"]["request_id"] == "perm-1" + assert PersistentStateDummyClient.permission_reply_calls == [ + { + "request_id": "perm-1", + "reply": "once", + "message": None, + "directory": None, + } + ] From 8a5889d1698a20fa6dde557140326a85c60df32e Mon Sep 17 00:00:00 2001 From: "helen@cloud" Date: Mon, 23 Mar 2026 10:54:06 -0400 Subject: [PATCH 4/4] fix(server): align database state tables with create-table config (#304) --- docs/guide.md | 2 +- src/opencode_a2a/server/state_store.py | 16 +++++++++---- src/opencode_a2a/server/task_store.py | 6 ----- tests/server/test_state_store.py | 32 +++++++++++++++++++++++++ tests/server/test_task_store_factory.py | 5 ++-- 5 files changed, 47 insertions(+), 14 deletions(-) diff --git a/docs/guide.md b/docs/guide.md index 1097678..0e66098 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -82,7 +82,7 @@ Key variables to understand protocol behavior: `A2A_TASK_STORE_BACKEND=database`. For local persistence, prefer `sqlite+aiosqlite:///./opencode-a2a.db`. - `A2A_TASK_STORE_TABLE_NAME` / `A2A_TASK_STORE_CREATE_TABLE`: database task - store table name and whether to auto-create it on startup. + store table name and whether to auto-create database tables on startup. - Runtime authentication is bearer-token only via `A2A_BEARER_TOKEN`. - The same outbound client flags are also honored by the server-side embedded A2A client used for peer calls and `a2a_call` tool execution: diff --git a/src/opencode_a2a/server/state_store.py b/src/opencode_a2a/server/state_store.py index 2f8767c..d41f7aa 100644 --- a/src/opencode_a2a/server/state_store.py +++ b/src/opencode_a2a/server/state_store.py @@ -173,11 +173,13 @@ def __init__( engine: AsyncEngine, ttl_seconds: int, maxsize: int, + create_tables: bool = True, clock: Callable[[], float] = time.time, ) -> None: self.engine = engine self._ttl_seconds = int(ttl_seconds) self._maxsize = int(maxsize) + self._create_tables = bool(create_tables) self._clock = clock self._initialized = False self._session_maker = async_sessionmaker( @@ -187,8 +189,9 @@ def __init__( async def initialize(self) -> None: if self._initialized: return - async with self.engine.begin() as conn: - await conn.run_sync(_STATE_METADATA.create_all) + if self._create_tables: + async with self.engine.begin() as conn: + await conn.run_sync(_STATE_METADATA.create_all) self._initialized = True async def _ensure_initialized(self) -> None: @@ -509,11 +512,13 @@ def __init__( engine: AsyncEngine, request_ttl_seconds: float, tombstone_ttl_seconds: float, + create_tables: bool = True, clock: Callable[[], float] = time.time, ) -> None: self.engine = engine self._request_ttl_seconds = float(request_ttl_seconds) self._tombstone_ttl_seconds = float(tombstone_ttl_seconds) + self._create_tables = bool(create_tables) self._clock = clock self._initialized = False self._session_maker = async_sessionmaker( @@ -523,8 +528,9 @@ def __init__( async def initialize(self) -> None: if self._initialized: return - async with self.engine.begin() as conn: - await conn.run_sync(_STATE_METADATA.create_all) + if self._create_tables: + async with self.engine.begin() as conn: + await conn.run_sync(_STATE_METADATA.create_all) self._initialized = True async def _ensure_initialized(self) -> None: @@ -658,6 +664,7 @@ def build_session_state_repository( engine=cast("AsyncEngine", engine), ttl_seconds=settings.a2a_session_cache_ttl_seconds, maxsize=settings.a2a_session_cache_maxsize, + create_tables=settings.a2a_task_store_create_table, ) return MemorySessionStateRepository( ttl_seconds=settings.a2a_session_cache_ttl_seconds, @@ -675,6 +682,7 @@ def build_interrupt_request_repository( engine=cast("AsyncEngine", engine), request_ttl_seconds=settings.a2a_interrupt_request_ttl_seconds, tombstone_ttl_seconds=settings.a2a_interrupt_request_tombstone_ttl_seconds, + create_tables=settings.a2a_task_store_create_table, ) return MemoryInterruptRequestRepository( request_ttl_seconds=settings.a2a_interrupt_request_ttl_seconds, diff --git a/src/opencode_a2a/server/task_store.py b/src/opencode_a2a/server/task_store.py index 7f5f708..9f163d9 100644 --- a/src/opencode_a2a/server/task_store.py +++ b/src/opencode_a2a/server/task_store.py @@ -87,9 +87,3 @@ async def initialize_task_store(task_store: TaskStore) -> None: initialize = getattr(task_store, "initialize", None) if callable(initialize): await initialize() - - -async def close_task_store(task_store: TaskStore) -> None: - engine = cast("AsyncEngine | None", getattr(task_store, "engine", None)) - if engine is not None: - await engine.dispose() diff --git a/tests/server/test_state_store.py b/tests/server/test_state_store.py index 92ebc01..c7bfa9e 100644 --- a/tests/server/test_state_store.py +++ b/tests/server/test_state_store.py @@ -3,6 +3,7 @@ from pathlib import Path import pytest +from sqlalchemy import inspect as sqlalchemy_inspect from opencode_a2a.server.state_store import ( build_interrupt_request_repository, @@ -80,3 +81,34 @@ async def test_database_interrupt_request_repository_persists_active_binding( assert binding.context_id == "ctx-1" await engine.dispose() + + +@pytest.mark.asyncio +async def test_database_state_repositories_skip_auto_create_when_disabled( + tmp_path: Path, +) -> None: + database_url = f"sqlite+aiosqlite:///{tmp_path / 'state-no-create.db'}" + settings = make_settings( + a2a_bearer_token="test-token", + a2a_task_store_backend="database", + a2a_task_store_database_url=database_url, + a2a_task_store_create_table=False, + ) + engine = build_database_engine(settings) + + session_repo = build_session_state_repository(settings, engine=engine) + interrupt_repo = build_interrupt_request_repository(settings, engine=engine) + await initialize_state_repository(session_repo) + await initialize_state_repository(interrupt_repo) + + async with engine.begin() as conn: + table_names = await conn.run_sync( + lambda sync_conn: set(sqlalchemy_inspect(sync_conn).get_table_names()) + ) + + assert "a2a_session_bindings" not in table_names + assert "a2a_session_owners" not in table_names + assert "a2a_pending_session_claims" not in table_names + assert "a2a_interrupt_requests" not in table_names + + await engine.dispose() diff --git a/tests/server/test_task_store_factory.py b/tests/server/test_task_store_factory.py index 2f418a9..52328b4 100644 --- a/tests/server/test_task_store_factory.py +++ b/tests/server/test_task_store_factory.py @@ -8,7 +8,6 @@ from opencode_a2a.server.task_store import ( build_task_store, - close_task_store, initialize_task_store, ) from tests.support.helpers import make_settings @@ -42,7 +41,7 @@ async def test_database_task_store_persists_tasks_across_rebuilds(tmp_path: Path writer = build_task_store(settings) await initialize_task_store(writer) await writer.save(_task("task-1")) - await close_task_store(writer) + await writer.engine.dispose() reader = build_task_store(settings) await initialize_task_store(reader) @@ -53,4 +52,4 @@ async def test_database_task_store_persists_tasks_across_rebuilds(tmp_path: Path assert restored.context_id == "ctx-1" assert restored.status.state == TaskState.working - await close_task_store(reader) + await reader.engine.dispose()