From abb747bff6f3e5344936da55dc02b630e27aa651 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Fri, 20 Feb 2026 08:42:53 +0000 Subject: [PATCH 1/6] refactor(asyncfast): move asyncapi generation into its own module --- packages/asyncfast/src/asyncfast/_asyncapi.py | 296 ++++++++++++++---- .../asyncfast/src/asyncfast/_asyncfast.py | 179 +---------- packages/asyncfast/src/asyncfast/_channel.py | 18 +- packages/asyncfast/src/asyncfast/_message.py | 34 +- packages/asyncfast/src/asyncfast/_utils.py | 4 +- .../tests_asyncfast/test_address_pattern.py | 4 +- .../tests_asyncfast/test_asyncapi.py | 36 +-- .../asyncfast/tests_asyncfast/test_channel.py | 36 +-- 8 files changed, 313 insertions(+), 294 deletions(-) diff --git a/packages/asyncfast/src/asyncfast/_asyncapi.py b/packages/asyncfast/src/asyncfast/_asyncapi.py index 1d9ba4c..0eac789 100644 --- a/packages/asyncfast/src/asyncfast/_asyncapi.py +++ b/packages/asyncfast/src/asyncfast/_asyncapi.py @@ -1,6 +1,7 @@ import inspect from collections.abc import AsyncGenerator from collections.abc import Generator +from collections.abc import Iterable from collections.abc import Sequence from dataclasses import dataclass from functools import cached_property @@ -21,7 +22,12 @@ from asyncfast._message import Message from pydantic import BaseModel from pydantic import create_model +from pydantic import TypeAdapter from pydantic.fields import FieldInfo +from pydantic.json_schema import GenerateJsonSchema +from pydantic.json_schema import JsonSchemaMode +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema def generate_resolvers( @@ -32,6 +38,92 @@ def generate_resolvers( yield from generate_resolvers(dependency) +@dataclass(frozen=True) +class MessageDefinition: + name: str + address: str | None + parameters: set[str] + headers: Sequence[tuple[str, type[Any], Any]] + bindings: Sequence[tuple[str, str, type[Any], CoreSchema]] + payload: tuple[type[Any], CoreSchema] | None + + @property + def channel_definition(self) -> dict[str, Any]: + definition = { + "address": self.address, + "messages": {self.name: {"$ref": f"#/components/messages/{self.name}"}}, + } + if self.parameters: + definition["parameters"] = {name: {} for name in self.parameters} + return definition + + @cached_property + def headers_model( + self, + ) -> type[BaseModel] | None: + + if not self.headers: + return None + field_definitions: dict[str, Any] = { + name: ( + type_, + FieldInfo( + default=default, + ), + ) + for name, type_, default in self.headers + } + return create_model( + f"{self.name}Headers", __base__=BaseModel, **field_definitions + ) + + def message( + self, + field_mapping: dict[tuple[int, JsonSchemaMode], JsonSchemaValue], + json_schema_mode: JsonSchemaMode, + ) -> dict[str, Any]: + message = {} + + headers_model = self.headers_model + if headers_model: + message["headers"] = field_mapping[hash(headers_model), json_schema_mode] + + payload = self.payload + if payload: + type_, _ = payload + message["payload"] = field_mapping[hash(type_), json_schema_mode] + + if self.bindings: + bindings: dict[str, dict[str, Any]] = {} + for ( + protocol, + field_name, + type_, + _, + ) in self.bindings: + bindings.setdefault(protocol, {})[field_name] = field_mapping[ + hash(type_), json_schema_mode + ] + message["bindings"] = bindings + return message + + def generate_inputs( + self, json_schema_mode: JsonSchemaMode + ) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]: + for _, _, type_, core_schema in self.bindings: + yield hash(type_), json_schema_mode, core_schema + + headers_model = self.headers_model + if headers_model: + yield hash(headers_model), json_schema_mode, TypeAdapter( + headers_model + ).core_schema + payload = self.payload + if payload: + type_, core_schema = payload + yield hash(type_), json_schema_mode, core_schema + + @dataclass(frozen=True) class ChannelDefinition: channel: Channel @@ -60,30 +152,6 @@ def title(self) -> str: def resolvers(self) -> Sequence[Resolver[Any]]: return tuple(generate_resolvers(self.channel)) - @cached_property - def headers_model( - self, - ) -> type[BaseModel] | None: - headers = [ - resolver - for resolver in self.resolvers - if isinstance(resolver, HeaderResolver) - ] - if not headers: - return None - field_definitions: dict[str, Any] = { - resolver.name: ( - resolver.type, - FieldInfo( - default=... if resolver.required else resolver.default, - ), - ) - for resolver in headers - } - return create_model( - f"{self.title}Headers", __base__=BaseModel, **field_definitions - ) - @cached_property def payload(self) -> PayloadResolver[Any] | None: payloads = [ @@ -95,7 +163,7 @@ def payload(self) -> PayloadResolver[Any] | None: return payloads[0] return None - def generate_messages(self) -> Generator[type[Message], None, None]: + def generate_send_messages(self) -> Generator[type[Message], None, None]: signature = inspect.signature(self.channel.func) return_annotation = signature.return_annotation @@ -104,52 +172,172 @@ def generate_messages(self) -> Generator[type[Message], None, None]: or get_origin(return_annotation) is Generator ): generator_type = get_args(return_annotation)[0] - if _is_union(generator_type): - for type in get_args(generator_type): - if _is_message(type): - yield type - elif _is_message(generator_type): + if is_union(generator_type): + for type_ in get_args(generator_type): + if is_message(type_): + yield type_ + elif is_message(generator_type): yield generator_type for resolver in self.resolvers: if isinstance(resolver, MessageSenderResolver): message_sender_type = get_args(resolver.type)[0] - if _is_union(message_sender_type): - for type in get_args(message_sender_type): - if _is_message(type): - yield type - elif _is_message(message_sender_type): + if is_union(message_sender_type): + for type_ in get_args(message_sender_type): + if is_message(type_): + yield type_ + elif is_message(message_sender_type): yield message_sender_type @cached_property - def messages(self) -> Sequence[type[Message]]: - return tuple(self.generate_messages()) + def send_messages(self) -> Sequence[type[Message]]: + return tuple(self.generate_send_messages()) + @cached_property + def send_message_definitions(self) -> Sequence[MessageDefinition]: + return tuple( + MessageDefinition( + message.__name__, + message.__address__, + {name for name in message.__parameters__}, + [ + (alias, field.type, ...) + for name, (alias, field) in message.__headers__.items() + ], + [ + (protocol, field_name, field.type, field.type_adapter.core_schema) + for protocol, field_name, field in message.__bindings__.values() + ], + ( + ( + message.__payload__[1].type, + message.__payload__[1].type_adapter.core_schema, + ) + if message.__payload__ + else None + ), + ) + for message in self.send_messages + ) -@dataclass(frozen=True) -class MessageDefinition: - address: str | None - name: str - parameters: set[str] - - @property - def definition(self) -> dict[str, Any]: - definition = { - "address": self.address, - "messages": {self.name: {"$ref": f"#/components/messages/{self.name}"}}, - } - if self.parameters: - definition["parameters"] = {name: {} for name in self.parameters} - return definition + @cached_property + def message_definition(self) -> MessageDefinition: + return MessageDefinition( + f"{self.title}Message", + self.channel.address, + self.parameters, + [ + ( + resolver.name, + resolver.type, + ... if resolver.required else resolver.default, + ) + for resolver in self.resolvers + if isinstance(resolver, HeaderResolver) + ], + [ + ( + binding_resolver.protocol, + binding_resolver.field_name, + binding_resolver.type, + binding_resolver.type_adapter.core_schema, + ) + for binding_resolver in self.bindings + ], + ( + (self.payload.type, self.payload.type_adapter.core_schema) + if self.payload + else None + ), + ) -def _is_union(type_annotation: type) -> bool: +def is_union(type_annotation: type) -> bool: origin = get_origin(type_annotation) return origin is Union or origin is UnionType -def _is_message(cls: type[Any]) -> bool: +def is_message(cls: type[Any]) -> bool: try: return issubclass(cls, Message) except TypeError: # pragma: no cover return False + + +def generate_inputs( + channel_definitions: Iterable[ChannelDefinition], +) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]: + for channel_definition in channel_definitions: + yield from channel_definition.message_definition.generate_inputs("validation") + + for send_message_definition in channel_definition.send_message_definitions: + yield from send_message_definition.generate_inputs("serialization") + + +def generate_messages( + channel_definitions: Iterable[ChannelDefinition], + field_mapping: dict[tuple[int, JsonSchemaMode], JsonSchemaValue], +) -> Generator[tuple[str, dict[str, Any]], None, None]: + for channel_definition in channel_definitions: + + yield channel_definition.message_definition.name, channel_definition.message_definition.message( + field_mapping, "validation" + ) + + for send_message_definition in channel_definition.send_message_definitions: + yield send_message_definition.name, send_message_definition.message( + field_mapping, "serialization" + ) + + +def generate_channels( + channel_definitions: Iterable[ChannelDefinition], +) -> Generator[tuple[str, dict[str, Any]], None, None]: + for channel_definition in channel_definitions: + yield channel_definition.title, channel_definition.message_definition.channel_definition + + for send_message_definition in channel_definition.send_message_definitions: + yield send_message_definition.name, send_message_definition.channel_definition + + +def generate_operations( + channel_definitions: Iterable[ChannelDefinition], +) -> Generator[tuple[str, dict[str, Any]], None, None]: + for channel_definition in channel_definitions: + yield f"receive{channel_definition.title}", { + "action": "receive", + "channel": {"$ref": f"#/channels/{channel_definition.title}"}, + } + + for message in channel_definition.send_messages: + yield f"send{message.__name__}", { + "action": "send", + "channel": {"$ref": f"#/channels/{message.__name__}"}, + } + + +def get_asyncapi( + *, + title: str, + version: str, + channel_definitions: Sequence[ChannelDefinition], +) -> dict[str, Any]: + schema_generator = GenerateJsonSchema(ref_template="#/components/schemas/{model}") + + field_mapping, definitions = schema_generator.generate_definitions( + inputs=list(generate_inputs(channel_definitions)) + ) + + return { + "asyncapi": "3.0.0", + "info": { + "title": title, + "version": version, + }, + "channels": dict(generate_channels(channel_definitions)), + "operations": dict(generate_operations(channel_definitions)), + "components": { + "messages": dict(generate_messages(channel_definitions, field_mapping)), + **({"schemas": definitions} if definitions else {}), + }, + } diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index e19beb9..570b87c 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -1,12 +1,8 @@ from collections.abc import Callable -from collections.abc import Generator -from collections.abc import Iterable -from collections.abc import Mapping from contextlib import AbstractAsyncContextManager from functools import partial from re import Pattern from typing import Any -from typing import get_args from typing import TypeVar from amgi_types import AMGIReceiveCallable @@ -18,21 +14,13 @@ from amgi_types import MessageScope from amgi_types import Scope from asyncfast._asyncapi import ChannelDefinition -from asyncfast._asyncapi import MessageDefinition +from asyncfast._asyncapi import get_asyncapi from asyncfast._channel import Channel from asyncfast._channel import channel as make_channel from asyncfast._channel import MessageReceive -from asyncfast._utils import _address_pattern -from asyncfast._utils import _get_address_parameters -from asyncfast.bindings import Binding -from pydantic import TypeAdapter -from pydantic.json_schema import GenerateJsonSchema -from pydantic.json_schema import JsonSchemaMode -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import CoreSchema +from asyncfast._utils import get_address_pattern DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) -M = TypeVar("M", bound=Mapping[str, Any]) Lifespan = Callable[["AsyncFast"], AbstractAsyncContextManager[None]] @@ -54,6 +42,7 @@ def __init__( self._version = version self._lifespan_context = lifespan self._lifespan: AbstractAsyncContextManager[None] | None = None + self._asyncapi_schema: dict[str, Any] | None = None @property def title(self) -> str: @@ -69,12 +58,12 @@ def channel(self, address: str) -> Callable[[DecoratedCallable], DecoratedCallab def _add_channel( self, address: str, function: DecoratedCallable ) -> DecoratedCallable: - address_pattern = _address_pattern(address) + address_pattern = get_address_pattern(address) channel = _Channel( address, address_pattern, - make_channel(function, _get_address_parameters(address)), + make_channel(function, address), ) self._channels.append(channel) @@ -112,61 +101,18 @@ async def __call__( raise ChannelNotFoundError(address) def asyncapi(self) -> dict[str, Any]: - schema_generator = GenerateJsonSchema( - ref_template="#/components/schemas/{model}" - ) - - field_mapping, definitions = schema_generator.generate_definitions( - inputs=list(self._generate_inputs()) - ) - return { - "asyncapi": "3.0.0", - "info": { - "title": self.title, - "version": self.version, - }, - "channels": dict(_generate_channels(self._channels)), - "operations": dict(_generate_operations(self._channels)), - "components": { - "messages": dict(_generate_messages(self._channels, field_mapping)), - **({"schemas": definitions} if definitions else {}), - }, - } - - def _generate_inputs( - self, - ) -> Generator[tuple[int, JsonSchemaMode, CoreSchema], None, None]: - for channel in self._channels: - for binding_resolver in channel._channel_definition.bindings: - yield hash( - binding_resolver.type - ), "validation", binding_resolver.type_adapter.core_schema - - headers_model = channel._channel_definition.headers_model - if headers_model: - yield hash(headers_model), "validation", TypeAdapter( - headers_model - ).core_schema - payload = channel._channel_definition.payload - if payload: - yield hash(payload.type), "validation", payload.type_adapter.core_schema - - for message in channel._channel_definition.messages: - if message.__payload__: - _, field = message.__payload__ - - yield hash(field), "serialization", field.type_adapter.core_schema + if not self._asyncapi_schema: + channel_definitions = tuple( + ChannelDefinition(channel._channel_invoker) + for channel in self._channels + ) + self._asyncapi_schema = get_asyncapi( + title=self.title, + version=self.version, + channel_definitions=channel_definitions, + ) - for _, _, field in message.__bindings__.values(): - yield hash( - field.type - ), "serialization", field.type_adapter.core_schema - - message_headers_model = message._headers_model() - if message_headers_model: - yield hash(message_headers_model), "serialization", TypeAdapter( - message_headers_model - ).core_schema + return self._asyncapi_schema class _Channel: @@ -179,11 +125,6 @@ def __init__( self._address = address self._address_pattern = address_pattern self._channel_invoker = channel_invoker - self._channel_definition = ChannelDefinition(channel_invoker) - - @property - def address(self) -> str: - return self._address def match(self, address: str) -> dict[str, str] | None: match = self._address_pattern.match(address) @@ -210,91 +151,3 @@ async def __call__( "message": str(e), } await send(message_nack_event) - - -def _generate_messages( - channels: Iterable[_Channel], - field_mapping: dict[tuple[int, JsonSchemaMode], JsonSchemaValue], -) -> Generator[tuple[str, dict[str, Any]], None, None]: - for channel in channels: - message = {} - - headers_model = channel._channel_definition.headers_model - if headers_model: - message["headers"] = field_mapping[hash(headers_model), "validation"] - - payload = channel._channel_definition.payload - if payload: - message["payload"] = field_mapping[hash(payload.type), "validation"] - - bindings: dict[str, dict[str, Any]] - if channel._channel_definition.bindings: - bindings = {} - for binding_resolver in channel._channel_definition.bindings: - - bindings.setdefault(binding_resolver.protocol, {})[ - binding_resolver.field_name - ] = field_mapping[hash(binding_resolver.type), "validation"] - message["bindings"] = bindings - - yield f"{channel._channel_definition.title}Message", message - - for channel_message in channel._channel_definition.messages: - message_message = {} - - if channel_message.__payload__: - _, field = channel_message.__payload__ - message_message["payload"] = field_mapping[hash(field), "serialization"] - - message_headers_model = channel_message._headers_model() - if message_headers_model: - message_message["headers"] = field_mapping[ - hash(message_headers_model), "serialization" - ] - - if channel_message.__bindings__: - bindings = {} - for _, _, field in channel_message.__bindings__.values(): - binding_type = get_args(field.type)[1] - assert isinstance(binding_type, Binding) - - bindings.setdefault(binding_type.__protocol__, {})[ - binding_type.__field_name__ - ] = field_mapping[hash(field), "serialization"] - message_message["bindings"] = bindings - - yield channel_message.__name__, message_message - - -def _generate_channels( - channels: Iterable[_Channel], -) -> Generator[tuple[str, dict[str, Any]], None, None]: - for channel in channels: - yield channel._channel_definition.title, MessageDefinition( - channel.address, - f"{channel._channel_definition.title}Message", - channel._channel_definition.parameters, - ).definition - - for message in channel._channel_definition.messages: - yield message.__name__, MessageDefinition( - message.__address__, - message.__name__, - {name for name in message.__parameters__}, - ).definition - - -def _generate_operations( - channels: Iterable[_Channel], -) -> Generator[tuple[str, dict[str, Any]], None, None]: - for channel in channels: - yield f"receive{channel._channel_definition.title}", { - "action": "receive", - "channel": {"$ref": f"#/channels/{channel._channel_definition.title}"}, - } - - for message in channel._channel_definition.messages: - yield f"send{message.__name__}", { - "action": "send", - "channel": {"$ref": f"#/channels/{message.__name__}"}, - } diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index f81e3ba..1897d2e 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -29,6 +29,7 @@ from amgi_types import AMGISendCallable from amgi_types import MessageScope from amgi_types import MessageSendEvent +from asyncfast._utils import get_address_parameters from asyncfast.bindings import Binding from pydantic import TypeAdapter from pydantic.fields import FieldInfo @@ -39,6 +40,7 @@ P = ParamSpec("P") T = TypeVar("T") M = TypeVar("M", bound=Mapping[str, Any]) +DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) def _next_or_stop(generator: Generator[T, None, Any]) -> T | StopIteration: @@ -361,6 +363,7 @@ def resolve( @dataclass(frozen=True) class Channel(CallableResolver, ABC): + address: str parameters: set[str] async def invoke( @@ -534,7 +537,8 @@ def resolvers_dependencies( return resolvers, dependencies -def channel(func: Callable[..., Any], address_parameters: set[str]) -> Channel: +def channel(func: Callable[..., Any], address: str) -> Channel: + address_parameters = get_address_parameters(address) resolvers, dependencies = resolvers_dependencies(func, address_parameters) payloads = sum( @@ -552,9 +556,13 @@ def channel(func: Callable[..., Any], address_parameters: set[str]) -> Channel: ) if inspect.iscoroutinefunction(func): - return AsyncChannel(func, resolvers, dependencies, address_parameters) + return AsyncChannel(func, resolvers, dependencies, address, address_parameters) if inspect.isasyncgenfunction(func): - return AsyncGeneratorChannel(func, resolvers, dependencies, address_parameters) + return AsyncGeneratorChannel( + func, resolvers, dependencies, address, address_parameters + ) if inspect.isgeneratorfunction(func): - return SyncGeneratorChannel(func, resolvers, dependencies, address_parameters) - return SyncChannel(func, resolvers, dependencies, address_parameters) + return SyncGeneratorChannel( + func, resolvers, dependencies, address, address_parameters + ) + return SyncChannel(func, resolvers, dependencies, address, address_parameters) diff --git a/packages/asyncfast/src/asyncfast/_message.py b/packages/asyncfast/src/asyncfast/_message.py index 50dfa7d..b60615c 100644 --- a/packages/asyncfast/src/asyncfast/_message.py +++ b/packages/asyncfast/src/asyncfast/_message.py @@ -11,10 +11,8 @@ from asyncfast._channel import Header from asyncfast._channel import Parameter from asyncfast._channel import Payload -from asyncfast._utils import _get_address_parameters +from asyncfast._utils import get_address_parameters from asyncfast.bindings import Binding -from pydantic import BaseModel -from pydantic import create_model from pydantic import TypeAdapter @@ -31,7 +29,6 @@ class Message(Mapping[str, Any]): __address__: ClassVar[str | None] = None __headers__: ClassVar[dict[str, tuple[str, _Field]]] - __headers_model__: ClassVar[type[BaseModel] | None] __parameters__: ClassVar[dict[str, TypeAdapter[Any]]] __payload__: ClassVar[tuple[str, _Field] | None] __bindings__: ClassVar[dict[str, tuple[str, str, _Field]]] @@ -140,39 +137,12 @@ def _get_bindings(self) -> dict[str, dict[str, Any]]: ) return bindings - @classmethod - def _headers_model(cls) -> type[BaseModel] | None: - if not hasattr(cls, "__headers_model__"): - if cls.__headers__: - cls.__headers_model__ = _create_headers_model( - f"{cls.__name__}Headers", cls.__headers__ - ) - else: - cls.__headers_model__ = None - return cls.__headers_model__ - - -def _generate_field_definitions( - headers: Mapping[str, tuple[str, _Field]], -) -> Iterator[tuple[str, Any]]: - for name, (alias, field) in headers.items(): - type_, annotation = get_args(field.type) - yield alias, (type_, annotation) - - -def _create_headers_model( - headers_name: str, headers: Mapping[str, tuple[str, _Field]] -) -> type[BaseModel]: - return create_model( - headers_name, __base__=BaseModel, **dict(_generate_field_definitions(headers)) - ) - def _generate_message_annotations( address: str | None, fields: dict[str, Any], ) -> Generator[tuple[str, type[Annotated[Any, Any]]], None, None]: - address_parameters = _get_address_parameters(address) + address_parameters = get_address_parameters(address) for name, field in fields.items(): if get_origin(field) is Annotated: yield name, field diff --git a/packages/asyncfast/src/asyncfast/_utils.py b/packages/asyncfast/src/asyncfast/_utils.py index 3de5110..6831237 100644 --- a/packages/asyncfast/src/asyncfast/_utils.py +++ b/packages/asyncfast/src/asyncfast/_utils.py @@ -6,7 +6,7 @@ _PARAMETER_PATTERN = re.compile(r"{(.*)}") -def _get_address_parameters(address: str | None) -> set[str]: +def get_address_parameters(address: str | None) -> set[str]: if address is None: return set() parameters = _PARAMETER_PATTERN.findall(address) @@ -18,7 +18,7 @@ def _get_address_parameters(address: str | None) -> set[str]: return set(parameters) -def _address_pattern(address: str) -> Pattern[str]: +def get_address_pattern(address: str) -> Pattern[str]: index = 0 address_regex = "^" for match in _PARAMETER_PATTERN.finditer(address): diff --git a/packages/asyncfast/tests_asyncfast/test_address_pattern.py b/packages/asyncfast/tests_asyncfast/test_address_pattern.py index 7cfe543..97fb724 100644 --- a/packages/asyncfast/tests_asyncfast/test_address_pattern.py +++ b/packages/asyncfast/tests_asyncfast/test_address_pattern.py @@ -1,7 +1,7 @@ import re import pytest -from asyncfast._utils import _address_pattern +from asyncfast._utils import get_address_pattern @pytest.mark.parametrize( @@ -13,4 +13,4 @@ ], ) def test_address_pattern(address: str, pattern: str) -> None: - assert _address_pattern(address) == re.compile(pattern) + assert get_address_pattern(address) == re.compile(pattern) diff --git a/packages/asyncfast/tests_asyncfast/test_asyncapi.py b/packages/asyncfast/tests_asyncfast/test_asyncapi.py index 57a50e7..76b2e32 100644 --- a/packages/asyncfast/tests_asyncfast/test_asyncapi.py +++ b/packages/asyncfast/tests_asyncfast/test_asyncapi.py @@ -35,16 +35,16 @@ async def on_hello(request_id: Annotated[int, Header()]) -> None: "components": { "messages": { "OnHelloMessage": { - "headers": {"$ref": "#/components/schemas/OnHelloHeaders"} + "headers": {"$ref": "#/components/schemas/OnHelloMessageHeaders"} } }, "schemas": { - "OnHelloHeaders": { + "OnHelloMessageHeaders": { "properties": { "request-id": {"title": "Request-Id", "type": "integer"} }, "required": ["request-id"], - "title": "OnHelloHeaders", + "title": "OnHelloMessageHeaders", "type": "object", } }, @@ -79,16 +79,16 @@ async def on_hello(request_id: Annotated[int, Header(alias="Request-Id")]) -> No "components": { "messages": { "OnHelloMessage": { - "headers": {"$ref": "#/components/schemas/OnHelloHeaders"} + "headers": {"$ref": "#/components/schemas/OnHelloMessageHeaders"} } }, "schemas": { - "OnHelloHeaders": { + "OnHelloMessageHeaders": { "properties": { "Request-Id": {"title": "Request-Id", "type": "integer"} }, "required": ["Request-Id"], - "title": "OnHelloHeaders", + "title": "OnHelloMessageHeaders", "type": "object", } }, @@ -123,16 +123,16 @@ def on_hello(request_id: Annotated[int, Header()]) -> None: "components": { "messages": { "OnHelloMessage": { - "headers": {"$ref": "#/components/schemas/OnHelloHeaders"} + "headers": {"$ref": "#/components/schemas/OnHelloMessageHeaders"} } }, "schemas": { - "OnHelloHeaders": { + "OnHelloMessageHeaders": { "properties": { "request-id": {"title": "Request-Id", "type": "integer"} }, "required": ["request-id"], - "title": "OnHelloHeaders", + "title": "OnHelloMessageHeaders", "type": "object", } }, @@ -169,11 +169,11 @@ async def on_hello( "components": { "messages": { "OnHelloMessage": { - "headers": {"$ref": "#/components/schemas/OnHelloHeaders"} + "headers": {"$ref": "#/components/schemas/OnHelloMessageHeaders"} } }, "schemas": { - "OnHelloHeaders": { + "OnHelloMessageHeaders": { "properties": { "request-id": { "description": "Id to correlate the request", @@ -182,7 +182,7 @@ async def on_hello( } }, "required": ["request-id"], - "title": "OnHelloHeaders", + "title": "OnHelloMessageHeaders", "type": "object", } }, @@ -939,12 +939,12 @@ async def notification_channel_handler( "messages": { "NotificationChannelHandlerMessage": { "headers": { - "$ref": "#/components/schemas/NotificationChannelHandlerHeaders" + "$ref": "#/components/schemas/NotificationChannelHandlerMessageHeaders" } } }, "schemas": { - "NotificationChannelHandlerHeaders": { + "NotificationChannelHandlerMessageHeaders": { "properties": { "request-id": { "default": 0, @@ -952,7 +952,7 @@ async def notification_channel_handler( "type": "integer", } }, - "title": "NotificationChannelHandlerHeaders", + "title": "NotificationChannelHandlerMessageHeaders", "type": "object", } }, @@ -1165,17 +1165,17 @@ async def on_hello(headers: Annotated[dict[str, int], Depends(dependency)]) -> N "components": { "messages": { "OnHelloMessage": { - "headers": {"$ref": "#/components/schemas/OnHelloHeaders"} + "headers": {"$ref": "#/components/schemas/OnHelloMessageHeaders"} } }, "schemas": { - "OnHelloHeaders": { + "OnHelloMessageHeaders": { "properties": { "header1": {"title": "Header1", "type": "integer"}, "header2": {"title": "Header2", "type": "integer"}, }, "required": ["header1", "header2"], - "title": "OnHelloHeaders", + "title": "OnHelloMessageHeaders", "type": "object", } }, diff --git a/packages/asyncfast/tests_asyncfast/test_channel.py b/packages/asyncfast/tests_asyncfast/test_channel.py index 8f2d1c2..bc3462f 100644 --- a/packages/asyncfast/tests_asyncfast/test_channel.py +++ b/packages/asyncfast/tests_asyncfast/test_channel.py @@ -21,7 +21,7 @@ async def test_payload_basic() -> None: def func(i: int) -> None: mock(i) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -44,7 +44,7 @@ async def test_header_basic() -> None: def func(header: Annotated[str, Header()]) -> None: mock(header) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -66,7 +66,7 @@ async def test_header_default() -> None: def func(header: Annotated[str, Header()] = "value") -> None: mock(header) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -88,7 +88,7 @@ async def test_header_underscore_to_hyphen() -> None: def func(header_name: Annotated[str, Header()]) -> None: mock(header_name) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -110,7 +110,7 @@ async def test_header_alias() -> None: def func(etag: Annotated[str, Header(alias="ETag")]) -> None: mock(etag) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -132,7 +132,7 @@ async def test_address_parameter() -> None: def func(user: str) -> None: mock(user) - await channel(func, {"user"}).invoke( + await channel(func, "channel.{user}").invoke( MessageReceive( { "type": "message", @@ -154,7 +154,7 @@ async def test_binding() -> None: def func(key: Annotated[int, KafkaKey()]) -> None: mock(key) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -177,7 +177,7 @@ async def test_binding_default() -> None: def func(key: Annotated[int, KafkaKey()] = 123) -> None: mock(key) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -199,7 +199,7 @@ async def test_async_func() -> None: async def func(i: int) -> None: await mock(i) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -226,7 +226,7 @@ async def func() -> AsyncGenerator[Mapping[str, Any], None]: "headers": [(b"Id", b"10")], } - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -260,7 +260,7 @@ def func() -> Generator[Mapping[str, Any], None, None]: "headers": [(b"Id", b"10")], } - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -296,7 +296,7 @@ async def func(message_sender: MessageSender[Mapping[str, Any]]) -> None: } ) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -334,7 +334,7 @@ async def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -364,7 +364,7 @@ def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -393,7 +393,7 @@ def func( ) -> None: mock_func(dependency1, dependency2) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -425,7 +425,7 @@ def func( ) -> None: mock_func(dependency1, dependency2) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -465,7 +465,7 @@ async def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", @@ -505,7 +505,7 @@ def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, set()).invoke( + await channel(func, "channel").invoke( MessageReceive( { "type": "message", From f10fa0c54425bcd971aeb9d65befc7d56711d48e Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Fri, 20 Feb 2026 21:05:18 +0000 Subject: [PATCH 2/6] fix(asyncfast): handle untyped handler parameters safely --- packages/asyncfast/src/asyncfast/_channel.py | 3 +- .../tests_asyncfast/test_asyncapi.py | 31 ++++++++++++ .../asyncfast/tests_asyncfast/test_message.py | 49 +++++++++++++++++++ 3 files changed, 82 insertions(+), 1 deletion(-) diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index 1897d2e..6655d0b 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -519,7 +519,8 @@ def parameter_resolver( if get_origin(parameter.annotation) is MessageSender: return MessageSenderResolver(parameter.annotation) - return PayloadResolver(parameter.annotation) + type_ = object if parameter.empty == parameter.annotation else parameter.annotation + return PayloadResolver(type_) def resolvers_dependencies( diff --git a/packages/asyncfast/tests_asyncfast/test_asyncapi.py b/packages/asyncfast/tests_asyncfast/test_asyncapi.py index 76b2e32..c3e6ee7 100644 --- a/packages/asyncfast/tests_asyncfast/test_asyncapi.py +++ b/packages/asyncfast/tests_asyncfast/test_asyncapi.py @@ -1188,3 +1188,34 @@ async def on_hello(headers: Annotated[dict[str, int], Depends(dependency)]) -> N } }, } + + +async def test_untyped() -> None: + app = AsyncFast() + + @app.channel("topic.{name}") + async def topic_handler(payload, name): # type: ignore[no-untyped-def] + pass # pragma: no cover + + assert app.asyncapi() == { + "asyncapi": "3.0.0", + "channels": { + "TopicHandler": { + "address": "topic.{name}", + "messages": { + "TopicHandlerMessage": { + "$ref": "#/components/messages/TopicHandlerMessage" + } + }, + "parameters": {"name": {}}, + } + }, + "components": {"messages": {"TopicHandlerMessage": {"payload": {}}}}, + "info": {"title": "AsyncFast", "version": "0.1.0"}, + "operations": { + "receiveTopicHandler": { + "action": "receive", + "channel": {"$ref": "#/channels/TopicHandler"}, + } + }, + } diff --git a/packages/asyncfast/tests_asyncfast/test_message.py b/packages/asyncfast/tests_asyncfast/test_message.py index a47ac3c..4e1a8ce 100644 --- a/packages/asyncfast/tests_asyncfast/test_message.py +++ b/packages/asyncfast/tests_asyncfast/test_message.py @@ -831,3 +831,52 @@ async def topic_handler(id: int) -> None: AsyncMock(), AsyncMock(), ) + + +async def test_message_payload_untyped() -> None: + app = AsyncFast() + + test_mock = Mock() + + @app.channel("topic") + async def topic_handler(payload): # type: ignore[no-untyped-def] + test_mock(payload) + + message_scope: MessageScope = { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "topic", + "headers": [], + "payload": b'{"id":1}', + } + await app( + message_scope, + AsyncMock(), + AsyncMock(), + ) + + test_mock.assert_called_once_with({"id": 1}) + + +async def test_message_address_parameter_untyped() -> None: + app = AsyncFast() + + test_mock = Mock() + + @app.channel("topic.{name}") + async def topic_handler(name): # type: ignore[no-untyped-def] + test_mock(name) + + message_scope: MessageScope = { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "topic.name", + "headers": [], + } + await app( + message_scope, + AsyncMock(), + AsyncMock(), + ) + + test_mock.assert_called_once_with("name") From f44f0579d657820790bf7abd41def6a8951797b8 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Fri, 20 Feb 2026 21:05:20 +0000 Subject: [PATCH 3/6] feat(asyncfast): add router-based dispatch --- packages/asyncfast/src/asyncfast/__init__.py | 2 +- packages/asyncfast/src/asyncfast/_asyncapi.py | 6 +- .../asyncfast/src/asyncfast/_asyncfast.py | 81 +---- packages/asyncfast/src/asyncfast/_channel.py | 87 ++++- .../asyncfast/tests_asyncfast/test_channel.py | 339 ++++++++---------- .../asyncfast/tests_asyncfast/test_message.py | 19 +- t.py | 0 7 files changed, 253 insertions(+), 281 deletions(-) create mode 100644 t.py diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index c41a60e..0ea5adf 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -1,5 +1,5 @@ from asyncfast._asyncfast import AsyncFast -from asyncfast._asyncfast import ChannelNotFoundError +from asyncfast._channel import ChannelNotFoundError from asyncfast._channel import Depends from asyncfast._channel import Header from asyncfast._channel import InvalidChannelDefinitionError diff --git a/packages/asyncfast/src/asyncfast/_asyncapi.py b/packages/asyncfast/src/asyncfast/_asyncapi.py index 0eac789..ea13441 100644 --- a/packages/asyncfast/src/asyncfast/_asyncapi.py +++ b/packages/asyncfast/src/asyncfast/_asyncapi.py @@ -19,6 +19,7 @@ from asyncfast._channel import MessageSenderResolver from asyncfast._channel import PayloadResolver from asyncfast._channel import Resolver +from asyncfast._channel import Router from asyncfast._message import Message from pydantic import BaseModel from pydantic import create_model @@ -320,8 +321,11 @@ def get_asyncapi( *, title: str, version: str, - channel_definitions: Sequence[ChannelDefinition], + router: Router, ) -> dict[str, Any]: + channel_definitions = tuple( + ChannelDefinition(channel) for channel in router.channels + ) schema_generator = GenerateJsonSchema(ref_template="#/components/schemas/{model}") field_mapping, definitions = schema_generator.generate_definitions( diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index 570b87c..55a73db 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -1,7 +1,6 @@ from collections.abc import Callable from contextlib import AbstractAsyncContextManager from functools import partial -from re import Pattern from typing import Any from typing import TypeVar @@ -9,27 +8,14 @@ from amgi_types import AMGISendCallable from amgi_types import LifespanShutdownCompleteEvent from amgi_types import LifespanStartupCompleteEvent -from amgi_types import MessageAckEvent -from amgi_types import MessageNackEvent -from amgi_types import MessageScope from amgi_types import Scope -from asyncfast._asyncapi import ChannelDefinition from asyncfast._asyncapi import get_asyncapi -from asyncfast._channel import Channel -from asyncfast._channel import channel as make_channel -from asyncfast._channel import MessageReceive -from asyncfast._utils import get_address_pattern +from asyncfast._channel import Router DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) Lifespan = Callable[["AsyncFast"], AbstractAsyncContextManager[None]] -class ChannelNotFoundError(LookupError): - def __init__(self, address: str) -> None: - super().__init__(f"Couldn't resolve address: {address}") - self.address = address - - class AsyncFast: def __init__( self, @@ -37,10 +23,10 @@ def __init__( version: str = "0.1.0", lifespan: Lifespan | None = None, ) -> None: - self._channels: list[_Channel] = [] self._title = title self._version = version self._lifespan_context = lifespan + self._router = Router() self._lifespan: AbstractAsyncContextManager[None] | None = None self._asyncapi_schema: dict[str, Any] | None = None @@ -58,15 +44,7 @@ def channel(self, address: str) -> Callable[[DecoratedCallable], DecoratedCallab def _add_channel( self, address: str, function: DecoratedCallable ) -> DecoratedCallable: - address_pattern = get_address_pattern(address) - - channel = _Channel( - address, - address_pattern, - make_channel(function, address), - ) - - self._channels.append(channel) + self._router.add_channel(address, function) return function async def __call__( @@ -92,62 +70,15 @@ async def __call__( await send(lifespan_shutdown_complete_event) return elif scope["type"] == "message": - address = scope["address"] - for channel in self._channels: - parameters = channel.match(address) - if parameters is not None: - await channel(scope, send, parameters) - return - raise ChannelNotFoundError(address) + await self._router(scope, receive, send) def asyncapi(self) -> dict[str, Any]: if not self._asyncapi_schema: - channel_definitions = tuple( - ChannelDefinition(channel._channel_invoker) - for channel in self._channels - ) + self._asyncapi_schema = get_asyncapi( title=self.title, version=self.version, - channel_definitions=channel_definitions, + router=self._router, ) return self._asyncapi_schema - - -class _Channel: - def __init__( - self, - address: str, - address_pattern: Pattern[str], - channel_invoker: Channel, - ) -> None: - self._address = address - self._address_pattern = address_pattern - self._channel_invoker = channel_invoker - - def match(self, address: str) -> dict[str, str] | None: - match = self._address_pattern.match(address) - if match: - return match.groupdict() - return None - - async def __call__( - self, - scope: MessageScope, - send: AMGISendCallable, - parameters: dict[str, str], - ) -> None: - try: - await self._channel_invoker.invoke(MessageReceive(scope, parameters), send) - - message_ack_event: MessageAckEvent = { - "type": "message.ack", - } - await send(message_ack_event) - except Exception as e: - message_nack_event: MessageNackEvent = { - "type": "message.nack", - "message": str(e), - } - await send(message_nack_event) diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index 6655d0b..2314d5f 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -18,6 +18,7 @@ from dataclasses import KW_ONLY from functools import cached_property from functools import wraps +from re import Pattern from typing import Annotated from typing import Any from typing import Generic @@ -26,10 +27,14 @@ from typing import ParamSpec from typing import TypeVar +from amgi_types import AMGIReceiveCallable from amgi_types import AMGISendCallable +from amgi_types import MessageAckEvent +from amgi_types import MessageNackEvent from amgi_types import MessageScope from amgi_types import MessageSendEvent from asyncfast._utils import get_address_parameters +from asyncfast._utils import get_address_pattern from asyncfast.bindings import Binding from pydantic import TypeAdapter from pydantic.fields import FieldInfo @@ -94,6 +99,16 @@ class InvalidChannelDefinitionError(ValueError): """ +class RouteInvariantError(RuntimeError): + """Raised when a selected route fails to match its address.""" + + +class ChannelNotFoundError(LookupError): + def __init__(self, address: str) -> None: + super().__init__(f"Couldn't resolve address: {address}") + self.address = address + + class Header(FieldInfo): pass @@ -364,15 +379,32 @@ def resolve( @dataclass(frozen=True) class Channel(CallableResolver, ABC): address: str + address_pattern: Pattern[str] parameters: set[str] - async def invoke( - self, message_receive: MessageReceive, send: AMGISendCallable + async def __call__( + self, + scope: MessageScope, + receive: AMGIReceiveCallable, + send: AMGISendCallable, + parameters: dict[str, str] | None = None, ) -> None: + parameters = self.match(scope["address"]) if parameters is None else parameters + if parameters is None: + raise RouteInvariantError( + f"Selected route did not match address {scope['address']!r}" + ) + message_receive = MessageReceive(scope, parameters) dependency_cache = DependencyCache(asyncio.get_event_loop()) async with AsyncExitStack() as async_exit_stack: await self.call(message_receive, send, dependency_cache, async_exit_stack) + def match(self, address: str) -> dict[str, str] | None: + match = self.address_pattern.match(address) + if match: + return match.groupdict() + return None + @dataclass(frozen=True) class SyncChannel(Channel): @@ -538,8 +570,9 @@ def resolvers_dependencies( return resolvers, dependencies -def channel(func: Callable[..., Any], address: str) -> Channel: +def get_channel(func: Callable[..., Any], address: str) -> Channel: address_parameters = get_address_parameters(address) + address_pattern = get_address_pattern(address) resolvers, dependencies = resolvers_dependencies(func, address_parameters) payloads = sum( @@ -557,13 +590,53 @@ def channel(func: Callable[..., Any], address: str) -> Channel: ) if inspect.iscoroutinefunction(func): - return AsyncChannel(func, resolvers, dependencies, address, address_parameters) + return AsyncChannel( + func, resolvers, dependencies, address, address_pattern, address_parameters + ) if inspect.isasyncgenfunction(func): return AsyncGeneratorChannel( - func, resolvers, dependencies, address, address_parameters + func, resolvers, dependencies, address, address_pattern, address_parameters ) if inspect.isgeneratorfunction(func): return SyncGeneratorChannel( - func, resolvers, dependencies, address, address_parameters + func, resolvers, dependencies, address, address_pattern, address_parameters ) - return SyncChannel(func, resolvers, dependencies, address, address_parameters) + return SyncChannel( + func, resolvers, dependencies, address, address_pattern, address_parameters + ) + + +class Router: + def __init__(self) -> None: + self.channels: list[Channel] = [] + + def add_channel(self, address: str, func: Callable[..., Any]) -> None: + self.channels.append(get_channel(func, address)) + + async def __call__( + self, scope: MessageScope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + try: + await self.call_channel(scope, receive, send) + + message_ack_event: MessageAckEvent = { + "type": "message.ack", + } + await send(message_ack_event) + except Exception as e: + message_nack_event: MessageNackEvent = { + "type": "message.nack", + "message": str(e), + } + await send(message_nack_event) + + async def call_channel( + self, scope: MessageScope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + address = scope["address"] + for channel in self.channels: + parameters = channel.match(address) + if parameters is not None: + await channel(scope, receive, send, parameters) + return + raise ChannelNotFoundError(address) diff --git a/packages/asyncfast/tests_asyncfast/test_channel.py b/packages/asyncfast/tests_asyncfast/test_channel.py index bc3462f..37cd93c 100644 --- a/packages/asyncfast/tests_asyncfast/test_channel.py +++ b/packages/asyncfast/tests_asyncfast/test_channel.py @@ -7,10 +7,9 @@ from unittest.mock import call from unittest.mock import Mock -from asyncfast._channel import channel from asyncfast._channel import Depends +from asyncfast._channel import get_channel from asyncfast._channel import Header -from asyncfast._channel import MessageReceive from asyncfast._channel import MessageSender from asyncfast.bindings import KafkaKey @@ -21,17 +20,15 @@ async def test_payload_basic() -> None: def func(i: int) -> None: mock(i) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - "payload": b"1", - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + "payload": b"1", + }, + Mock(), Mock(), ) @@ -44,16 +41,14 @@ async def test_header_basic() -> None: def func(header: Annotated[str, Header()]) -> None: mock(header) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"header", b"value")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"header", b"value")], + }, + Mock(), Mock(), ) @@ -66,16 +61,14 @@ async def test_header_default() -> None: def func(header: Annotated[str, Header()] = "value") -> None: mock(header) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + }, + Mock(), Mock(), ) @@ -88,16 +81,14 @@ async def test_header_underscore_to_hyphen() -> None: def func(header_name: Annotated[str, Header()]) -> None: mock(header_name) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"header-name", b"value")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"header-name", b"value")], + }, + Mock(), Mock(), ) @@ -110,16 +101,14 @@ async def test_header_alias() -> None: def func(etag: Annotated[str, Header(alias="ETag")]) -> None: mock(etag) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"ETag", b"9e30981e-02d5-11f1-9648-e323315723e1")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"ETag", b"9e30981e-02d5-11f1-9648-e323315723e1")], + }, + Mock(), Mock(), ) @@ -132,16 +121,14 @@ async def test_address_parameter() -> None: def func(user: str) -> None: mock(user) - await channel(func, "channel.{user}").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel.54a08cc6-02db-11f1-afbf-f3f4688d5de4", - "headers": [], - }, - {"user": "54a08cc6-02db-11f1-afbf-f3f4688d5de4"}, - ), + await get_channel(func, "channel.{user}")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel.54a08cc6-02db-11f1-afbf-f3f4688d5de4", + "headers": [], + }, + Mock(), Mock(), ) @@ -154,17 +141,15 @@ async def test_binding() -> None: def func(key: Annotated[int, KafkaKey()]) -> None: mock(key) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - "bindings": {"kafka": {"key": b"123"}}, - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + "bindings": {"kafka": {"key": b"123"}}, + }, + Mock(), Mock(), ) @@ -177,16 +162,14 @@ async def test_binding_default() -> None: def func(key: Annotated[int, KafkaKey()] = 123) -> None: mock(key) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + }, + Mock(), Mock(), ) @@ -199,17 +182,15 @@ async def test_async_func() -> None: async def func(i: int) -> None: await mock(i) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - "payload": b"1", - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + "payload": b"1", + }, + Mock(), Mock(), ) @@ -226,17 +207,15 @@ async def func() -> AsyncGenerator[Mapping[str, Any], None]: "headers": [(b"Id", b"10")], } - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - "payload": b"1", - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + "payload": b"1", + }, + Mock(), send_mock, ) @@ -260,17 +239,15 @@ def func() -> Generator[Mapping[str, Any], None, None]: "headers": [(b"Id", b"10")], } - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - "payload": b"1", - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + "payload": b"1", + }, + Mock(), send_mock, ) @@ -296,17 +273,15 @@ async def func(message_sender: MessageSender[Mapping[str, Any]]) -> None: } ) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - "payload": b"1", - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + "payload": b"1", + }, + Mock(), send_mock, ) @@ -334,16 +309,14 @@ async def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"header1", b"1"), (b"header2", b"2")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + Mock(), Mock(), ) @@ -364,16 +337,14 @@ def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"header1", b"1"), (b"header2", b"2")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + Mock(), Mock(), ) @@ -393,16 +364,14 @@ def func( ) -> None: mock_func(dependency1, dependency2) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + }, + Mock(), Mock(), ) @@ -425,16 +394,14 @@ def func( ) -> None: mock_func(dependency1, dependency2) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + }, + Mock(), Mock(), ) @@ -465,16 +432,14 @@ async def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"header1", b"1"), (b"header2", b"2")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + Mock(), Mock(), ) @@ -505,16 +470,14 @@ def dependency( def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: mock(headers) - await channel(func, "channel").invoke( - MessageReceive( - { - "type": "message", - "amgi": {"version": "2.0", "spec_version": "2.0"}, - "address": "channel", - "headers": [(b"header1", b"1"), (b"header2", b"2")], - }, - {}, - ), + await get_channel(func, "channel")( + { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [(b"header1", b"1"), (b"header2", b"2")], + }, + Mock(), Mock(), ) diff --git a/packages/asyncfast/tests_asyncfast/test_message.py b/packages/asyncfast/tests_asyncfast/test_message.py index 4e1a8ce..b7829f3 100644 --- a/packages/asyncfast/tests_asyncfast/test_message.py +++ b/packages/asyncfast/tests_asyncfast/test_message.py @@ -15,7 +15,6 @@ from amgi_types import AMGISendEvent from amgi_types import MessageScope from asyncfast import AsyncFast -from asyncfast import ChannelNotFoundError from asyncfast import Header from asyncfast import Message from asyncfast import MessageSender @@ -823,14 +822,16 @@ async def topic_handler(id: int) -> None: "address": "not_topic", "headers": [], } - with pytest.raises( - ChannelNotFoundError, match="Couldn't resolve address: not_topic" - ): - await app( - message_scope, - AsyncMock(), - AsyncMock(), - ) + send_mock = AsyncMock() + await app( + message_scope, + AsyncMock(), + send_mock, + ) + + send_mock.assert_awaited_once_with( + {"type": "message.nack", "message": "Couldn't resolve address: not_topic"} + ) async def test_message_payload_untyped() -> None: diff --git a/t.py b/t.py new file mode 100644 index 0000000..e69de29 From e8f4d3c36ee3a57e6eab5abf732ca9ef0653fe58 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Fri, 20 Feb 2026 21:05:21 +0000 Subject: [PATCH 4/6] feat(asyncfast): add router-based dispatch --- t.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 t.py diff --git a/t.py b/t.py deleted file mode 100644 index e69de29..0000000 From 3bed4bbf3a59aa4dd40551e3b6dd7bcfb5053db7 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Fri, 20 Feb 2026 21:05:22 +0000 Subject: [PATCH 5/6] feat(asyncfast): add middleware support --- packages/asyncfast/src/asyncfast/__init__.py | 4 +- .../asyncfast/src/asyncfast/_asyncfast.py | 50 +++++++++ .../tests_asyncfast/test_middleware.py | 101 ++++++++++++++++++ 3 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 packages/asyncfast/tests_asyncfast/test_middleware.py diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index 0ea5adf..2cf0ead 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -1,4 +1,5 @@ from asyncfast._asyncfast import AsyncFast +from asyncfast._asyncfast import Middleware from asyncfast._channel import ChannelNotFoundError from asyncfast._channel import Depends from asyncfast._channel import Header @@ -10,12 +11,13 @@ __all__ = [ "AsyncFast", + "Middleware", "ChannelNotFoundError", - "Message", "Depends", "Header", "InvalidChannelDefinitionError", "MessageSender", "Parameter", "Payload", + "Message", ] diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index 55a73db..2ae1675 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -1,9 +1,14 @@ from collections.abc import Callable +from collections.abc import Iterator +from collections.abc import Sequence from contextlib import AbstractAsyncContextManager from functools import partial from typing import Any +from typing import ParamSpec +from typing import Protocol from typing import TypeVar +from amgi_types import AMGIApplication from amgi_types import AMGIReceiveCallable from amgi_types import AMGISendCallable from amgi_types import LifespanShutdownCompleteEvent @@ -12,23 +17,46 @@ from asyncfast._asyncapi import get_asyncapi from asyncfast._channel import Router +P = ParamSpec("P") DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) Lifespan = Callable[["AsyncFast"], AbstractAsyncContextManager[None]] +class _MiddlewareFactory(Protocol[P]): + def __call__( + self, app: AMGIApplication, /, *args: P.args, **kwargs: P.kwargs + ) -> AMGIApplication: ... # pragma: no cover + + +class Middleware: + def __init__( + self, cls: _MiddlewareFactory[P], *args: P.args, **kwargs: P.kwargs + ) -> None: + self.cls = cls + self.args = args + self.kwargs = kwargs + + def __iter__(self) -> Iterator[Any]: + as_tuple = (self.cls, self.args, self.kwargs) + return iter(as_tuple) + + class AsyncFast: def __init__( self, title: str = "AsyncFast", version: str = "0.1.0", lifespan: Lifespan | None = None, + middleware: Sequence[Middleware] | None = None, ) -> None: self._title = title self._version = version self._lifespan_context = lifespan + self._middleware = list(middleware) if middleware else [] self._router = Router() self._lifespan: AbstractAsyncContextManager[None] | None = None self._asyncapi_schema: dict[str, Any] | None = None + self._middleware_stack: AMGIApplication | None = None @property def title(self) -> str: @@ -49,6 +77,14 @@ def _add_channel( async def __call__( self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + if self._middleware_stack is None: + self._middleware_stack = self.build_middleware_stack() + + await self._middleware_stack(scope, receive, send) + + async def _app( + self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable ) -> None: if scope["type"] == "lifespan": while True: @@ -82,3 +118,17 @@ def asyncapi(self) -> dict[str, Any]: ) return self._asyncapi_schema + + def build_middleware_stack(self) -> AMGIApplication: + app = self._app + for cls, args, kwargs in self._middleware: + app = cls(app, *args, **kwargs) + return app + + def add_middleware( + self, + middleware_class: _MiddlewareFactory[P], + *args: P.args, + **kwargs: P.kwargs, + ) -> None: + self._middleware.append(Middleware(middleware_class, *args, **kwargs)) diff --git a/packages/asyncfast/tests_asyncfast/test_middleware.py b/packages/asyncfast/tests_asyncfast/test_middleware.py new file mode 100644 index 0000000..0276ea7 --- /dev/null +++ b/packages/asyncfast/tests_asyncfast/test_middleware.py @@ -0,0 +1,101 @@ +from unittest.mock import AsyncMock +from unittest.mock import call +from unittest.mock import Mock + +from amgi_types import AMGIApplication +from amgi_types import AMGIReceiveCallable +from amgi_types import AMGISendCallable +from amgi_types import MessageScope +from amgi_types import Scope +from asyncfast import AsyncFast +from asyncfast import Middleware + + +class RecordingMiddleware: + def __init__(self, app: AMGIApplication, mock: Mock) -> None: + self._app = app + self._mock = mock + + async def __call__( + self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + self._mock.before() + await self._app(scope, receive, send) + self._mock.after() + + +async def test_middleware_with_init() -> None: + parent = Mock() + + app = AsyncFast(middleware=[Middleware(RecordingMiddleware, parent.recorder)]) + + @app.channel("channel") + def handler() -> None: + parent.handler() + + scope: MessageScope = { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + } + await app(scope, AsyncMock(), AsyncMock()) + + assert parent.mock_calls == [ + call.recorder.before(), + call.handler(), + call.recorder.after(), + ] + + +async def test_middleware_with_add_middleware() -> None: + parent = Mock() + + app = AsyncFast() + app.add_middleware(RecordingMiddleware, parent.recorder) + + @app.channel("channel") + def handler() -> None: + parent.handler() + + scope: MessageScope = { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + } + await app(scope, AsyncMock(), AsyncMock()) + + assert parent.mock_calls == [ + call.recorder.before(), + call.handler(), + call.recorder.after(), + ] + + +async def test_multiple_middleware_order() -> None: + parent = Mock() + + app = AsyncFast() + app.add_middleware(RecordingMiddleware, parent.first) + app.add_middleware(RecordingMiddleware, parent.second) + + @app.channel("channel") + def handler() -> None: + parent.handler() + + scope: MessageScope = { + "type": "message", + "amgi": {"version": "2.0", "spec_version": "2.0"}, + "address": "channel", + "headers": [], + } + await app(scope, AsyncMock(), AsyncMock()) + + assert parent.mock_calls == [ + call.second.before(), + call.first.before(), + call.handler(), + call.first.after(), + call.second.after(), + ] From 53edd713d4aef9e31e699b01ad723b3b65c2c9e2 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Fri, 20 Feb 2026 22:15:50 +0000 Subject: [PATCH 6/6] docs(asyncfast): add middleware docs and example --- .../docs/examples/middleware_basic.py | 29 ++++++ packages/asyncfast/docs/index.rst | 1 + packages/asyncfast/docs/middleware.rst | 89 +++++++++++++++++++ 3 files changed, 119 insertions(+) create mode 100644 packages/asyncfast/docs/examples/middleware_basic.py create mode 100644 packages/asyncfast/docs/middleware.rst diff --git a/packages/asyncfast/docs/examples/middleware_basic.py b/packages/asyncfast/docs/examples/middleware_basic.py new file mode 100644 index 0000000..691097b --- /dev/null +++ b/packages/asyncfast/docs/examples/middleware_basic.py @@ -0,0 +1,29 @@ +from time import monotonic + +from amgi_types import AMGIApplication +from amgi_types import AMGIReceiveCallable +from amgi_types import AMGISendCallable +from amgi_types import Scope +from asyncfast import AsyncFast + + +class TimingMiddleware: + def __init__(self, app: AMGIApplication) -> None: + self._app = app + + async def __call__( + self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + start = monotonic() + await self._app(scope, receive, send) + duration_ms = (monotonic() - start) * 1000 + print(f"{scope['type']} handled in {duration_ms:.2f}ms") + + +app = AsyncFast() +app.add_middleware(TimingMiddleware) + + +@app.channel("orders") +async def handle_order(order_id: int) -> None: + print(f"processing order {order_id}") diff --git a/packages/asyncfast/docs/index.rst b/packages/asyncfast/docs/index.rst index b520652..a563637 100644 --- a/packages/asyncfast/docs/index.rst +++ b/packages/asyncfast/docs/index.rst @@ -79,6 +79,7 @@ Taking ideas from: sending dependencies lifespan + middleware .. _amgi: https://amgi.readthedocs.io/en/latest/ diff --git a/packages/asyncfast/docs/middleware.rst b/packages/asyncfast/docs/middleware.rst new file mode 100644 index 0000000..184e34f --- /dev/null +++ b/packages/asyncfast/docs/middleware.rst @@ -0,0 +1,89 @@ +############ + Middleware +############ + +Middleware lets you wrap the AMGI application to run logic before, or after a message is handled. A middleware is a +callable that receives the downstream app, and is itself an AMGI application. + +Middleware can be used for cross-cutting concerns like logging, timing, tracing, metrics, or translating errors. It sees +the full AMGI ``scope``, and the ``receive``/``send`` callables for each event. + +********************** + What Middleware Runs +********************** + +AsyncFast builds a middleware stack that wraps the core app. Each middleware is called for every AMGI event handled by +the app: + +- ``message`` scopes for regular channel handling. +- ``lifespan`` scopes for startup, and shutdown. + +If you need one-time startup, or shutdown logic, prefer the lifespan API. Middleware is for per-event behavior. + +****************** + Basic Middleware +****************** + +Create a class with ``__init__`` to receive the downstream app, and ``__call__`` to handle each event: + +.. async-fast-example:: examples/middleware_basic.py + +The code before ``await self._app(...)`` runs before the handler (and its dependencies). The code after runs after the +handler finishes. This is the standard pattern for timing, logging, or error handling. + +If you add dependencies that use ``yield`` for cleanup, their teardown runs inside the downstream app, so it completes +before the code after ``await self._app(...)``. + +********************************* + Working With Scope And Messages +********************************* + +The middleware callable receives: + +- ``scope``: a dict describing the AMGI event (including ``type``, channel address, headers, and protocol info). +- ``receive``: an async callable that yields inbound events. +- ``send``: an async callable used to emit outbound events. + +Most middleware simply passes these through to the downstream app. If you need to inspect, or transform traffic, you can +wrap ``receive`` or ``send`` before passing them along. When you do, make sure you preserve the expected event flow, and +always ``await`` the downstream app exactly once. + +************************ + Registering Middleware +************************ + +You can register middleware when creating the app: + +.. code:: python + + from asyncfast import AsyncFast + from asyncfast import Middleware + + app = AsyncFast( + middleware=[Middleware(MyMiddleware, "arg1", option=True)], + ) + +Or add it later: + +.. code:: python + + app = AsyncFast() + app.add_middleware(MyMiddleware, "arg1", option=True) + +.. note:: + + The middleware stack is built on first use. Add middleware before the app starts handling messages. + +****************** + Middleware Order +****************** + +Middleware wraps the app in the order it is registered. The last middleware added runs first. + +For example, if you add ``FirstMiddleware``, and then ``SecondMiddleware``, the call order is: + +#. ``SecondMiddleware`` before +#. ``FirstMiddleware`` before +#. router +#. ``FirstMiddleware`` after +#. ``SecondMiddleware`` after