From 2bd87048fec5bb596909bd3448609398bfab2313 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sat, 7 Feb 2026 19:07:56 +0000 Subject: [PATCH 01/11] fix(amgi-types): use sequence for message headers --- packages/amgi-types/src/amgi_types/__init__.py | 6 +++--- packages/asyncfast/tests_asyncfast/test_message.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/amgi-types/src/amgi_types/__init__.py b/packages/amgi-types/src/amgi_types/__init__.py index 56b43f1..c052303 100644 --- a/packages/amgi-types/src/amgi_types/__init__.py +++ b/packages/amgi-types/src/amgi_types/__init__.py @@ -1,7 +1,7 @@ import sys from collections.abc import Awaitable from collections.abc import Callable -from collections.abc import Iterable +from collections.abc import Sequence from typing import Any from typing import Literal from typing import TypedDict @@ -96,7 +96,7 @@ class MessageReceiveEvent(TypedDict): type: Literal["message.receive"] id: str - headers: Iterable[tuple[bytes, bytes]] + headers: Sequence[tuple[bytes, bytes]] payload: NotRequired[bytes | None] bindings: NotRequired[dict[str, dict[str, Any]]] more_messages: NotRequired[bool] @@ -136,7 +136,7 @@ class MessageSendEvent(TypedDict): type: Literal["message.send"] address: str - headers: Iterable[tuple[bytes, bytes]] + headers: Sequence[tuple[bytes, bytes]] payload: NotRequired[bytes | None] bindings: NotRequired[dict[str, dict[str, Any]]] diff --git a/packages/asyncfast/tests_asyncfast/test_message.py b/packages/asyncfast/tests_asyncfast/test_message.py index fc8dfa7..51e0c1b 100644 --- a/packages/asyncfast/tests_asyncfast/test_message.py +++ b/packages/asyncfast/tests_asyncfast/test_message.py @@ -1,11 +1,11 @@ from asyncio import Event from collections.abc import AsyncGenerator from collections.abc import Generator -from collections.abc import Iterable from dataclasses import dataclass from typing import Annotated from typing import Any from typing import Optional +from typing import Sequence from unittest.mock import _Call from unittest.mock import AsyncMock from unittest.mock import call @@ -256,7 +256,7 @@ async def topic_handler( ], ) async def test_message_header_optional( - headers: Iterable[tuple[bytes, bytes]], expected_call: _Call + headers: Sequence[tuple[bytes, bytes]], expected_call: _Call ) -> None: app = AsyncFast() @@ -299,7 +299,7 @@ async def topic_handler(id: Annotated[Optional[str], Header()] = None) -> None: ], ) async def test_message_header_default( - headers: Iterable[tuple[bytes, bytes]], expected_call: _Call + headers: Sequence[tuple[bytes, bytes]], expected_call: _Call ) -> None: app = AsyncFast() From b9f96084c002d8e633544540a3e5bb5559efa005 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sat, 7 Feb 2026 19:10:57 +0000 Subject: [PATCH 02/11] test(asyncfast): fix test for header underscore to hyphen as its actually using alias --- packages/asyncfast/tests_asyncfast/test_message.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/asyncfast/tests_asyncfast/test_message.py b/packages/asyncfast/tests_asyncfast/test_message.py index 51e0c1b..4869924 100644 --- a/packages/asyncfast/tests_asyncfast/test_message.py +++ b/packages/asyncfast/tests_asyncfast/test_message.py @@ -187,9 +187,7 @@ async def test_message_header_underscore_to_hyphen() -> None: test_mock = Mock() @app.channel("topic") - async def topic_handler( - idempotency_key: Annotated[UUID, Header(alias="Idempotency-Key")], - ) -> None: + async def topic_handler(idempotency_key: Annotated[UUID, Header()]) -> None: test_mock(idempotency_key) message_scope: MessageScope = { @@ -200,7 +198,7 @@ async def topic_handler( message_receive_event: MessageReceiveEvent = { "type": "message.receive", "id": "id-1", - "headers": [(b"Idempotency-Key", b"8e03978e-40d5-43e8-bc93-6894a57f9324")], + "headers": [(b"idempotency-key", b"8e03978e-40d5-43e8-bc93-6894a57f9324")], } await app( message_scope, From 70fa5e9f0d91c8700938329d72b1cff605d48a43 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sat, 7 Feb 2026 19:15:55 +0000 Subject: [PATCH 03/11] test(asyncfast): add benchmark tests --- .gitignore | 5 +- packages/asyncfast/pyproject.toml | 9 + .../tests_asyncfast/test_benchmarks.py | 543 ++++++++++++++++++ uv.lock | 24 + 4 files changed, 580 insertions(+), 1 deletion(-) create mode 100644 packages/asyncfast/tests_asyncfast/test_benchmarks.py diff --git a/.gitignore b/.gitignore index e00f533..5755b12 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ _build # coverage .coverage .coverage.* -coverage.xml \ No newline at end of file +coverage.xml + +# benchmarks +.benchmarks \ No newline at end of file diff --git a/packages/asyncfast/pyproject.toml b/packages/asyncfast/pyproject.toml index 7fe9e16..48f950c 100644 --- a/packages/asyncfast/pyproject.toml +++ b/packages/asyncfast/pyproject.toml @@ -51,6 +51,7 @@ scripts.asyncfast = "asyncfast.cli:main" dev = [ "pytest>=8.4.1", "pytest-asyncio>=1.1", + "pytest-benchmark>=5.2.3", "pytest-cov>=7.0.0", "pytest-timeout>=2.4.0", ] @@ -60,3 +61,11 @@ workspace = true [tool.uv.sources.asyncfast-cli] workspace = true + +[tool.pytest.ini_options] +asyncio_mode = "auto" +timeout = 10 +timeout_func_only = true +addopts = [ + "--benchmark-max-time=0.1", +] diff --git a/packages/asyncfast/tests_asyncfast/test_benchmarks.py b/packages/asyncfast/tests_asyncfast/test_benchmarks.py new file mode 100644 index 0000000..f7baaf8 --- /dev/null +++ b/packages/asyncfast/tests_asyncfast/test_benchmarks.py @@ -0,0 +1,543 @@ +import asyncio +from dataclasses import dataclass +from typing import Annotated +from typing import Any +from typing import AsyncGenerator +from typing import Callable +from typing import Optional +from typing import Sequence +from uuid import UUID + +import pytest +from amgi_types import AMGIApplication +from amgi_types import AMGISendEvent +from amgi_types import MessageReceiveEvent +from amgi_types import MessageScope +from asyncfast import AsyncFast +from asyncfast import Header +from asyncfast import Message +from asyncfast import MessageSender +from asyncfast import Payload +from asyncfast.bindings import KafkaKey +from pydantic import BaseModel +from pytest_benchmark.fixture import BenchmarkFixture + +AppBenchmark = Callable[ + [AMGIApplication, MessageScope, MessageReceiveEvent], + None, +] + + +@pytest.fixture +def app_benchmark(benchmark: BenchmarkFixture) -> AppBenchmark: + async def _send(message: AMGISendEvent) -> None: + pass + + def _app_benchmark( + app: AMGIApplication, + message_scope: MessageScope, + message_receive_event: MessageReceiveEvent, + ) -> None: + loop = asyncio.new_event_loop() + + async def _receive() -> MessageReceiveEvent: + return message_receive_event + + benchmark( + lambda: loop.run_until_complete( + app( + message_scope, + _receive, + _send, + ) + ) + ) + + return _app_benchmark + + +def test_message_payload(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + class MessagePayload(BaseModel): + id: int + + @app.channel("topic") + async def topic_handler(payload: MessagePayload) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + "payload": b'{"id":1}', + }, + ) + + +def test_message_payload_optional(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + class MessagePayload(BaseModel): + id: int + + @app.channel("topic") + async def topic_handler(payload: MessagePayload | None) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) + + +def test_message_payload_sync(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + class MessagePayload(BaseModel): + id: int + + @app.channel("topic") + def topic_handler(payload: MessagePayload) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + "payload": b'{"id":1}', + }, + ) + + +def test_message_header_string(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(etag: Annotated[str, Header(alias="ETag")]) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [(b"ETag", b"33a64df551425fcc55e4d42a148795d9f25f89d4")], + }, + ) + + +def test_message_header_integer(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(id: Annotated[int, Header()]) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [(b"id", b"10")], + }, + ) + + +def test_message_header_underscore_to_hyphen(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler( + idempotency_key: Annotated[UUID, Header()], + ) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [(b"idempotency-key", b"8e03978e-40d5-43e8-bc93-6894a57f9324")], + }, + ) + + +def test_message_headers_multiple(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler( + id: Annotated[int, Header()], + etag: Annotated[str, Header()], + ) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [ + (b"id", b"10"), + (b"etag", b"33a64df551425fcc55e4d42a148795d9f25f89d4"), + ], + }, + ) + + +@pytest.mark.parametrize( + "headers", + [[(b"id", b"33a64df551425fcc55e4d42a148795d9f25f89d4")], []], +) +def test_message_header_optional( + headers: Sequence[tuple[bytes, bytes]], app_benchmark: AppBenchmark +) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(id: Annotated[Optional[str], Header()] = None) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": headers, + }, + ) + + +@pytest.mark.parametrize( + "headers", + [ + [(b"id", b"1")], + [(b"id", b"1"), (b"example", b"value")], + ], +) +def test_message_header_default( + headers: Sequence[tuple[bytes, bytes]], app_benchmark: AppBenchmark +) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler( + id: Annotated[int, Header()], example: Annotated[str, Header()] = "default" + ) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": headers, + }, + ) + + +def test_message_sending_dict(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler() -> AsyncGenerator[dict[str, Any], None]: + yield { + "address": "send_topic", + "payload": b'{"key": "KEY-001"}', + "headers": [(b"Id", b"10")], + } + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) + + +def test_message_payload_dataclass(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @dataclass + class MessagePayload: + id: int + + @app.channel("topic") + async def topic_handler(payload: MessagePayload) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + "payload": b'{"id":1}', + }, + ) + + +def test_message_payload_simple(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(payload: Annotated[int, Payload()]) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + "payload": b"123", + }, + ) + + +def test_message_payload_address_parameter(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("order.{user_id}") + async def order_handler(user_id: str) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "order.1234", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) + + +def test_message_sending_message(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @dataclass + class SendMessage(Message, address="send_topic"): + payload: int + id: Annotated[int, Header()] + + @app.channel("topic") + async def topic_handler() -> AsyncGenerator[SendMessage, None]: + yield SendMessage(payload=10, id=10) + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) + + +def test_message_address_parameter(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @dataclass + class SendMessage(Message, address="send.{name}"): + name: str + payload: int + + @app.channel("topic") + async def topic_handler() -> AsyncGenerator[SendMessage, None]: + yield SendMessage( + name="test", + payload=10, + ) + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) + + +def test_message_nack(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler() -> None: + raise Exception("test") + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) + + +def test_message_binding_kafka_key(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(key: Annotated[int, KafkaKey()]) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + "bindings": {"kafka": {"key": b"1234"}}, + }, + ) + + +@pytest.mark.parametrize("bindings", ({}, {"kafka": {"key": b"1234"}})) +def test_message_binding_default_kafka_key( + bindings: dict[str, Any], app_benchmark: AppBenchmark +) -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(key: Annotated[Optional[int], KafkaKey()] = None) -> None: + pass + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + "bindings": bindings, + }, + ) + + +def test_message_sender(app_benchmark: AppBenchmark) -> None: + app = AsyncFast() + + @dataclass + class SendMessage(Message, address="send_topic"): + payload: int + id: Annotated[int, Header()] + + @app.channel("topic") + async def topic_handler(message_sender: MessageSender[SendMessage]) -> None: + await message_sender.send(SendMessage(payload=10, id=10)) + + app_benchmark( + app, + { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "topic", + }, + { + "type": "message.receive", + "id": "id-1", + "headers": [], + }, + ) diff --git a/uv.lock b/uv.lock index d74ce58..e3a0c0c 100644 --- a/uv.lock +++ b/uv.lock @@ -624,6 +624,7 @@ standard = [ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, + { name = "pytest-benchmark" }, { name = "pytest-cov" }, { name = "pytest-timeout" }, ] @@ -640,6 +641,7 @@ provides-extras = ["standard"] dev = [ { name = "pytest", specifier = ">=8.4.1" }, { name = "pytest-asyncio", specifier = ">=1.1" }, + { name = "pytest-benchmark", specifier = ">=5.2.3" }, { name = "pytest-cov", specifier = ">=7.0.0" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, ] @@ -1989,6 +1991,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, ] +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, +] + [[package]] name = "pydantic" version = "2.12.5" @@ -2176,6 +2187,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, ] +[[package]] +name = "pytest-benchmark" +version = "5.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "py-cpuinfo" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" }, +] + [[package]] name = "pytest-cov" version = "7.0.0" From f9e8c861b0dd98724efc6aac8e591076b0b1203c Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sat, 7 Feb 2026 20:31:22 +0000 Subject: [PATCH 04/11] perf(asyncfast): refactor core channel/message pipeline for greater speed --- packages/asyncfast/src/asyncfast/__init__.py | 914 +----------------- packages/asyncfast/src/asyncfast/_asyncapi.py | 141 +++ .../asyncfast/src/asyncfast/_asyncfast.py | 329 +++++++ packages/asyncfast/src/asyncfast/_channel.py | 293 ++++++ packages/asyncfast/src/asyncfast/_message.py | 182 ++++ packages/asyncfast/src/asyncfast/_utils.py | 32 + .../tests_asyncfast/test_address_pattern.py | 2 +- .../asyncfast/tests_asyncfast/test_channel.py | 306 ++++++ .../asyncfast/tests_asyncfast/test_message.py | 23 +- 9 files changed, 1309 insertions(+), 913 deletions(-) create mode 100644 packages/asyncfast/src/asyncfast/_asyncapi.py create mode 100644 packages/asyncfast/src/asyncfast/_asyncfast.py create mode 100644 packages/asyncfast/src/asyncfast/_channel.py create mode 100644 packages/asyncfast/src/asyncfast/_message.py create mode 100644 packages/asyncfast/src/asyncfast/_utils.py create mode 100644 packages/asyncfast/tests_asyncfast/test_channel.py diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index 2c74dcd..e02920e 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -1,897 +1,17 @@ -import asyncio -import inspect -import re -from collections import Counter -from collections.abc import AsyncGenerator -from collections.abc import Awaitable -from collections.abc import Callable -from collections.abc import Generator -from collections.abc import Iterable -from collections.abc import Iterator -from collections.abc import Mapping -from collections.abc import Sequence -from contextlib import AbstractAsyncContextManager -from functools import cached_property -from functools import partial -from inspect import Signature -from re import Pattern -from types import UnionType -from typing import Annotated -from typing import Any -from typing import ClassVar -from typing import Generic -from typing import get_args -from typing import get_origin -from typing import TypeVar -from typing import Union - -from amgi_types import AMGIReceiveCallable -from amgi_types import AMGISendCallable -from amgi_types import LifespanShutdownCompleteEvent -from amgi_types import LifespanStartupCompleteEvent -from amgi_types import MessageAckEvent -from amgi_types import MessageNackEvent -from amgi_types import MessageReceiveEvent -from amgi_types import MessageScope -from amgi_types import MessageSendEvent -from amgi_types import Scope -from asyncfast.bindings import Binding -from pydantic import BaseModel -from pydantic import create_model -from pydantic import TypeAdapter -from pydantic.fields import FieldInfo -from pydantic.json_schema import GenerateJsonSchema -from pydantic.json_schema import JsonSchemaMode -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema - -DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) -M = TypeVar("M", bound=Mapping[str, Any]) -Lifespan = Callable[["AsyncFast"], AbstractAsyncContextManager[None]] - - -_FIELD_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") -_PARAMETER_PATTERN = re.compile(r"{(.*)}") - - -class InvalidChannelDefinitionError(ValueError): - """ - Raised when a channel or message handler is defined with an invalid shape. - """ - - -async def _send_message(send: AMGISendCallable, message: Mapping[str, Any]) -> None: - message_send_event: MessageSendEvent = { - "type": "message.send", - "address": message["address"], - "headers": message["headers"], - "payload": message.get("payload"), - } - await send(message_send_event) - - -class MessageSender(Generic[M]): - def __init__(self, send: AMGISendCallable) -> None: - self._send = send - - async def send(self, message: M) -> None: - await _send_message(self._send, message) - - -class _Field: - def __init__(self, type_: type): - self.type = type_ - self.type_adapter = TypeAdapter[Any](type_) - - def __hash__(self) -> int: - return hash(self.type) - - -class Message(Mapping[str, Any]): - - __address__: ClassVar[str | None] = None - __headers__: ClassVar[dict[str, _Field]] - __headers_model__: ClassVar[type[BaseModel] | None] - __parameters__: ClassVar[dict[str, TypeAdapter[Any]]] - __payload__: ClassVar[tuple[str, _Field] | None] - __bindings__: ClassVar[dict[str, _Field]] - - def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: - cls.__address__ = address - annotations = list(_generate_message_annotations(address, cls.__annotations__)) - - headers = { - name: _Field(annotated) - for name, annotated in annotations - if isinstance(get_args(annotated)[1], Header) - } - - parameters = { - name: TypeAdapter(annotated) - for name, annotated in annotations - if isinstance(get_args(annotated)[1], Parameter) - } - - bindings = { - name: _Field(annotated) - for name, annotated in annotations - if isinstance(get_args(annotated)[1], Binding) - } - - payloads = [ - (name, _Field(annotated)) - for name, annotated in annotations - if isinstance(get_args(annotated)[1], Payload) - ] - - assert len(payloads) <= 1, "Channel must have no more than 1 payload" - - payload = payloads[0] if len(payloads) == 1 else None - - cls.__headers__ = headers - cls.__parameters__ = parameters - cls.__payload__ = payload - cls.__bindings__ = bindings - - def __getitem__(self, key: str, /) -> Any: - if key == "address": - return self._get_address() - elif key == "headers": - return self._get_headers() - elif key == "payload" and self.__payload__: - return self._get_payload() - elif key == "bindings" and self.__bindings__: - return self._get_bindings() - raise KeyError(key) - - def __len__(self) -> int: - payload = 1 if self.__payload__ else 0 - bindings = 1 if self.__bindings__ else 0 - return 2 + payload + bindings - - def __iter__(self) -> Iterator[str]: - yield from ("address", "headers") - if self.__payload__: - yield "payload" - if self.__bindings__: - yield "bindings" - - def _get_address(self) -> str | None: - if self.__address__ is None: - return None - parameters = { - name: type_adapter.dump_python(getattr(self, name)) - for name, type_adapter in self.__parameters__.items() - } - - return self.__address__.format(**parameters) - - def _generate_headers(self) -> Iterable[tuple[str, bytes]]: - for name, field in self.__headers__.items(): - _, annotation = get_args(field.type) - alias = annotation.alias if annotation.alias else name.replace("_", "-") - yield alias, self._get_value(name, field.type_adapter) - - def _get_headers(self) -> Iterable[tuple[bytes, bytes]]: - return [(name.encode(), value) for name, value in self._generate_headers()] - - def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes: - value = getattr(self, name) - python = type_adapter.dump_python(value, mode="json") - if isinstance(python, str): - return python.encode() - if isinstance(python, bytes): - return python - return type_adapter.dump_json(value) - - def _get_payload(self) -> bytes | None: - if self.__payload__ is None: - return None - name, field = self.__payload__ - return field.type_adapter.dump_json(getattr(self, name)) - - def _get_bindings(self) -> dict[str, dict[str, Any]]: - bindings: dict[str, dict[str, Any]] = {} - for name, field in self.__bindings__.items(): - binding_type = get_args(field.type)[1] - assert isinstance(binding_type, Binding) - - bindings.setdefault(binding_type.__protocol__, {})[ - binding_type.__field_name__ - ] = self._get_value(name, field.type_adapter) - return bindings - - @classmethod - def _headers_model(cls) -> type[BaseModel] | None: - if not hasattr(cls, "__headers_model__"): - if cls.__headers__: - cls.__headers_model__ = _create_headers_model( - f"{cls.__name__}Headers", cls.__headers__ - ) - else: - cls.__headers_model__ = None - return cls.__headers_model__ - - -def _generate_message_annotations( - address: str | None, - fields: dict[str, Any], -) -> Generator[tuple[str, type[Annotated[Any, Any]]], None, None]: - address_parameters = _get_address_parameters(address) - for name, field in fields.items(): - if get_origin(field) is Annotated: - yield name, field - elif name in address_parameters: - yield name, Annotated[field, Parameter()] # type: ignore[misc] - else: - yield name, Annotated[field, Payload()] # type: ignore[misc] - - -def _is_message(cls: type[Any]) -> bool: - try: - return issubclass(cls, Message) - except TypeError: - return False - - -def _is_union(type_annotation: type) -> bool: - origin = get_origin(type_annotation) - return origin is Union or origin is UnionType - - -class AsyncFast: - def __init__( - self, - title: str = "AsyncFast", - version: str = "0.1.0", - lifespan: Lifespan | None = None, - ) -> None: - self._channels: list[_Channel] = [] - self._title = title - self._version = version - self._lifespan_context = lifespan - self._lifespan: AbstractAsyncContextManager[None] | None = None - - @property - def title(self) -> str: - return self._title - - @property - def version(self) -> str: - return self._version - - def channel(self, address: str) -> Callable[[DecoratedCallable], DecoratedCallable]: - return partial(self._add_channel, address) - - def _add_channel( - self, address: str, function: DecoratedCallable - ) -> DecoratedCallable: - signature = inspect.signature(function) - - messages = [] - return_annotation = signature.return_annotation - if return_annotation is not Signature.empty and ( - get_origin(return_annotation) is AsyncGenerator - or get_origin(return_annotation) is Generator - ): - async_generator_type = get_args(return_annotation)[0] - if _is_union(async_generator_type): - messages = [ - type for type in get_args(async_generator_type) if _is_message(type) - ] - elif _is_message(async_generator_type): - messages = [get_args(return_annotation)[0]] - - annotations = list(_generate_annotations(address, signature)) - - headers = { - name: _Field(annotated) - for name, annotated in annotations - if get_origin(annotated) is Annotated - and isinstance(get_args(annotated)[1], Header) - } - - parameters = { - name: _Field(annotated) - for name, annotated in annotations - if get_origin(annotated) is Annotated - and isinstance(get_args(annotated)[1], Parameter) - } - - payloads = [ - (name, _Field(annotated)) - for name, annotated in annotations - if get_origin(annotated) is Annotated - and isinstance(get_args(annotated)[1], Payload) - ] - - bindings = { - name: _Field(annotated) - for name, annotated in annotations - if get_origin(annotated) is Annotated - and isinstance(get_args(annotated)[1], Binding) - } - - message_senders = [ - name - for name, annotated in annotations - if get_origin(annotated) is MessageSender - ] - for name, annotated in annotations: - if get_origin(annotated) is MessageSender: - (message_sender_type,) = get_args(annotated) - if _is_union(message_sender_type): - messages = [ - type - for type in get_args(message_sender_type) - if _is_message(type) - ] - elif _is_message(message_sender_type): - messages = [message_sender_type] - - if len(payloads) > 1: - raise InvalidChannelDefinitionError( - "Channel must have no more than 1 payload" - ) - - payload = payloads[0] if len(payloads) == 1 else None - - if len(message_senders) > 1: - raise InvalidChannelDefinitionError( - "Channel must have no more than 1 message sender" - ) - - message_sender = message_senders[0] if len(message_senders) == 1 else None - - address_pattern = _address_pattern(address) - - channel = _Channel( - address, - address_pattern, - function, - headers, - parameters, - payload, - messages, - bindings, - message_sender, - ) - - self._channels.append(channel) - return function - - async def __call__( - self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable - ) -> None: - if scope["type"] == "lifespan": - while True: - message = await receive() - if message["type"] == "lifespan.startup": - if self._lifespan_context is not None: - self._lifespan = self._lifespan_context(self) - await self._lifespan.__aenter__() - lifespan_startup_complete_event: LifespanStartupCompleteEvent = { - "type": "lifespan.startup.complete" - } - await send(lifespan_startup_complete_event) - elif message["type"] == "lifespan.shutdown": - if self._lifespan is not None: - await self._lifespan.__aexit__(None, None, None) - lifespan_shutdown_complete_event: LifespanShutdownCompleteEvent = { - "type": "lifespan.shutdown.complete" - } - await send(lifespan_shutdown_complete_event) - return - elif scope["type"] == "message": - address = scope["address"] - for channel in self._channels: - parameters = channel.match(address) - if parameters is not None: - await channel(scope, receive, send, parameters) - break - - def asyncapi(self) -> dict[str, Any]: - schema_generator = GenerateJsonSchema( - ref_template="#/components/schemas/{model}" - ) - - field_mapping, definitions = schema_generator.generate_definitions( - inputs=list(self._generate_inputs()) - ) - return { - "asyncapi": "3.0.0", - "info": { - "title": self.title, - "version": self.version, - }, - "channels": dict(_generate_channels(self._channels)), - "operations": dict(_generate_operations(self._channels)), - "components": { - "messages": dict(_generate_messages(self._channels, field_mapping)), - **({"schemas": definitions} if definitions else {}), - }, - } - - def _generate_inputs( - self, - ) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]: - for channel in self._channels: - for field in channel._bindings.values(): - yield hash(field), "validation", field.type_adapter.core_schema - - headers_model = channel.headers_model - if headers_model: - yield hash(headers_model), "validation", TypeAdapter( - headers_model - ).core_schema - payload = channel.payload - if payload: - _, field = payload - yield hash(field), "validation", field.type_adapter.core_schema - - for message in channel.messages: - if message.__payload__: - _, field = message.__payload__ - - yield hash(field), "serialization", field.type_adapter.core_schema - - for field in message.__bindings__.values(): - yield hash( - field.type - ), "serialization", field.type_adapter.core_schema - - message_headers_model = message._headers_model() - if message_headers_model: - yield hash(message_headers_model), "serialization", TypeAdapter( - message_headers_model - ).core_schema - - -def _generate_annotations( - address: str, - signature: Signature, -) -> Generator[tuple[str, type[Any]], None, None]: - - address_parameters = _get_address_parameters(address) - - for name, parameter in signature.parameters.items(): - annotation = parameter.annotation - if get_origin(annotation) is Annotated: - if parameter.default != parameter.empty: - args = get_args(annotation) - args[1].default = parameter.default - yield name, annotation - elif get_origin(annotation) is MessageSender: - yield name, annotation - elif name in address_parameters: - yield name, Annotated[annotation, Parameter()] # type: ignore[misc] - else: - yield name, Annotated[annotation, Payload()] # type: ignore[misc] - - -async def _handle_async_generator( - handler: Callable[..., AsyncGenerator[Any, None]], - arguments: dict[str, Any], - send: AMGISendCallable, -) -> None: - agen = handler(**arguments) - exception: Exception | None = None - while True: - try: - if exception is None: - send_message = await agen.__anext__() - else: - send_message = await agen.athrow(exception) - try: - await _send_message(send, send_message) - except Exception as e: - exception = e - else: - exception = None - except StopAsyncIteration: - break - - -def _throw_or_none(gen: Generator[Any, None, None], exception: Exception) -> Any: - try: - return gen.throw(exception) - except StopIteration: - return None - - -async def _handle_generator( - handler: Callable[..., Generator[Any, None, None]], - arguments: dict[str, Any], - send: AMGISendCallable, -) -> None: - gen = handler(**arguments) - exception: Exception | None = None - while True: - if exception is None: - send_message = await asyncio.to_thread(next, gen, None) - else: - send_message = await asyncio.to_thread(_throw_or_none, gen, exception) - if send_message is None: - break - try: - await _send_message(send, send_message) - except Exception as e: - exception = e - else: - exception = None - - -async def _receive_messages( - receive: AMGIReceiveCallable, -) -> AsyncGenerator[MessageReceiveEvent, None]: - more_messages = True - while more_messages: - message = await receive() - assert message["type"] == "message.receive" - yield message - more_messages = message.get("more_messages", False) - - -def _generate_field_definitions( - headers: Mapping[str, _Field], -) -> Iterator[tuple[str, Any]]: - for name, field in headers.items(): - type_, annotation = get_args(field.type) - alias = annotation.alias if annotation.alias else name.replace("_", "-") - yield alias, (type_, annotation) - - -def _create_headers_model( - headers_name: str, headers: Mapping[str, _Field] -) -> type[BaseModel]: - return create_model( - headers_name, __base__=BaseModel, **dict(_generate_field_definitions(headers)) - ) - - -class _Channel: - - def __init__( - self, - address: str, - address_pattern: Pattern[str], - handler: Callable[..., Awaitable[None]], - headers: Mapping[str, _Field], - parameters: Mapping[str, _Field], - payload: tuple[str, _Field] | None, - messages: Sequence[type[Message]], - bindings: Mapping[str, _Field], - message_sender: str | None, - ) -> None: - self._address = address - self._address_pattern = address_pattern - self._handler = handler - self._headers = headers - self._parameters = parameters - self._payload = payload - self._messages = messages - self._bindings = bindings - self._message_sender = message_sender - - @property - def address(self) -> str: - return self._address - - @property - def name(self) -> str: - return self._handler.__name__ - - @cached_property - def title(self) -> str: - return "".join(part.title() for part in self.name.split("_")) - - @property - def headers(self) -> Mapping[str, _Field]: - return self._headers - - @cached_property - def headers_model(self) -> type[BaseModel] | None: - if self._headers: - return _create_headers_model(f"{self.title}Headers", self._headers) - return None - - @property - def payload(self) -> tuple[str, _Field] | None: - return self._payload - - @property - def parameters(self) -> Mapping[str, _Field]: - return self._parameters - - @property - def messages(self) -> Sequence[type[Message]]: - return self._messages - - def match(self, address: str) -> dict[str, str] | None: - match = self._address_pattern.match(address) - if match: - return match.groupdict() - return None - - async def __call__( - self, - scope: MessageScope, - receive: AMGIReceiveCallable, - send: AMGISendCallable, - parameters: dict[str, str], - ) -> None: - ack_out_of_order = "message.ack.out_of_order" in scope.get("extensions", {}) - if ack_out_of_order: - await asyncio.gather( - *[ - self._handle_message(message, parameters, send) - async for message in _receive_messages(receive) - ] - ) - else: - async for message in _receive_messages(receive): - await self._handle_message(message, parameters, send) - - async def _handle_message( - self, - message: MessageReceiveEvent, - parameters: dict[str, str], - send: AMGISendCallable, - ) -> None: - try: - arguments = dict(self._generate_arguments(message, parameters, send)) - - if inspect.isasyncgenfunction(self._handler): - await _handle_async_generator(self._handler, arguments, send) - elif inspect.isgeneratorfunction(self._handler): - await _handle_generator(self._handler, arguments, send) - elif inspect.iscoroutinefunction(self._handler): - await self._handler(**arguments) - else: - await asyncio.to_thread(self._handler, **arguments) - - message_ack_event: MessageAckEvent = { - "type": "message.ack", - "id": message["id"], - } - await send(message_ack_event) - except Exception as e: - message_nack_event: MessageNackEvent = { - "type": "message.nack", - "id": message["id"], - "message": str(e), - } - await send(message_nack_event) - - def _generate_arguments( - self, - message_receive_event: MessageReceiveEvent, - parameters: dict[str, str], - send: AMGISendCallable, - ) -> Generator[tuple[str, Any], None, None]: - yield from self._generate_headers(message_receive_event) - yield from self._generate_payload(message_receive_event) - yield from self._generate_parameters(parameters) - yield from self._generate_bindings(message_receive_event) - if self._message_sender: - yield self._message_sender, MessageSender(send) - - def _generate_headers( - self, message_receive_event: MessageReceiveEvent - ) -> Generator[tuple[str, Any], None, None]: - if self.headers: - headers = _Headers(message_receive_event["headers"]) - for name, field in self.headers.items(): - annotated_args = get_args(field.type) - header_alias = annotated_args[1].alias - alias = header_alias if header_alias else name.replace("_", "-") - header = headers.get( - alias, annotated_args[1].get_default(call_default_factory=True) - ) - value = TypeAdapter(annotated_args[0]).validate_python( - header, from_attributes=True - ) - yield name, value - - def _generate_payload( - self, message_receive_event: MessageReceiveEvent - ) -> Generator[tuple[str, Any], None, None]: - if self.payload: - name, field = self.payload - payload: bytes | None = message_receive_event.get("payload") - if payload is None: - value = field.type_adapter.validate_python(None) - else: - value = field.type_adapter.validate_json(payload) - yield name, value - - def _generate_bindings( - self, message_receive_event: MessageReceiveEvent - ) -> Generator[tuple[str, Any], None, None]: - if self._bindings: - bindings = message_receive_event.get("bindings", {}) - for name, field in self._bindings.items(): - binding_type = get_args(field.type)[1] - assert isinstance(binding_type, Binding) - - yield name, field.type_adapter.validate_python( - bindings.get(binding_type.__protocol__, {}).get( - binding_type.__field_name__ - ) - ) - - def _generate_parameters( - self, parameters: dict[str, str] - ) -> Generator[tuple[str, Any], None, None]: - if self._parameters: - for name, field in self._parameters.items(): - yield name, field.type_adapter.validate_python(parameters[name]) - - -def _generate_messages( - channels: Iterable[_Channel], - field_mapping: dict[tuple[int, JsonSchemaMode], JsonSchemaValue], -) -> Generator[tuple[str, dict[str, Any]], None, None]: - for channel in channels: - message = {} - - headers_model = channel.headers_model - if headers_model: - message["headers"] = field_mapping[ - hash(channel.headers_model), "validation" - ] - - payload = channel.payload - if payload: - _, field = payload - message["payload"] = field_mapping[hash(field), "validation"] - - bindings: dict[str, dict[str, Any]] - if channel._bindings: - bindings = {} - for field in channel._bindings.values(): - binding_type = get_args(field.type)[1] - assert isinstance(binding_type, Binding) - - bindings.setdefault(binding_type.__protocol__, {})[ - binding_type.__field_name__ - ] = field_mapping[hash(field), "validation"] - message["bindings"] = bindings - - yield f"{channel.title}Message", message - - for channel_message in channel.messages: - message_message = {} - - if channel_message.__payload__: - _, field = channel_message.__payload__ - message_message["payload"] = field_mapping[hash(field), "serialization"] - - message_headers_model = channel_message._headers_model() - if message_headers_model: - message_message["headers"] = field_mapping[ - hash(message_headers_model), "serialization" - ] - - if channel_message.__bindings__: - bindings = {} - for field in channel_message.__bindings__.values(): - binding_type = get_args(field.type)[1] - assert isinstance(binding_type, Binding) - - bindings.setdefault(binding_type.__protocol__, {})[ - binding_type.__field_name__ - ] = field_mapping[hash(field), "serialization"] - message_message["bindings"] = bindings - - yield channel_message.__name__, message_message - - -def _generate_channels( - channels: Iterable[_Channel], -) -> Generator[tuple[str, dict[str, Any]], None, None]: - for channel in channels: - message_name = f"{channel.title}Message" - channel_definition = { - "address": channel.address, - "messages": { - message_name: {"$ref": f"#/components/messages/{message_name}"} - }, - } - - if channel.parameters: - channel_definition["parameters"] = {name: {} for name in channel.parameters} - - yield channel.title, channel_definition - - for message in channel.messages: - message_channel_definition = { - "address": message.__address__, - "messages": { - message.__name__: { - "$ref": f"#/components/messages/{message.__name__}" - } - }, - } - - if message.__parameters__: - message_channel_definition["parameters"] = { - name: {} for name in message.__parameters__ - } - - yield message.__name__, message_channel_definition - - -def _generate_operations( - channels: Iterable[_Channel], -) -> Generator[tuple[str, dict[str, Any]], None, None]: - for channel in channels: - yield f"receive{channel.title}", { - "action": "receive", - "channel": {"$ref": f"#/channels/{channel.title}"}, - } - - for message in channel.messages: - yield f"send{message.__name__}", { - "action": "send", - "channel": {"$ref": f"#/channels/{message.__name__}"}, - } - - -class Header(FieldInfo): - pass - - -class Payload(FieldInfo): - pass - - -class Parameter(FieldInfo): - pass - - -def _get_address_parameters(address: str | None) -> set[str]: - if address is None: - return set() - parameters = _PARAMETER_PATTERN.findall(address) - for parameter in parameters: - assert _FIELD_PATTERN.match(parameter), f"Parameter '{parameter}' is not valid" - - duplicates = {item for item, count in Counter(parameters).items() if count > 1} - assert len(duplicates) == 0, f"Address contains duplicate parameters: {duplicates}" - return set(parameters) - - -class _Headers(Mapping[str, str]): - - def __init__(self, raw_list: Iterable[tuple[bytes, bytes]]) -> None: - self.raw_list = list(raw_list) - - def __getitem__(self, key: str, /) -> str: - for header_key, header_value in self.raw_list: - if header_key.decode() == key: - return header_value.decode() - raise KeyError(key) - - def __len__(self) -> int: - return len(self.raw_list) - - def __iter__(self) -> Iterator[str]: - return iter(self.keys()) - - def keys(self) -> list[str]: # type: ignore[override] - return [key.decode() for key, _ in self.raw_list] - - -def _address_pattern(address: str) -> Pattern[str]: - index = 0 - address_regex = "^" - for match in _PARAMETER_PATTERN.finditer(address): - (name,) = match.groups() - address_regex += re.escape(address[index : match.start()]) - address_regex += f"(?P<{name}>.*)" - - index = match.end() - - address_regex += re.escape(address[index:]) + "$" - return re.compile(address_regex) +from asyncfast._asyncfast import AsyncFast +from asyncfast._channel import Header +from asyncfast._channel import InvalidChannelDefinitionError +from asyncfast._channel import MessageSender +from asyncfast._channel import Parameter +from asyncfast._channel import Payload +from asyncfast._message import Message + +__all__ = [ + "AsyncFast", + "Message", + "Header", + "InvalidChannelDefinitionError", + "MessageSender", + "Parameter", + "Payload", +] diff --git a/packages/asyncfast/src/asyncfast/_asyncapi.py b/packages/asyncfast/src/asyncfast/_asyncapi.py new file mode 100644 index 0000000..69de244 --- /dev/null +++ b/packages/asyncfast/src/asyncfast/_asyncapi.py @@ -0,0 +1,141 @@ +import inspect +from collections.abc import AsyncGenerator +from collections.abc import Generator +from collections.abc import Sequence +from dataclasses import dataclass +from functools import cached_property +from inspect import Signature +from types import UnionType +from typing import Any +from typing import get_args +from typing import get_origin +from typing import Union + +from asyncfast._channel import BindingResolver +from asyncfast._channel import Channel +from asyncfast._channel import HeaderResolver +from asyncfast._channel import MessageSenderResolver +from asyncfast._channel import PayloadResolver +from asyncfast._message import Message +from pydantic import BaseModel +from pydantic import create_model +from pydantic.fields import FieldInfo + + +@dataclass(frozen=True) +class ChannelDefinition: + channel: Channel + + @cached_property + def bindings(self) -> Sequence[BindingResolver[Any]]: + return [ + resolver + for resolver in self.channel.resolvers.values() + if isinstance(resolver, BindingResolver) + ] + + @cached_property + def parameters(self) -> set[str]: + return self.channel.parameters + + @cached_property + def name(self) -> str: + return self.channel.func.__name__ + + @cached_property + def title(self) -> str: + return "".join(part.title() for part in self.name.split("_")) + + @cached_property + def headers_model( + self, + ) -> type[BaseModel] | None: + headers = [ + resolver + for resolver in self.channel.resolvers.values() + if isinstance(resolver, HeaderResolver) + ] + if not headers: + return None + field_definitions: dict[str, Any] = { + resolver.name: ( + resolver.type, + FieldInfo( + default=... if resolver.required else resolver.default, + ), + ) + for resolver in headers + } + return create_model( + f"{self.title}Headers", __base__=BaseModel, **field_definitions + ) + + @cached_property + def payload(self) -> PayloadResolver[Any] | None: + payloads = [ + resolver + for resolver in self.channel.resolvers.values() + if isinstance(resolver, PayloadResolver) + ] + if payloads: + return payloads[0] + return None + + def generate_messages(self) -> Generator[type[Message], None, None]: + signature = inspect.signature(self.channel.func) + + return_annotation = signature.return_annotation + if return_annotation is not Signature.empty and ( + get_origin(return_annotation) is AsyncGenerator + or get_origin(return_annotation) is Generator + ): + generator_type = get_args(return_annotation)[0] + if _is_union(generator_type): + for type in get_args(generator_type): + if _is_message(type): + yield type + elif _is_message(generator_type): + yield generator_type + + for resolver in self.channel.resolvers.values(): + if isinstance(resolver, MessageSenderResolver): + message_sender_type = get_args(resolver.type)[0] + if _is_union(message_sender_type): + for type in get_args(message_sender_type): + if _is_message(type): + yield type + elif _is_message(message_sender_type): + yield message_sender_type + + @cached_property + def messages(self) -> Sequence[type[Message]]: + return tuple(self.generate_messages()) + + +@dataclass(frozen=True) +class MessageDefinition: + address: str | None + name: str + parameters: set[str] + + @property + def definition(self) -> dict[str, Any]: + definition = { + "address": self.address, + "messages": {self.name: {"$ref": f"#/components/messages/{self.name}"}}, + } + if self.parameters: + definition["parameters"] = {name: {} for name in self.parameters} + return definition + + +def _is_union(type_annotation: type) -> bool: + origin = get_origin(type_annotation) + return origin is Union or origin is UnionType + + +def _is_message(cls: type[Any]) -> bool: + try: + return issubclass(cls, Message) + except TypeError: # pragma: no cover + return False diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py new file mode 100644 index 0000000..f4eeeb2 --- /dev/null +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -0,0 +1,329 @@ +import asyncio +from collections.abc import AsyncGenerator +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Iterable +from collections.abc import Mapping +from contextlib import AbstractAsyncContextManager +from functools import partial +from re import Pattern +from typing import Any +from typing import get_args +from typing import TypeVar + +from amgi_types import AMGIReceiveCallable +from amgi_types import AMGISendCallable +from amgi_types import LifespanShutdownCompleteEvent +from amgi_types import LifespanStartupCompleteEvent +from amgi_types import MessageAckEvent +from amgi_types import MessageNackEvent +from amgi_types import MessageReceiveEvent +from amgi_types import MessageScope +from amgi_types import Scope +from asyncfast._asyncapi import ChannelDefinition +from asyncfast._asyncapi import MessageDefinition +from asyncfast._channel import Channel +from asyncfast._channel import channel as make_channel +from asyncfast._channel import MessageReceive +from asyncfast._utils import _address_pattern +from asyncfast._utils import _get_address_parameters +from asyncfast.bindings import Binding +from pydantic import TypeAdapter +from pydantic.json_schema import GenerateJsonSchema +from pydantic.json_schema import JsonSchemaMode +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema + +DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) +M = TypeVar("M", bound=Mapping[str, Any]) +Lifespan = Callable[["AsyncFast"], AbstractAsyncContextManager[None]] + + +class AsyncFast: + def __init__( + self, + title: str = "AsyncFast", + version: str = "0.1.0", + lifespan: Lifespan | None = None, + ) -> None: + self._channels: list[_Channel] = [] + self._title = title + self._version = version + self._lifespan_context = lifespan + self._lifespan: AbstractAsyncContextManager[None] | None = None + + @property + def title(self) -> str: + return self._title + + @property + def version(self) -> str: + return self._version + + def channel(self, address: str) -> Callable[[DecoratedCallable], DecoratedCallable]: + return partial(self._add_channel, address) + + def _add_channel( + self, address: str, function: DecoratedCallable + ) -> DecoratedCallable: + address_pattern = _address_pattern(address) + + channel = _Channel( + address, + address_pattern, + make_channel(function, _get_address_parameters(address)), + ) + + self._channels.append(channel) + return function + + async def __call__( + self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + if self._lifespan_context is not None: + self._lifespan = self._lifespan_context(self) + await self._lifespan.__aenter__() + lifespan_startup_complete_event: LifespanStartupCompleteEvent = { + "type": "lifespan.startup.complete" + } + await send(lifespan_startup_complete_event) + elif message["type"] == "lifespan.shutdown": + if self._lifespan is not None: + await self._lifespan.__aexit__(None, None, None) + lifespan_shutdown_complete_event: LifespanShutdownCompleteEvent = { + "type": "lifespan.shutdown.complete" + } + await send(lifespan_shutdown_complete_event) + return + elif scope["type"] == "message": + address = scope["address"] + for channel in self._channels: + parameters = channel.match(address) + if parameters is not None: + await channel(scope, receive, send, parameters) + break + + def asyncapi(self) -> dict[str, Any]: + schema_generator = GenerateJsonSchema( + ref_template="#/components/schemas/{model}" + ) + + field_mapping, definitions = schema_generator.generate_definitions( + inputs=list(self._generate_inputs()) + ) + return { + "asyncapi": "3.0.0", + "info": { + "title": self.title, + "version": self.version, + }, + "channels": dict(_generate_channels(self._channels)), + "operations": dict(_generate_operations(self._channels)), + "components": { + "messages": dict(_generate_messages(self._channels, field_mapping)), + **({"schemas": definitions} if definitions else {}), + }, + } + + def _generate_inputs( + self, + ) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]: + for channel in self._channels: + for binding_resolver in channel._channel_definition.bindings: + yield hash( + binding_resolver.type + ), "validation", binding_resolver.type_adapter.core_schema + + headers_model = channel._channel_definition.headers_model + if headers_model: + yield hash(headers_model), "validation", TypeAdapter( + headers_model + ).core_schema + payload = channel._channel_definition.payload + if payload: + yield hash(payload.type), "validation", payload.type_adapter.core_schema + + for message in channel._channel_definition.messages: + if message.__payload__: + _, field = message.__payload__ + + yield hash(field), "serialization", field.type_adapter.core_schema + + for field in message.__bindings__.values(): + yield hash( + field.type + ), "serialization", field.type_adapter.core_schema + + message_headers_model = message._headers_model() + if message_headers_model: + yield hash(message_headers_model), "serialization", TypeAdapter( + message_headers_model + ).core_schema + + +async def _receive_messages( + receive: AMGIReceiveCallable, +) -> AsyncGenerator[MessageReceiveEvent, None]: + more_messages = True + while more_messages: + message = await receive() + assert message["type"] == "message.receive" + yield message + more_messages = message.get("more_messages", False) + + +class _Channel: + def __init__( + self, + address: str, + address_pattern: Pattern[str], + channel_invoker: Channel, + ) -> None: + self._address = address + self._address_pattern = address_pattern + self._channel_invoker = channel_invoker + self._channel_definition = ChannelDefinition(channel_invoker) + + @property + def address(self) -> str: + return self._address + + def match(self, address: str) -> dict[str, str] | None: + match = self._address_pattern.match(address) + if match: + return match.groupdict() + return None + + async def __call__( + self, + scope: MessageScope, + receive: AMGIReceiveCallable, + send: AMGISendCallable, + parameters: dict[str, str], + ) -> None: + ack_out_of_order = "message.ack.out_of_order" in scope.get("extensions", {}) + if ack_out_of_order: + await asyncio.gather( + *[ + self._handle_message(message, parameters, send) + async for message in _receive_messages(receive) + ] + ) + else: + async for message in _receive_messages(receive): + await self._handle_message(message, parameters, send) + + async def _handle_message( + self, + message: MessageReceiveEvent, + parameters: dict[str, str], + send: AMGISendCallable, + ) -> None: + try: + + await self._channel_invoker.call(MessageReceive(message, parameters), send) + + message_ack_event: MessageAckEvent = { + "type": "message.ack", + "id": message["id"], + } + await send(message_ack_event) + except Exception as e: + message_nack_event: MessageNackEvent = { + "type": "message.nack", + "id": message["id"], + "message": str(e), + } + await send(message_nack_event) + + +def _generate_messages( + channels: Iterable[_Channel], + field_mapping: dict[tuple[int, JsonSchemaMode], JsonSchemaValue], +) -> Generator[tuple[str, dict[str, Any]], None, None]: + for channel in channels: + message = {} + + headers_model = channel._channel_definition.headers_model + if headers_model: + message["headers"] = field_mapping[hash(headers_model), "validation"] + + payload = channel._channel_definition.payload + if payload: + message["payload"] = field_mapping[hash(payload.type), "validation"] + + bindings: dict[str, dict[str, Any]] + if channel._channel_definition.bindings: + bindings = {} + for binding_resolver in channel._channel_definition.bindings: + + bindings.setdefault(binding_resolver.protocol, {})[ + binding_resolver.field_name + ] = field_mapping[hash(binding_resolver.type), "validation"] + message["bindings"] = bindings + + yield f"{channel._channel_definition.title}Message", message + + for channel_message in channel._channel_definition.messages: + message_message = {} + + if channel_message.__payload__: + _, field = channel_message.__payload__ + message_message["payload"] = field_mapping[hash(field), "serialization"] + + message_headers_model = channel_message._headers_model() + if message_headers_model: + message_message["headers"] = field_mapping[ + hash(message_headers_model), "serialization" + ] + + if channel_message.__bindings__: + bindings = {} + for field in channel_message.__bindings__.values(): + binding_type = get_args(field.type)[1] + assert isinstance(binding_type, Binding) + + bindings.setdefault(binding_type.__protocol__, {})[ + binding_type.__field_name__ + ] = field_mapping[hash(field), "serialization"] + message_message["bindings"] = bindings + + yield channel_message.__name__, message_message + + +def _generate_channels( + channels: Iterable[_Channel], +) -> Generator[tuple[str, dict[str, Any]], None, None]: + for channel in channels: + yield channel._channel_definition.title, MessageDefinition( + channel.address, + f"{channel._channel_definition.title}Message", + channel._channel_definition.parameters, + ).definition + + for message in channel._channel_definition.messages: + yield message.__name__, MessageDefinition( + message.__address__, + message.__name__, + {name for name in message.__parameters__}, + ).definition + + +def _generate_operations( + channels: Iterable[_Channel], +) -> Generator[tuple[str, dict[str, Any]], None, None]: + for channel in channels: + yield f"receive{channel._channel_definition.title}", { + "action": "receive", + "channel": {"$ref": f"#/channels/{channel._channel_definition.title}"}, + } + + for message in channel._channel_definition.messages: + yield f"send{message.__name__}", { + "action": "send", + "channel": {"$ref": f"#/channels/{message.__name__}"}, + } diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py new file mode 100644 index 0000000..ed9aa27 --- /dev/null +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -0,0 +1,293 @@ +import asyncio +import inspect +import re +from abc import ABC +from abc import abstractmethod +from collections.abc import Callable +from collections.abc import Generator +from collections.abc import Mapping +from dataclasses import dataclass +from functools import cached_property +from typing import Annotated +from typing import Any +from typing import Generic +from typing import get_args +from typing import get_origin +from typing import TypeVar + +from amgi_types import AMGISendCallable +from amgi_types import MessageReceiveEvent +from amgi_types import MessageSendEvent +from asyncfast.bindings import Binding +from pydantic import TypeAdapter +from pydantic.fields import FieldInfo + +_FIELD_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") +_PARAMETER_PATTERN = re.compile(r"{(.*)}") + + +T = TypeVar("T") +M = TypeVar("M", bound=Mapping[str, Any]) + + +class InvalidChannelDefinitionError(ValueError): + """ + Raised when a channel or message handler is defined with an invalid shape. + """ + + +class Header(FieldInfo): + pass + + +class Payload(FieldInfo): + pass + + +class Parameter(FieldInfo): + pass + + +@dataclass(frozen=True) +class MessageReceive: + message: MessageReceiveEvent + address_parameters: dict[str, str] + + @cached_property + def headers(self) -> dict[str, bytes]: + return {key.decode(): value for key, value in reversed(self.message["headers"])} + + +class Resolver(Generic[T], ABC): + @abstractmethod + def resolve(self, message_receive: MessageReceive, send: AMGISendCallable) -> T: ... + + +@dataclass(frozen=True) +class TypeResolve(Resolver[T], ABC): + type: type[T] + + @cached_property + def type_adapter(self) -> TypeAdapter[T]: + return TypeAdapter(self.type) + + +@dataclass(frozen=True) +class PayloadResolver(TypeResolve[T]): + def resolve(self, message_receive: MessageReceive, send: AMGISendCallable) -> T: + payload: bytes | None = message_receive.message.get("payload") + if payload is None: + return self.type_adapter.validate_python(None) + return self.type_adapter.validate_json(payload) + + +@dataclass(frozen=True) +class BindingResolver(TypeResolve[T]): + protocol: str + field_name: str + default: T + + def resolve(self, message_receive: MessageReceive, send: AMGISendCallable) -> T: + bindings = message_receive.message.get("bindings", {}) + + return self.type_adapter.validate_python( + bindings.get(self.protocol, {}).get(self.field_name, self.default) + ) + + +@dataclass(frozen=True) +class AddressParameterResolver(Resolver[str]): + name: str + + def resolve(self, message_receive: MessageReceive, send: AMGISendCallable) -> str: + return message_receive.address_parameters[self.name] + + +@dataclass(frozen=True) +class HeaderResolver(TypeResolve[T]): + name: str + default: T + required: bool + + def resolve(self, message_receive: MessageReceive, send: AMGISendCallable) -> T: + if not self.required: + value = message_receive.headers.get(self.name, self.default) + else: + value = message_receive.headers[self.name] + return self.type_adapter.validate_python(value) + + +async def send_message(send: AMGISendCallable, message: Mapping[str, Any]) -> None: + message_send_event: MessageSendEvent = { + "type": "message.send", + "address": message["address"], + "headers": message["headers"], + "payload": message.get("payload"), + } + await send(message_send_event) + + +class MessageSender(Generic[M]): + def __init__(self, send: AMGISendCallable) -> None: + self._send = send + + async def send(self, message: M) -> None: + await send_message(self._send, message) + + +@dataclass(frozen=True) +class MessageSenderResolver(Resolver[MessageSender[M]]): + type: type[MessageSender[M]] + + def resolve( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> MessageSender[M]: + return MessageSender(send) + + +@dataclass(frozen=True) +class Channel(ABC): + func: Callable[..., Any] + parameters: set[str] + resolvers: dict[str, Resolver[Any]] + + def resolve( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> dict[str, Any]: + return { + name: resolver.resolve(message_receive, send) + for name, resolver in self.resolvers.items() + } + + @abstractmethod + async def call( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> None: ... + + +@dataclass(frozen=True) +class SyncChannel(Channel): + async def call( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> None: + await asyncio.to_thread(self.func, **self.resolve(message_receive, send)) + + +@dataclass(frozen=True) +class AsyncChannel(Channel): + async def call( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> None: + await self.func(**self.resolve(message_receive, send)) + + +class AsyncGeneratorChannel(Channel): + async def call( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> None: + agen = self.func(**self.resolve(message_receive, send)) + exception: Exception | None = None + while True: + try: + if exception is None: + message = await agen.__anext__() + else: + message = await agen.athrow(exception) + try: + await send_message(send, message) + except Exception as e: + exception = e + else: + exception = None + except StopAsyncIteration: + break + + +def _throw_or_none(gen: Generator[Any, None, None], exception: Exception) -> Any: + try: + return gen.throw(exception) + except StopIteration: + return None + + +class SyncGeneratorChannel(Channel): + async def call( + self, message_receive: MessageReceive, send: AMGISendCallable + ) -> None: + + gen = self.func(**self.resolve(message_receive, send)) + exception: Exception | None = None + while True: + if exception is None: + message = await asyncio.to_thread(next, gen, None) + else: + message = await asyncio.to_thread(_throw_or_none, gen, exception) + if message is None: + break + try: + await send_message(send, message) + except Exception as e: + exception = e + else: + exception = None + + +def parameter_resolver( + name: str, parameter: inspect.Parameter, address_parameters: set[str] +) -> Resolver[Any]: + if name in address_parameters: + return AddressParameterResolver(name) + if get_origin(parameter.annotation) is Annotated: + _, annotation, *_ = get_args(parameter.annotation) + if isinstance(annotation, Header): + header_name = ( + annotation.alias if annotation.alias else name.replace("_", "-") + ) + + return HeaderResolver( + parameter.annotation, + header_name, + parameter.default, + parameter.default is parameter.empty, + ) + if isinstance(annotation, Binding): + return BindingResolver( + parameter.annotation, + annotation.__protocol__, + annotation.__field_name__, + parameter.default, + ) + if get_origin(parameter.annotation) is MessageSender: + return MessageSenderResolver(parameter.annotation) + + return PayloadResolver(parameter.annotation) + + +def channel(func: Callable[..., Any], address_parameters: set[str]) -> Channel: + signature = inspect.signature(func) + resolvers = { + name: parameter_resolver(name, parameter, address_parameters) + for name, parameter in signature.parameters.items() + } + + payloads = sum( + isinstance(resolver, PayloadResolver) for resolver in resolvers.values() + ) + if payloads > 1: + raise InvalidChannelDefinitionError("Channel must have no more than 1 payload") + + message_senders = sum( + isinstance(resolver, MessageSenderResolver) for resolver in resolvers.values() + ) + if message_senders > 1: + raise InvalidChannelDefinitionError( + "Channel must have no more than 1 message sender" + ) + + if inspect.iscoroutinefunction(func): + return AsyncChannel(func, address_parameters, resolvers) + if inspect.isasyncgenfunction(func): + return AsyncGeneratorChannel(func, address_parameters, resolvers) + if inspect.isgeneratorfunction(func): + return SyncGeneratorChannel(func, address_parameters, resolvers) + return SyncChannel(func, address_parameters, resolvers) diff --git a/packages/asyncfast/src/asyncfast/_message.py b/packages/asyncfast/src/asyncfast/_message.py new file mode 100644 index 0000000..2e4dc3f --- /dev/null +++ b/packages/asyncfast/src/asyncfast/_message.py @@ -0,0 +1,182 @@ +from collections.abc import Generator +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from typing import Annotated +from typing import Any +from typing import ClassVar +from typing import get_args +from typing import get_origin + +from asyncfast._channel import Header +from asyncfast._channel import Parameter +from asyncfast._channel import Payload +from asyncfast._utils import _get_address_parameters +from asyncfast.bindings import Binding +from pydantic import BaseModel +from pydantic import create_model +from pydantic import TypeAdapter + + +class _Field: + def __init__(self, type_: type): + self.type = type_ + self.type_adapter = TypeAdapter[Any](type_) + + def __hash__(self) -> int: + return hash(self.type) + + +class Message(Mapping[str, Any]): + + __address__: ClassVar[str | None] = None + __headers__: ClassVar[dict[str, _Field]] + __headers_model__: ClassVar[type[BaseModel] | None] + __parameters__: ClassVar[dict[str, TypeAdapter[Any]]] + __payload__: ClassVar[tuple[str, _Field] | None] + __bindings__: ClassVar[dict[str, _Field]] + + def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: + cls.__address__ = address + annotations = list(_generate_message_annotations(address, cls.__annotations__)) + + headers = { + name: _Field(annotated) + for name, annotated in annotations + if isinstance(get_args(annotated)[1], Header) + } + + parameters = { + name: TypeAdapter(annotated) + for name, annotated in annotations + if isinstance(get_args(annotated)[1], Parameter) + } + + bindings = { + name: _Field(annotated) + for name, annotated in annotations + if isinstance(get_args(annotated)[1], Binding) + } + + payloads = [ + (name, _Field(annotated)) + for name, annotated in annotations + if isinstance(get_args(annotated)[1], Payload) + ] + + assert len(payloads) <= 1, "Channel must have no more than 1 payload" + + payload = payloads[0] if len(payloads) == 1 else None + + cls.__headers__ = headers + cls.__parameters__ = parameters + cls.__payload__ = payload + cls.__bindings__ = bindings + + def __getitem__(self, key: str, /) -> Any: + if key == "address": + return self._get_address() + elif key == "headers": + return self._get_headers() + elif key == "payload" and self.__payload__: + return self._get_payload() + elif key == "bindings" and self.__bindings__: + return self._get_bindings() + raise KeyError(key) + + def __len__(self) -> int: + payload = 1 if self.__payload__ else 0 + bindings = 1 if self.__bindings__ else 0 + return 2 + payload + bindings + + def __iter__(self) -> Iterator[str]: + yield from ("address", "headers") + if self.__payload__: + yield "payload" + if self.__bindings__: + yield "bindings" + + def _get_address(self) -> str | None: + if self.__address__ is None: + return None + parameters = { + name: type_adapter.dump_python(getattr(self, name)) + for name, type_adapter in self.__parameters__.items() + } + + return self.__address__.format(**parameters) + + def _generate_headers(self) -> Iterable[tuple[str, bytes]]: + for name, field in self.__headers__.items(): + _, annotation = get_args(field.type) + alias = annotation.alias if annotation.alias else name.replace("_", "-") + yield alias, self._get_value(name, field.type_adapter) + + def _get_headers(self) -> Iterable[tuple[bytes, bytes]]: + return [(name.encode(), value) for name, value in self._generate_headers()] + + def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes: + value = getattr(self, name) + python = type_adapter.dump_python(value, mode="json") + if isinstance(python, str): + return python.encode() + return type_adapter.dump_json(value) + + def _get_payload(self) -> bytes | None: + if self.__payload__ is None: + return None + name, field = self.__payload__ + return field.type_adapter.dump_json(getattr(self, name)) + + def _get_bindings(self) -> dict[str, dict[str, Any]]: + bindings: dict[str, dict[str, Any]] = {} + for name, field in self.__bindings__.items(): + binding_type = get_args(field.type)[1] + assert isinstance(binding_type, Binding) + + bindings.setdefault(binding_type.__protocol__, {})[ + binding_type.__field_name__ + ] = self._get_value(name, field.type_adapter) + return bindings + + @classmethod + def _headers_model(cls) -> type[BaseModel] | None: + if not hasattr(cls, "__headers_model__"): + if cls.__headers__: + cls.__headers_model__ = _create_headers_model( + f"{cls.__name__}Headers", cls.__headers__ + ) + else: + cls.__headers_model__ = None + return cls.__headers_model__ + + +def _generate_field_definitions( + headers: Mapping[str, _Field], +) -> Iterator[tuple[str, Any]]: + for name, field in headers.items(): + type_, annotation = get_args(field.type) + alias = annotation.alias if annotation.alias else name.replace("_", "-") + yield alias, (type_, annotation) + + +def _create_headers_model( + headers_name: str, headers: Mapping[str, _Field] +) -> type[BaseModel]: + return create_model( + headers_name, __base__=BaseModel, **dict(_generate_field_definitions(headers)) + ) + + +def _generate_message_annotations( + address: str | None, + fields: dict[str, Any], +) -> Generator[tuple[str, type[Annotated[Any, Any]]], None, None]: + address_parameters = _get_address_parameters(address) + for name, field in fields.items(): + if get_origin(field) is Annotated: + yield name, field + elif name in address_parameters: + yield name, Annotated[field, Parameter()] # type: ignore[misc] + else: + yield name, Annotated[field, Payload()] # type: ignore[misc] diff --git a/packages/asyncfast/src/asyncfast/_utils.py b/packages/asyncfast/src/asyncfast/_utils.py new file mode 100644 index 0000000..3de5110 --- /dev/null +++ b/packages/asyncfast/src/asyncfast/_utils.py @@ -0,0 +1,32 @@ +import re +from collections import Counter +from re import Pattern + +_FIELD_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") +_PARAMETER_PATTERN = re.compile(r"{(.*)}") + + +def _get_address_parameters(address: str | None) -> set[str]: + if address is None: + return set() + parameters = _PARAMETER_PATTERN.findall(address) + for parameter in parameters: + assert _FIELD_PATTERN.match(parameter), f"Parameter '{parameter}' is not valid" + + duplicates = {item for item, count in Counter(parameters).items() if count > 1} + assert len(duplicates) == 0, f"Address contains duplicate parameters: {duplicates}" + return set(parameters) + + +def _address_pattern(address: str) -> Pattern[str]: + index = 0 + address_regex = "^" + for match in _PARAMETER_PATTERN.finditer(address): + (name,) = match.groups() + address_regex += re.escape(address[index : match.start()]) + address_regex += f"(?P<{name}>.*)" + + index = match.end() + + address_regex += re.escape(address[index:]) + "$" + return re.compile(address_regex) diff --git a/packages/asyncfast/tests_asyncfast/test_address_pattern.py b/packages/asyncfast/tests_asyncfast/test_address_pattern.py index 3fe14c5..7cfe543 100644 --- a/packages/asyncfast/tests_asyncfast/test_address_pattern.py +++ b/packages/asyncfast/tests_asyncfast/test_address_pattern.py @@ -1,7 +1,7 @@ import re import pytest -from asyncfast import _address_pattern +from asyncfast._utils import _address_pattern @pytest.mark.parametrize( diff --git a/packages/asyncfast/tests_asyncfast/test_channel.py b/packages/asyncfast/tests_asyncfast/test_channel.py new file mode 100644 index 0000000..f87b911 --- /dev/null +++ b/packages/asyncfast/tests_asyncfast/test_channel.py @@ -0,0 +1,306 @@ +from typing import Annotated +from typing import Any +from typing import AsyncGenerator +from typing import Generator +from typing import Mapping +from unittest.mock import AsyncMock +from unittest.mock import Mock + +from asyncfast._channel import channel +from asyncfast._channel import Header +from asyncfast._channel import MessageReceive +from asyncfast._channel import MessageSender +from asyncfast.bindings import KafkaKey + + +async def test_payload_basic() -> None: + mock = Mock() + + def func(i: int) -> None: + mock(i) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + "payload": b"1", + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with(1) + + +async def test_header_basic() -> None: + mock = Mock() + + def func(header: Annotated[str, Header()]) -> None: + mock(header) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"header", b"value")], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with("value") + + +async def test_header_default() -> None: + mock = Mock() + + def func(header: Annotated[str, Header()] = "value") -> None: + mock(header) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with("value") + + +async def test_header_underscore_to_hyphen() -> None: + mock = Mock() + + def func(header_name: Annotated[str, Header()]) -> None: + mock(header_name) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"header-name", b"value")], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with("value") + + +async def test_header_alias() -> None: + mock = Mock() + + def func(etag: Annotated[str, Header(alias="ETag")]) -> None: + mock(etag) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"ETag", b"9e30981e-02d5-11f1-9648-e323315723e1")], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with("9e30981e-02d5-11f1-9648-e323315723e1") + + +async def test_address_parameter() -> None: + mock = Mock() + + def func(user: str) -> None: + mock(user) + + await channel(func, {"user"}).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + }, + {"user": "54a08cc6-02db-11f1-afbf-f3f4688d5de4"}, + ), + Mock(), + ) + + mock.assert_called_with("54a08cc6-02db-11f1-afbf-f3f4688d5de4") + + +async def test_binding() -> None: + mock = Mock() + + def func(key: Annotated[int, KafkaKey()]) -> None: + mock(key) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + "bindings": {"kafka": {"key": b"123"}}, + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with(123) + + +async def test_binding_default() -> None: + mock = Mock() + + def func(key: Annotated[int, KafkaKey()] = 123) -> None: + mock(key) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with(123) + + +async def test_async_func() -> None: + mock = AsyncMock() + + async def func(i: int) -> None: + await mock(i) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + "payload": b"1", + }, + {}, + ), + Mock(), + ) + + mock.assert_awaited_once_with(1) + + +async def test_async_generator_func() -> None: + send_mock = AsyncMock() + + async def func() -> AsyncGenerator[Mapping[str, Any], None]: + yield { + "address": "send_topic", + "payload": b'{"key": "KEY-001"}', + "headers": [(b"Id", b"10")], + } + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + "payload": b"1", + }, + {}, + ), + send_mock, + ) + + send_mock.assert_awaited_once_with( + { + "type": "message.send", + "address": "send_topic", + "headers": [(b"Id", b"10")], + "payload": b'{"key": "KEY-001"}', + } + ) + + +async def test_sync_generator_func() -> None: + send_mock = AsyncMock() + + def func() -> Generator[Mapping[str, Any], None, None]: + yield { + "address": "send_topic", + "payload": b'{"key": "KEY-001"}', + "headers": [(b"Id", b"10")], + } + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + "payload": b"1", + }, + {}, + ), + send_mock, + ) + + send_mock.assert_awaited_once_with( + { + "type": "message.send", + "address": "send_topic", + "headers": [(b"Id", b"10")], + "payload": b'{"key": "KEY-001"}', + } + ) + + +async def test_message_sender() -> None: + send_mock = AsyncMock() + + async def func(message_sender: MessageSender[Mapping[str, Any]]) -> None: + await message_sender.send( + { + "address": "send_topic", + "payload": b'{"key": "KEY-001"}', + "headers": [(b"Id", b"10")], + } + ) + + await channel(func, set()).call( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + "payload": b"1", + }, + {}, + ), + send_mock, + ) + + send_mock.assert_awaited_once_with( + { + "type": "message.send", + "address": "send_topic", + "headers": [(b"Id", b"10")], + "payload": b'{"key": "KEY-001"}', + } + ) diff --git a/packages/asyncfast/tests_asyncfast/test_message.py b/packages/asyncfast/tests_asyncfast/test_message.py index 4869924..9d8954b 100644 --- a/packages/asyncfast/tests_asyncfast/test_message.py +++ b/packages/asyncfast/tests_asyncfast/test_message.py @@ -720,9 +720,13 @@ async def topic_handler() -> AsyncGenerator[dict[str, Any], None]: async def test_message_sending_dict_post_error_sync() -> None: app = AsyncFast() + send_exception = Exception("test") + def send_mock(event: AMGISendEvent) -> None: if event["type"] == "message.send" and event["address"] == "error": - raise Exception("test") + raise send_exception + + after_exception_mock = Mock() @app.channel("topic") def topic_handler() -> Generator[dict[str, Any], None, None]: @@ -732,12 +736,8 @@ def topic_handler() -> Generator[dict[str, Any], None, None]: "payload": b"1", "headers": [], } - except Exception: - yield { - "address": "not_error", - "payload": b"1", - "headers": [], - } + except Exception as e: + after_exception_mock(send_exception) message_scope: MessageScope = { "type": "message", @@ -766,17 +766,10 @@ def topic_handler() -> Generator[dict[str, Any], None, None]: "payload": b"1", } ), - call( - { - "type": "message.send", - "address": "not_error", - "headers": [], - "payload": b"1", - } - ), call({"type": "message.ack", "id": "id-1"}), ] ) + after_exception_mock.assert_called_with(send_exception) async def test_message_invalid_payload_nack() -> None: From c36138519f969e39ea4ca9e55d4e3656bf91325d Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 8 Feb 2026 16:06:21 +0000 Subject: [PATCH 05/11] test(asyncfast): update message class tests so they are benchmarks --- .../tests_asyncfast/test_message_object.py | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/packages/asyncfast/tests_asyncfast/test_message_object.py b/packages/asyncfast/tests_asyncfast/test_message_object.py index 9f736e9..0a4a357 100644 --- a/packages/asyncfast/tests_asyncfast/test_message_object.py +++ b/packages/asyncfast/tests_asyncfast/test_message_object.py @@ -1,14 +1,32 @@ from dataclasses import dataclass from typing import Annotated +from typing import Any +from typing import Callable +from typing import cast from uuid import UUID +import pytest from asyncfast import Header from asyncfast import Message from asyncfast.bindings import KafkaKey from pydantic import BaseModel +from pytest_benchmark.fixture import BenchmarkFixture +MessageBenchmark = Callable[[Message], dict[str, Any]] -def test_message_payload() -> None: + +@pytest.fixture +def message_benchmark(benchmark: BenchmarkFixture) -> MessageBenchmark: + def _message_benchmark( + message: Message, + ) -> dict[str, Any]: + result = benchmark(lambda: dict(message)) + return cast(dict[str, Any], result) + + return _message_benchmark + + +def test_message_payload(message_benchmark: MessageBenchmark) -> None: class Data(BaseModel): id: str @@ -18,99 +36,99 @@ class Response(Message, address="response_channel"): response = Response(data=Data(id="test")) - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "headers": [], "payload": b'{"id":"test"}', } -def test_message_header() -> None: +def test_message_header(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="response_channel"): id: Annotated[int, Header()] response = Response(id=100) - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "headers": [(b"id", b"100")], } -def test_message_header_alias() -> None: +def test_message_header_alias(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="response_channel"): id: Annotated[int, Header(alias="Id")] response = Response(id=100) - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "headers": [(b"Id", b"100")], } -def test_message_header_underscore() -> None: +def test_message_header_underscore(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="response_channel"): request_id: Annotated[str, Header()] response = Response(request_id="46345148-fafe-11f0-af7a-975c1791ef56") - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "headers": [(b"request-id", b"46345148-fafe-11f0-af7a-975c1791ef56")], } -def test_message_parameter() -> None: +def test_message_parameter(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="register.{user_id}"): user_id: str response = Response(user_id="ec5e9f87-c896-4fb1-b028-8352ef654e05") - assert dict(response) == { + assert message_benchmark(response) == { "address": "register.ec5e9f87-c896-4fb1-b028-8352ef654e05", "headers": [], } -def test_message_header_string() -> None: +def test_message_header_string(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="response_channel"): id: Annotated[UUID, Header()] response = Response(id=UUID("ec5e9f87-c896-4fb1-b028-8352ef654e05")) - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "headers": [(b"id", b"ec5e9f87-c896-4fb1-b028-8352ef654e05")], } -def test_message_header_bytes() -> None: +def test_message_header_bytes(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="response_channel"): id: Annotated[bytes, Header()] response = Response(id=b"1234") - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "headers": [(b"id", b"1234")], } -def test_message_binding_kafka_key() -> None: +def test_message_binding_kafka_key(message_benchmark: MessageBenchmark) -> None: @dataclass class Response(Message, address="response_channel"): key: Annotated[UUID, KafkaKey()] response = Response(key=UUID("ec5e9f87-c896-4fb1-b028-8352ef654e05")) - assert dict(response) == { + assert message_benchmark(response) == { "address": "response_channel", "bindings": {"kafka": {"key": b"ec5e9f87-c896-4fb1-b028-8352ef654e05"}}, "headers": [], From 6f4a080c10caa6ccb991bea4ca96ff055ab1b6cf Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 8 Feb 2026 16:08:01 +0000 Subject: [PATCH 06/11] perf(asyncfast): cache header and binding metadata to reduce lookups --- .../asyncfast/src/asyncfast/_asyncfast.py | 4 +- packages/asyncfast/src/asyncfast/_message.py | 41 +++++++++++-------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index f4eeeb2..c45dad0 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -153,7 +153,7 @@ def _generate_inputs( yield hash(field), "serialization", field.type_adapter.core_schema - for field in message.__bindings__.values(): + for _, _, field in message.__bindings__.values(): yield hash( field.type ), "serialization", field.type_adapter.core_schema @@ -283,7 +283,7 @@ def _generate_messages( if channel_message.__bindings__: bindings = {} - for field in channel_message.__bindings__.values(): + for _, _, field in channel_message.__bindings__.values(): binding_type = get_args(field.type)[1] assert isinstance(binding_type, Binding) diff --git a/packages/asyncfast/src/asyncfast/_message.py b/packages/asyncfast/src/asyncfast/_message.py index 2e4dc3f..3760c30 100644 --- a/packages/asyncfast/src/asyncfast/_message.py +++ b/packages/asyncfast/src/asyncfast/_message.py @@ -30,18 +30,25 @@ def __hash__(self) -> int: class Message(Mapping[str, Any]): __address__: ClassVar[str | None] = None - __headers__: ClassVar[dict[str, _Field]] + __headers__: ClassVar[dict[str, tuple[str, _Field]]] __headers_model__: ClassVar[type[BaseModel] | None] __parameters__: ClassVar[dict[str, TypeAdapter[Any]]] __payload__: ClassVar[tuple[str, _Field] | None] - __bindings__: ClassVar[dict[str, _Field]] + __bindings__: ClassVar[dict[str, tuple[str, str, _Field]]] def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: cls.__address__ = address annotations = list(_generate_message_annotations(address, cls.__annotations__)) headers = { - name: _Field(annotated) + name: ( + ( + get_args(annotated)[1].alias + if get_args(annotated)[1].alias + else name.replace("_", "-") + ), + _Field(annotated), + ) for name, annotated in annotations if isinstance(get_args(annotated)[1], Header) } @@ -53,7 +60,11 @@ def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: } bindings = { - name: _Field(annotated) + name: ( + get_args(annotated)[1].__protocol__, + get_args(annotated)[1].__field_name__, + _Field(annotated), + ) for name, annotated in annotations if isinstance(get_args(annotated)[1], Binding) } @@ -107,9 +118,7 @@ def _get_address(self) -> str | None: return self.__address__.format(**parameters) def _generate_headers(self) -> Iterable[tuple[str, bytes]]: - for name, field in self.__headers__.items(): - _, annotation = get_args(field.type) - alias = annotation.alias if annotation.alias else name.replace("_", "-") + for name, (alias, field) in self.__headers__.items(): yield alias, self._get_value(name, field.type_adapter) def _get_headers(self) -> Iterable[tuple[bytes, bytes]]: @@ -130,13 +139,10 @@ def _get_payload(self) -> bytes | None: def _get_bindings(self) -> dict[str, dict[str, Any]]: bindings: dict[str, dict[str, Any]] = {} - for name, field in self.__bindings__.items(): - binding_type = get_args(field.type)[1] - assert isinstance(binding_type, Binding) - - bindings.setdefault(binding_type.__protocol__, {})[ - binding_type.__field_name__ - ] = self._get_value(name, field.type_adapter) + for name, (protocol, field_name, field) in self.__bindings__.items(): + bindings.setdefault(protocol, {})[field_name] = self._get_value( + name, field.type_adapter + ) return bindings @classmethod @@ -152,16 +158,15 @@ def _headers_model(cls) -> type[BaseModel] | None: def _generate_field_definitions( - headers: Mapping[str, _Field], + headers: Mapping[str, tuple[str, _Field]], ) -> Iterator[tuple[str, Any]]: - for name, field in headers.items(): + for name, (alias, field) in headers.items(): type_, annotation = get_args(field.type) - alias = annotation.alias if annotation.alias else name.replace("_", "-") yield alias, (type_, annotation) def _create_headers_model( - headers_name: str, headers: Mapping[str, _Field] + headers_name: str, headers: Mapping[str, tuple[str, _Field]] ) -> type[BaseModel]: return create_model( headers_name, __base__=BaseModel, **dict(_generate_field_definitions(headers)) From 2bb53fef0cc34eea5d24fa95993e79fd8f93de30 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 8 Feb 2026 16:38:22 +0000 Subject: [PATCH 07/11] test(asyncfast): improve message class test coverage --- .../tests_asyncfast/test_message_object.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/packages/asyncfast/tests_asyncfast/test_message_object.py b/packages/asyncfast/tests_asyncfast/test_message_object.py index 0a4a357..c9270d7 100644 --- a/packages/asyncfast/tests_asyncfast/test_message_object.py +++ b/packages/asyncfast/tests_asyncfast/test_message_object.py @@ -133,3 +133,41 @@ class Response(Message, address="response_channel"): "bindings": {"kafka": {"key": b"ec5e9f87-c896-4fb1-b028-8352ef654e05"}}, "headers": [], } + + +def test_message_unsupported_key() -> None: + @dataclass + class Response(Message, address="response_channel"): + key: Annotated[UUID, KafkaKey()] + + response = Response(key=UUID("ec5e9f87-c896-4fb1-b028-8352ef654e05")) + + with pytest.raises(KeyError): + response["other_key"] + + +def test_message_no_address(message_benchmark: MessageBenchmark) -> None: + class Data(BaseModel): + id: str + + @dataclass + class Response(Message): + data: Data + + response = Response(data=Data(id="test")) + + assert message_benchmark(response) == { + "address": None, + "headers": [], + "payload": b'{"id":"test"}', + } + + +def test_message_len() -> None: + @dataclass + class Response(Message, address="response_channel"): + key: Annotated[UUID, KafkaKey()] + + response = Response(key=UUID("ec5e9f87-c896-4fb1-b028-8352ef654e05")) + + assert len(response) == 3 From eaf59c9f9c155bc1d71699b43d158980ae2b7342 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 8 Feb 2026 16:48:58 +0000 Subject: [PATCH 08/11] fix(asyncfast): raise channel not found error for unknown channels --- packages/asyncfast/src/asyncfast/__init__.py | 2 ++ .../asyncfast/src/asyncfast/_asyncfast.py | 9 +++++- .../asyncfast/tests_asyncfast/test_message.py | 28 +++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index e02920e..24057d6 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -1,4 +1,5 @@ from asyncfast._asyncfast import AsyncFast +from asyncfast._asyncfast import ChannelNotFoundError from asyncfast._channel import Header from asyncfast._channel import InvalidChannelDefinitionError from asyncfast._channel import MessageSender @@ -8,6 +9,7 @@ __all__ = [ "AsyncFast", + "ChannelNotFoundError", "Message", "Header", "InvalidChannelDefinitionError", diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index c45dad0..cab27af 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -39,6 +39,12 @@ Lifespan = Callable[["AsyncFast"], AbstractAsyncContextManager[None]] +class ChannelNotFoundError(LookupError): + def __init__(self, address: str) -> None: + super().__init__(f"Couldn't resolve address: {address}") + self.address = address + + class AsyncFast: def __init__( self, @@ -105,7 +111,8 @@ async def __call__( parameters = channel.match(address) if parameters is not None: await channel(scope, receive, send, parameters) - break + return + raise ChannelNotFoundError(address) def asyncapi(self) -> dict[str, Any]: schema_generator = GenerateJsonSchema( diff --git a/packages/asyncfast/tests_asyncfast/test_message.py b/packages/asyncfast/tests_asyncfast/test_message.py index 9d8954b..cd47207 100644 --- a/packages/asyncfast/tests_asyncfast/test_message.py +++ b/packages/asyncfast/tests_asyncfast/test_message.py @@ -17,6 +17,7 @@ from amgi_types import MessageReceiveEvent from amgi_types import MessageScope from asyncfast import AsyncFast +from asyncfast import ChannelNotFoundError from asyncfast import Header from asyncfast import Message from asyncfast import MessageSender @@ -946,3 +947,30 @@ async def topic_handler(i: int) -> None: AsyncMock(side_effect=[message_receive_event1, message_receive_event2]), AsyncMock(), ) + + +async def test_message_non_existant_channel() -> None: + app = AsyncFast() + + @app.channel("topic") + async def topic_handler(id: int) -> None: + pass # pragma: no cover + + message_scope: MessageScope = { + "type": "message", + "amgi": {"version": "1.0", "spec_version": "1.0"}, + "address": "not_topic", + } + message_receive_event: MessageReceiveEvent = { + "type": "message.receive", + "id": "id-1", + "headers": [], + } + with pytest.raises( + ChannelNotFoundError, match="Couldn't resolve address: not_topic" + ): + await app( + message_scope, + AsyncMock(side_effect=[message_receive_event]), + AsyncMock(), + ) From 643c29f09c2748d0c2bdf2d9a346d8219f5ea615 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 8 Feb 2026 16:51:25 +0000 Subject: [PATCH 09/11] refactor(asyncfast): remove unreachable when getting message payload --- packages/asyncfast/src/asyncfast/_message.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/packages/asyncfast/src/asyncfast/_message.py b/packages/asyncfast/src/asyncfast/_message.py index 3760c30..50dfa7d 100644 --- a/packages/asyncfast/src/asyncfast/_message.py +++ b/packages/asyncfast/src/asyncfast/_message.py @@ -90,7 +90,8 @@ def __getitem__(self, key: str, /) -> Any: elif key == "headers": return self._get_headers() elif key == "payload" and self.__payload__: - return self._get_payload() + name, field = self.__payload__ + return field.type_adapter.dump_json(getattr(self, name)) elif key == "bindings" and self.__bindings__: return self._get_bindings() raise KeyError(key) @@ -131,12 +132,6 @@ def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes: return python.encode() return type_adapter.dump_json(value) - def _get_payload(self) -> bytes | None: - if self.__payload__ is None: - return None - name, field = self.__payload__ - return field.type_adapter.dump_json(getattr(self, name)) - def _get_bindings(self) -> dict[str, dict[str, Any]]: bindings: dict[str, dict[str, Any]] = {} for name, (protocol, field_name, field) in self.__bindings__.items(): From c7b818a4644e3e2102323e1b02aba25804a97599 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 8 Feb 2026 23:33:04 +0000 Subject: [PATCH 10/11] feat(asyncfast): add dependency injection for channel parameters --- packages/asyncfast/src/asyncfast/__init__.py | 2 + packages/asyncfast/src/asyncfast/_asyncapi.py | 20 +- .../asyncfast/src/asyncfast/_asyncfast.py | 4 +- packages/asyncfast/src/asyncfast/_channel.py | 383 +++++++++++++++--- .../tests_asyncfast/test_asyncapi.py | 54 +++ .../asyncfast/tests_asyncfast/test_channel.py | 224 +++++++++- 6 files changed, 613 insertions(+), 74 deletions(-) diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index 24057d6..c41a60e 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -1,5 +1,6 @@ from asyncfast._asyncfast import AsyncFast from asyncfast._asyncfast import ChannelNotFoundError +from asyncfast._channel import Depends from asyncfast._channel import Header from asyncfast._channel import InvalidChannelDefinitionError from asyncfast._channel import MessageSender @@ -11,6 +12,7 @@ "AsyncFast", "ChannelNotFoundError", "Message", + "Depends", "Header", "InvalidChannelDefinitionError", "MessageSender", diff --git a/packages/asyncfast/src/asyncfast/_asyncapi.py b/packages/asyncfast/src/asyncfast/_asyncapi.py index 69de244..1d9ba4c 100644 --- a/packages/asyncfast/src/asyncfast/_asyncapi.py +++ b/packages/asyncfast/src/asyncfast/_asyncapi.py @@ -12,16 +12,26 @@ from typing import Union from asyncfast._channel import BindingResolver +from asyncfast._channel import CallableResolver from asyncfast._channel import Channel from asyncfast._channel import HeaderResolver from asyncfast._channel import MessageSenderResolver from asyncfast._channel import PayloadResolver +from asyncfast._channel import Resolver from asyncfast._message import Message from pydantic import BaseModel from pydantic import create_model from pydantic.fields import FieldInfo +def generate_resolvers( + callable_resolver: CallableResolver, +) -> Generator[Resolver[Any], None, None]: + yield from callable_resolver.resolvers.values() + for dependency in callable_resolver.dependencies.values(): + yield from generate_resolvers(dependency) + + @dataclass(frozen=True) class ChannelDefinition: channel: Channel @@ -46,13 +56,17 @@ def name(self) -> str: def title(self) -> str: return "".join(part.title() for part in self.name.split("_")) + @cached_property + def resolvers(self) -> Sequence[Resolver[Any]]: + return tuple(generate_resolvers(self.channel)) + @cached_property def headers_model( self, ) -> type[BaseModel] | None: headers = [ resolver - for resolver in self.channel.resolvers.values() + for resolver in self.resolvers if isinstance(resolver, HeaderResolver) ] if not headers: @@ -74,7 +88,7 @@ def headers_model( def payload(self) -> PayloadResolver[Any] | None: payloads = [ resolver - for resolver in self.channel.resolvers.values() + for resolver in self.resolvers if isinstance(resolver, PayloadResolver) ] if payloads: @@ -97,7 +111,7 @@ def generate_messages(self) -> Generator[type[Message], None, None]: elif _is_message(generator_type): yield generator_type - for resolver in self.channel.resolvers.values(): + for resolver in self.resolvers: if isinstance(resolver, MessageSenderResolver): message_sender_type = get_args(resolver.type)[0] if _is_union(message_sender_type): diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index cab27af..92dac75 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -232,7 +232,9 @@ async def _handle_message( ) -> None: try: - await self._channel_invoker.call(MessageReceive(message, parameters), send) + await self._channel_invoker.invoke( + MessageReceive(message, parameters), send + ) message_ack_event: MessageAckEvent = { "type": "message.ack", diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index ed9aa27..cd12726 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -1,18 +1,29 @@ +from __future__ import annotations + import asyncio import inspect import re from abc import ABC from abc import abstractmethod +from asyncio import AbstractEventLoop +from asyncio import Task +from collections.abc import AsyncGenerator from collections.abc import Callable from collections.abc import Generator from collections.abc import Mapping +from contextlib import AbstractAsyncContextManager +from contextlib import asynccontextmanager +from contextlib import AsyncExitStack from dataclasses import dataclass +from dataclasses import KW_ONLY from functools import cached_property +from functools import wraps from typing import Annotated from typing import Any from typing import Generic from typing import get_args from typing import get_origin +from typing import ParamSpec from typing import TypeVar from amgi_types import AMGISendCallable @@ -25,11 +36,56 @@ _FIELD_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") _PARAMETER_PATTERN = re.compile(r"{(.*)}") - +P = ParamSpec("P") T = TypeVar("T") M = TypeVar("M", bound=Mapping[str, Any]) +def _next_or_stop(generator: Generator[T, None, Any]) -> T | StopIteration: + try: + return next(generator) + except StopIteration as e: + return e + + +def _throw_or_stop( + generator: Generator[T, None, Any], exc: BaseException +) -> T | StopIteration: + try: + return generator.throw(exc) + except StopIteration as e: + return e + + +def asyncify_generator( + func: Callable[P, Generator[T, None, Any]], +) -> Callable[P, AsyncGenerator[T]]: + @wraps(func) + async def wrapped_generator(*args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[T]: + generator = func(*args, **kwargs) + try: + result: T | StopIteration = await asyncio.to_thread( + _next_or_stop, generator + ) + + while True: + if isinstance(result, StopIteration): + return + + try: + yield result + except BaseException as exc: + if isinstance(exc, (SystemExit, KeyboardInterrupt)): + raise + result = await asyncio.to_thread(_throw_or_stop, generator, exc) + else: + result = await asyncio.to_thread(_next_or_stop, generator) + finally: + generator.close() + + return wrapped_generator + + class InvalidChannelDefinitionError(ValueError): """ Raised when a channel or message handler is defined with an invalid shape. @@ -48,6 +104,13 @@ class Parameter(FieldInfo): pass +@dataclass(frozen=True) +class Depends: + func: Callable[..., Any] + _: KW_ONLY + use_cache: bool = True + + @dataclass(frozen=True) class MessageReceive: message: MessageReceiveEvent @@ -146,95 +209,267 @@ def resolve( @dataclass(frozen=True) -class Channel(ABC): +class CallableResolver(ABC): func: Callable[..., Any] - parameters: set[str] resolvers: dict[str, Resolver[Any]] - - def resolve( - self, message_receive: MessageReceive, send: AMGISendCallable + dependencies: dict[str, DependencyResolver] + + async def resolve( + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, ) -> dict[str, Any]: - return { + resolver_result = { name: resolver.resolve(message_receive, send) for name, resolver in self.resolvers.items() } + dependency_result = dict( + zip( + self.dependencies.keys(), + await asyncio.gather( + *( + dependency_cache.resolve( + dependency, message_receive, send, async_exit_stack + ) + for dependency in self.dependencies.values() + ) + ), + ) + ) + + return {**resolver_result, **dependency_result} + @abstractmethod async def call( - self, message_receive: MessageReceive, send: AMGISendCallable - ) -> None: ... + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, + ) -> Any: ... @dataclass(frozen=True) -class SyncChannel(Channel): +class DependencyResolver(CallableResolver, ABC): + use_cache: bool + + +@dataclass(frozen=True) +class SyncDependencyResolver(DependencyResolver): + async def call( + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, + ) -> Any: + return await asyncio.to_thread( + self.func, + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ), + ) + + +@dataclass(frozen=True) +class AsyncDependencyResolver(DependencyResolver): + async def call( + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, + ) -> Any: + return await self.func( + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ) + ) + + +class AsyncYieldingDependencyResolver(DependencyResolver): + @cached_property + def async_context_manager(self) -> Callable[..., AbstractAsyncContextManager[Any]]: + return asynccontextmanager(self.func) + async def call( + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, + ) -> Any: + return await async_exit_stack.enter_async_context( + self.async_context_manager( + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ) + ) + ) + + +class SyncYieldingDependencyResolver(DependencyResolver): + @cached_property + def async_generator_func(self) -> Callable[..., AsyncGenerator[Any]]: + return asyncify_generator(self.func) + + @cached_property + def async_context_manager(self) -> Callable[..., AbstractAsyncContextManager[Any]]: + return asynccontextmanager(self.async_generator_func) + + async def call( + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, + ) -> Any: + return await async_exit_stack.enter_async_context( + self.async_context_manager( + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ) + ) + ) + + +class DependencyCache: + def __init__(self, loop: AbstractEventLoop) -> None: + self.loop = loop + self.cache: dict[Callable[..., Any], Task[Any]] = {} + + def resolve( + self, + dependency: DependencyResolver, + message_receive: MessageReceive, + send: AMGISendCallable, + async_exit_stack: AsyncExitStack, + ) -> Task[Any]: + if dependency.use_cache: + if dependency.func not in self.cache: + self.cache[dependency.func] = self.loop.create_task( + dependency.call(message_receive, send, self, async_exit_stack) + ) + return self.cache[dependency.func] + return self.loop.create_task( + dependency.call(message_receive, send, self, async_exit_stack) + ) + + +@dataclass(frozen=True) +class Channel(CallableResolver, ABC): + parameters: set[str] + + async def invoke( self, message_receive: MessageReceive, send: AMGISendCallable ) -> None: - await asyncio.to_thread(self.func, **self.resolve(message_receive, send)) + dependency_cache = DependencyCache(asyncio.get_event_loop()) + async with AsyncExitStack() as async_exit_stack: + await self.call(message_receive, send, dependency_cache, async_exit_stack) @dataclass(frozen=True) -class AsyncChannel(Channel): +class SyncChannel(Channel): async def call( - self, message_receive: MessageReceive, send: AMGISendCallable + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, ) -> None: - await self.func(**self.resolve(message_receive, send)) + await asyncio.to_thread( + self.func, + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ), + ) -class AsyncGeneratorChannel(Channel): +@dataclass(frozen=True) +class AsyncChannel(Channel): async def call( - self, message_receive: MessageReceive, send: AMGISendCallable + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, ) -> None: - agen = self.func(**self.resolve(message_receive, send)) - exception: Exception | None = None + await self.func( + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ) + ) + + +async def handle_async_generator( + agen: AsyncGenerator[Any], send: AMGISendCallable +) -> None: + try: while True: try: - if exception is None: - message = await agen.__anext__() - else: - message = await agen.athrow(exception) + message = await agen.__anext__() + except StopAsyncIteration: + return + + while True: try: await send_message(send, message) - except Exception as e: - exception = e + except Exception as exc: + try: + message = await agen.athrow(exc) + except StopAsyncIteration: + return else: - exception = None - except StopAsyncIteration: - break + break + finally: + await agen.aclose() -def _throw_or_none(gen: Generator[Any, None, None], exception: Exception) -> Any: - try: - return gen.throw(exception) - except StopIteration: - return None +class AsyncGeneratorChannel(Channel): + async def call( + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, + ) -> None: + agen = self.func( + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ) + ) + + await handle_async_generator(agen, send) class SyncGeneratorChannel(Channel): + @cached_property + def async_generator_func(self) -> Callable[..., AsyncGenerator[Any]]: + return asyncify_generator(self.func) + async def call( - self, message_receive: MessageReceive, send: AMGISendCallable + self, + message_receive: MessageReceive, + send: AMGISendCallable, + dependency_cache: DependencyCache, + async_exit_stack: AsyncExitStack, ) -> None: - gen = self.func(**self.resolve(message_receive, send)) - exception: Exception | None = None - while True: - if exception is None: - message = await asyncio.to_thread(next, gen, None) - else: - message = await asyncio.to_thread(_throw_or_none, gen, exception) - if message is None: - break - try: - await send_message(send, message) - except Exception as e: - exception = e - else: - exception = None + agen = self.async_generator_func( + **await self.resolve( + message_receive, send, dependency_cache, async_exit_stack + ) + ) + + await handle_async_generator(agen, send) def parameter_resolver( name: str, parameter: inspect.Parameter, address_parameters: set[str] -) -> Resolver[Any]: +) -> Resolver[Any] | DependencyResolver: if name in address_parameters: return AddressParameterResolver(name) if get_origin(parameter.annotation) is Annotated: @@ -257,18 +492,50 @@ def parameter_resolver( annotation.__field_name__, parameter.default, ) + if isinstance(annotation, Depends): + resolvers, dependencies = resolvers_dependencies( + annotation.func, address_parameters + ) + + if inspect.iscoroutinefunction(annotation.func): + return AsyncDependencyResolver( + annotation.func, resolvers, dependencies, annotation.use_cache + ) + if inspect.isasyncgenfunction(annotation.func): + return AsyncYieldingDependencyResolver( + annotation.func, resolvers, dependencies, annotation.use_cache + ) + if inspect.isgeneratorfunction(annotation.func): + return SyncYieldingDependencyResolver( + annotation.func, resolvers, dependencies, annotation.use_cache + ) + return SyncDependencyResolver( + annotation.func, resolvers, dependencies, annotation.use_cache + ) + if get_origin(parameter.annotation) is MessageSender: return MessageSenderResolver(parameter.annotation) return PayloadResolver(parameter.annotation) -def channel(func: Callable[..., Any], address_parameters: set[str]) -> Channel: +def resolvers_dependencies( + func: Callable[..., Any], address_parameters: set[str] +) -> tuple[dict[str, Resolver[Any]], dict[str, DependencyResolver]]: signature = inspect.signature(func) - resolvers = { - name: parameter_resolver(name, parameter, address_parameters) - for name, parameter in signature.parameters.items() - } + resolvers = {} + dependencies = {} + for name, parameter in signature.parameters.items(): + resolver = parameter_resolver(name, parameter, address_parameters) + if isinstance(resolver, Resolver): + resolvers[name] = resolver + else: + dependencies[name] = resolver + return resolvers, dependencies + + +def channel(func: Callable[..., Any], address_parameters: set[str]) -> Channel: + resolvers, dependencies = resolvers_dependencies(func, address_parameters) payloads = sum( isinstance(resolver, PayloadResolver) for resolver in resolvers.values() @@ -285,9 +552,9 @@ def channel(func: Callable[..., Any], address_parameters: set[str]) -> Channel: ) if inspect.iscoroutinefunction(func): - return AsyncChannel(func, address_parameters, resolvers) + return AsyncChannel(func, resolvers, dependencies, address_parameters) if inspect.isasyncgenfunction(func): - return AsyncGeneratorChannel(func, address_parameters, resolvers) + return AsyncGeneratorChannel(func, resolvers, dependencies, address_parameters) if inspect.isgeneratorfunction(func): - return SyncGeneratorChannel(func, address_parameters, resolvers) - return SyncChannel(func, address_parameters, resolvers) + return SyncGeneratorChannel(func, resolvers, dependencies, address_parameters) + return SyncChannel(func, resolvers, dependencies, address_parameters) diff --git a/packages/asyncfast/tests_asyncfast/test_asyncapi.py b/packages/asyncfast/tests_asyncfast/test_asyncapi.py index bdf459b..57a50e7 100644 --- a/packages/asyncfast/tests_asyncfast/test_asyncapi.py +++ b/packages/asyncfast/tests_asyncfast/test_asyncapi.py @@ -6,6 +6,7 @@ from uuid import UUID from asyncfast import AsyncFast +from asyncfast import Depends from asyncfast import Header from asyncfast import Message from asyncfast import MessageSender @@ -1134,3 +1135,56 @@ async def receive_handler(id: int) -> AsyncGenerator[Send, None]: }, }, } + + +def test_dependencies() -> None: + app = AsyncFast() + + def dependency( + header1: Annotated[int, Header()], header2: Annotated[int, Header()] + ) -> dict[str, int]: # pragma: no cover + return { + "header1": header1, + "header2": header2, + } + + @app.channel("hello") + async def on_hello(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: + pass # pragma: no cover + + assert app.asyncapi() == { + "asyncapi": "3.0.0", + "channels": { + "OnHello": { + "address": "hello", + "messages": { + "OnHelloMessage": {"$ref": "#/components/messages/OnHelloMessage"} + }, + } + }, + "components": { + "messages": { + "OnHelloMessage": { + "headers": {"$ref": "#/components/schemas/OnHelloHeaders"} + } + }, + "schemas": { + "OnHelloHeaders": { + "properties": { + "header1": {"title": "Header1", "type": "integer"}, + "header2": {"title": "Header2", "type": "integer"}, + }, + "required": ["header1", "header2"], + "title": "OnHelloHeaders", + "type": "object", + } + }, + }, + "info": {"title": "AsyncFast", "version": "0.1.0"}, + "operations": { + "receiveOnHello": { + "action": "receive", + "channel": {"$ref": "#/channels/OnHello"}, + } + }, + } diff --git a/packages/asyncfast/tests_asyncfast/test_channel.py b/packages/asyncfast/tests_asyncfast/test_channel.py index f87b911..68b629c 100644 --- a/packages/asyncfast/tests_asyncfast/test_channel.py +++ b/packages/asyncfast/tests_asyncfast/test_channel.py @@ -4,9 +4,11 @@ from typing import Generator from typing import Mapping from unittest.mock import AsyncMock +from unittest.mock import call from unittest.mock import Mock from asyncfast._channel import channel +from asyncfast._channel import Depends from asyncfast._channel import Header from asyncfast._channel import MessageReceive from asyncfast._channel import MessageSender @@ -19,7 +21,7 @@ async def test_payload_basic() -> None: def func(i: int) -> None: mock(i) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -41,7 +43,7 @@ async def test_header_basic() -> None: def func(header: Annotated[str, Header()]) -> None: mock(header) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -62,7 +64,7 @@ async def test_header_default() -> None: def func(header: Annotated[str, Header()] = "value") -> None: mock(header) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -83,7 +85,7 @@ async def test_header_underscore_to_hyphen() -> None: def func(header_name: Annotated[str, Header()]) -> None: mock(header_name) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -104,7 +106,7 @@ async def test_header_alias() -> None: def func(etag: Annotated[str, Header(alias="ETag")]) -> None: mock(etag) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -125,7 +127,7 @@ async def test_address_parameter() -> None: def func(user: str) -> None: mock(user) - await channel(func, {"user"}).call( + await channel(func, {"user"}).invoke( MessageReceive( { "type": "message.receive", @@ -146,7 +148,7 @@ async def test_binding() -> None: def func(key: Annotated[int, KafkaKey()]) -> None: mock(key) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -168,7 +170,7 @@ async def test_binding_default() -> None: def func(key: Annotated[int, KafkaKey()] = 123) -> None: mock(key) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -189,7 +191,7 @@ async def test_async_func() -> None: async def func(i: int) -> None: await mock(i) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -215,7 +217,7 @@ async def func() -> AsyncGenerator[Mapping[str, Any], None]: "headers": [(b"Id", b"10")], } - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -248,7 +250,7 @@ def func() -> Generator[Mapping[str, Any], None, None]: "headers": [(b"Id", b"10")], } - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -283,7 +285,7 @@ async def func(message_sender: MessageSender[Mapping[str, Any]]) -> None: } ) - await channel(func, set()).call( + await channel(func, set()).invoke( MessageReceive( { "type": "message.receive", @@ -304,3 +306,201 @@ async def func(message_sender: MessageSender[Mapping[str, Any]]) -> None: "payload": b'{"key": "KEY-001"}', } ) + + +async def test_async_depends() -> None: + mock = Mock() + + async def dependency( + header1: Annotated[int, Header()], header2: Annotated[int, Header()] + ) -> dict[str, int]: + return { + "header1": header1, + "header2": header2, + } + + def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: + mock(headers) + + await channel(func, set()).invoke( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with({"header1": 1, "header2": 2}) + + +async def test_sync_depends() -> None: + mock = Mock() + + def dependency( + header1: Annotated[int, Header()], header2: Annotated[int, Header()] + ) -> dict[str, int]: + return { + "header1": header1, + "header2": header2, + } + + def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: + mock(headers) + + await channel(func, set()).invoke( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + {}, + ), + Mock(), + ) + + mock.assert_called_with({"header1": 1, "header2": 2}) + + +async def test_async_depends_use_cache() -> None: + mock_func = Mock() + mock_dependency = Mock() + + async def dependency() -> Any: + return mock_dependency() + + def func( + dependency1: Annotated[Any, Depends(dependency)], + dependency2: Annotated[Any, Depends(dependency)], + ) -> None: + mock_func(dependency1, dependency2) + + await channel(func, set()).invoke( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + }, + {}, + ), + Mock(), + ) + + mock_func.assert_called_with( + mock_dependency.return_value, mock_dependency.return_value + ) + mock_dependency.assert_called_once() + + +async def test_async_depends_use_cache_false() -> None: + mock_func = Mock() + mock_dependency = Mock() + + async def dependency() -> Any: + return mock_dependency() + + def func( + dependency1: Annotated[Any, Depends(dependency, use_cache=False)], + dependency2: Annotated[Any, Depends(dependency, use_cache=False)], + ) -> None: + mock_func(dependency1, dependency2) + + await channel(func, set()).invoke( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [], + }, + {}, + ), + Mock(), + ) + + mock_func.assert_called_with( + mock_dependency.return_value, mock_dependency.return_value + ) + assert mock_dependency.call_count == 2 + + +async def test_async_yielding_depends() -> None: + parent = Mock() + + mock = Mock() + mock_close = Mock() + + parent.attach_mock(mock, "func") + parent.attach_mock(mock_close, "close") + + async def dependency( + header1: Annotated[int, Header()], header2: Annotated[int, Header()] + ) -> AsyncGenerator[dict[str, int], None]: + yield { + "header1": header1, + "header2": header2, + } + mock_close() + + def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: + mock(headers) + + await channel(func, set()).invoke( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + {}, + ), + Mock(), + ) + + assert parent.mock_calls == [ + call.func({"header1": 1, "header2": 2}), + call.close(), + ] + + +async def test_sync_yielding_depends() -> None: + parent = Mock() + + mock = Mock() + mock_close = Mock() + + parent.attach_mock(mock, "func") + parent.attach_mock(mock_close, "close") + + def dependency( + header1: Annotated[int, Header()], header2: Annotated[int, Header()] + ) -> Generator[dict[str, int], None, None]: + yield { + "header1": header1, + "header2": header2, + } + mock_close() + + def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: + mock(headers) + + await channel(func, set()).invoke( + MessageReceive( + { + "type": "message.receive", + "id": "id", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + {}, + ), + Mock(), + ) + + assert parent.mock_calls == [ + call.func({"header1": 1, "header2": 2}), + call.close(), + ] From 216297543e185ebded6c4a2c8c34e8e4f9f57dfb Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 12 Feb 2026 19:04:30 +0000 Subject: [PATCH 11/11] docs(asyncfast): add dependency docs and examples --- packages/asyncfast/docs/dependencies.rst | 57 +++++++++++++++++++ .../docs/examples/dependency_basic.py | 20 +++++++ .../docs/examples/dependency_cache.py | 19 +++++++ .../examples/dependency_sub_dependency.py | 24 ++++++++ .../docs/examples/dependency_yield.py | 23 ++++++++ packages/asyncfast/docs/index.rst | 1 + 6 files changed, 144 insertions(+) create mode 100644 packages/asyncfast/docs/dependencies.rst create mode 100644 packages/asyncfast/docs/examples/dependency_basic.py create mode 100644 packages/asyncfast/docs/examples/dependency_cache.py create mode 100644 packages/asyncfast/docs/examples/dependency_sub_dependency.py create mode 100644 packages/asyncfast/docs/examples/dependency_yield.py diff --git a/packages/asyncfast/docs/dependencies.rst b/packages/asyncfast/docs/dependencies.rst new file mode 100644 index 0000000..ec82281 --- /dev/null +++ b/packages/asyncfast/docs/dependencies.rst @@ -0,0 +1,57 @@ +############## + Dependencies +############## + +Dependencies let you share common logic across channel handlers. They are declared with ``Depends`` and injected using +``typing.Annotated``. + +A dependency function can be ``def`` or ``async def``. It can also be a generator (sync or async) to provide setup and +cleanup logic. + +****************** + Basic Dependency +****************** + +Use ``Depends`` inside ``Annotated`` to tell AsyncFast to resolve a value and pass it to your handler: + +.. async-fast-example:: examples/dependency_basic.py + +********************************** + Dependencies Use The Same Inputs +********************************** + +Dependency functions can declare the same kinds of parameters as a channel handler: + +- ``Payload`` (the message body) +- ``Header`` values +- channel parameters (from the address template) +- bindings (via protocol binding types) +- ``MessageSender`` for sending follow-up messages + +These inputs are resolved for dependencies exactly the same way they are for handlers. + +****************** + Sub-dependencies +****************** + +Dependencies can depend on other dependencies using the same ``Depends`` pattern: + +.. async-fast-example:: examples/dependency_sub_dependency.py + +******************** + Cleanup With Yield +******************** + +If a dependency is a generator, AsyncFast treats it like a context manager and runs cleanup after the handler returns. +This works for both ``def`` generators and ``async def`` generators: + +.. async-fast-example:: examples/dependency_yield.py + +********* + Caching +********* + +By default, a dependency result is cached per message and reused if requested multiple times. To disable caching, set +``use_cache=False`` on ``Depends``: + +.. async-fast-example:: examples/dependency_cache.py diff --git a/packages/asyncfast/docs/examples/dependency_basic.py b/packages/asyncfast/docs/examples/dependency_basic.py new file mode 100644 index 0000000..f6805a2 --- /dev/null +++ b/packages/asyncfast/docs/examples/dependency_basic.py @@ -0,0 +1,20 @@ +from typing import Annotated + +from asyncfast import AsyncFast +from asyncfast import Depends +from asyncfast import Header + +app = AsyncFast() + + +def get_context( + request_id: Annotated[str, Header(alias="request-id")], +) -> dict[str, str]: + return {"request_id": request_id} + + +@app.channel("orders.created") +async def handle_orders( + context: Annotated[dict[str, str], Depends(get_context)], +) -> None: + print(context["request_id"]) diff --git a/packages/asyncfast/docs/examples/dependency_cache.py b/packages/asyncfast/docs/examples/dependency_cache.py new file mode 100644 index 0000000..754d194 --- /dev/null +++ b/packages/asyncfast/docs/examples/dependency_cache.py @@ -0,0 +1,19 @@ +from typing import Annotated + +from asyncfast import AsyncFast +from asyncfast import Depends +from asyncfast import Header + +app = AsyncFast() + + +def get_request_id(request_id: Annotated[str, Header(alias="request-id")]) -> str: + return request_id + + +@app.channel("events") +async def handle_events( + request_id: Annotated[str, Depends(get_request_id, use_cache=False)], + request_id_again: Annotated[str, Depends(get_request_id, use_cache=False)], +) -> None: + print(request_id, request_id_again) diff --git a/packages/asyncfast/docs/examples/dependency_sub_dependency.py b/packages/asyncfast/docs/examples/dependency_sub_dependency.py new file mode 100644 index 0000000..bec5e5e --- /dev/null +++ b/packages/asyncfast/docs/examples/dependency_sub_dependency.py @@ -0,0 +1,24 @@ +from typing import Annotated + +from asyncfast import AsyncFast +from asyncfast import Depends +from asyncfast import Header + +app = AsyncFast() + + +def get_tenant_id(tenant_id: Annotated[str, Header(alias="tenant-id")]) -> str: + return tenant_id + + +def get_context( + tenant_id: Annotated[str, Depends(get_tenant_id)], +) -> dict[str, str]: + return {"tenant_id": tenant_id} + + +@app.channel("billing") +async def handle_billing( + context: Annotated[dict[str, str], Depends(get_context)], +) -> None: + print(context["tenant_id"]) diff --git a/packages/asyncfast/docs/examples/dependency_yield.py b/packages/asyncfast/docs/examples/dependency_yield.py new file mode 100644 index 0000000..e705f02 --- /dev/null +++ b/packages/asyncfast/docs/examples/dependency_yield.py @@ -0,0 +1,23 @@ +from collections.abc import AsyncGenerator +from typing import Annotated + +from asyncfast import AsyncFast +from asyncfast import Depends + +app = AsyncFast() + + +async def get_resource() -> AsyncGenerator[str, None]: + resource = "connected" + try: + yield resource + finally: + # Cleanup happens after the handler returns. + pass + + +@app.channel("ping") +async def handle_ping( + resource: Annotated[str, Depends(get_resource)], +) -> None: + print(resource) diff --git a/packages/asyncfast/docs/index.rst b/packages/asyncfast/docs/index.rst index 8fd643f..b520652 100644 --- a/packages/asyncfast/docs/index.rst +++ b/packages/asyncfast/docs/index.rst @@ -77,6 +77,7 @@ Taking ideas from: receiving sending + dependencies lifespan .. _amgi: https://amgi.readthedocs.io/en/latest/