From 5535b040c1710a5783e386ec0ff23155281c397b Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sat, 28 Mar 2026 10:58:31 +0000 Subject: [PATCH 1/2] refactor(asyncfast): add routing functionality to its own routing class --- packages/asyncfast/src/asyncfast/__init__.py | 4 +- packages/asyncfast/src/asyncfast/_asyncapi.py | 4 +- .../asyncfast/src/asyncfast/_asyncfast.py | 4 +- packages/asyncfast/src/asyncfast/_channel.py | 57 +++++-------------- packages/asyncfast/src/asyncfast/_utils.py | 35 ++++++++++-- .../asyncfast/tests_asyncfast/test_channel.py | 18 ++++++ 6 files changed, 67 insertions(+), 55 deletions(-) diff --git a/packages/asyncfast/src/asyncfast/__init__.py b/packages/asyncfast/src/asyncfast/__init__.py index 2cf0ead..7ad442c 100644 --- a/packages/asyncfast/src/asyncfast/__init__.py +++ b/packages/asyncfast/src/asyncfast/__init__.py @@ -1,6 +1,5 @@ from asyncfast._asyncfast import AsyncFast from asyncfast._asyncfast import Middleware -from asyncfast._channel import ChannelNotFoundError from asyncfast._channel import Depends from asyncfast._channel import Header from asyncfast._channel import InvalidChannelDefinitionError @@ -8,11 +7,11 @@ from asyncfast._channel import Parameter from asyncfast._channel import Payload from asyncfast._message import Message +from asyncfast._utils import ChannelNotFoundError __all__ = [ "AsyncFast", "Middleware", - "ChannelNotFoundError", "Depends", "Header", "InvalidChannelDefinitionError", @@ -20,4 +19,5 @@ "Parameter", "Payload", "Message", + "ChannelNotFoundError", ] diff --git a/packages/asyncfast/src/asyncfast/_asyncapi.py b/packages/asyncfast/src/asyncfast/_asyncapi.py index 4042324..4119430 100644 --- a/packages/asyncfast/src/asyncfast/_asyncapi.py +++ b/packages/asyncfast/src/asyncfast/_asyncapi.py @@ -15,11 +15,11 @@ from asyncfast._channel import BindingResolver from asyncfast._channel import CallableResolver from asyncfast._channel import Channel +from asyncfast._channel import ChannelRouter from asyncfast._channel import HeaderResolver from asyncfast._channel import MessageSenderResolver from asyncfast._channel import PayloadResolver from asyncfast._channel import Resolver -from asyncfast._channel import Router from asyncfast._message import Message from pydantic import BaseModel from pydantic import create_model @@ -319,7 +319,7 @@ def get_asyncapi( *, title: str, version: str, - router: Router, + router: ChannelRouter, ) -> dict[str, Any]: channel_definitions = tuple( ChannelDefinition(channel) for channel in router.channels diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index 3ac2695..2b3aa2c 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -15,7 +15,7 @@ from amgi_types import LifespanStartupCompleteEvent from amgi_types import Scope from asyncfast._asyncapi import get_asyncapi -from asyncfast._channel import Router +from asyncfast._channel import ChannelRouter from asyncfast.middleware.errors import ServerErrorMiddleware P = ParamSpec("P") @@ -54,7 +54,7 @@ def __init__( self._version = version self._lifespan_context = lifespan self._middleware = list(middleware) if middleware else [] - self._router = Router() + self._router = ChannelRouter() self._lifespan: AbstractAsyncContextManager[None] | None = None self._asyncapi_schema: dict[str, Any] | None = None self._middleware_stack: AMGIApplication | None = None diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index eb3008d..c75a382 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -19,7 +19,6 @@ from dataclasses import KW_ONLY from functools import cached_property from functools import wraps -from re import Pattern from typing import Annotated from typing import Any from typing import Generic @@ -33,7 +32,7 @@ from amgi_types import MessageScope from amgi_types import MessageSendEvent from asyncfast._utils import get_address_parameters -from asyncfast._utils import get_address_pattern +from asyncfast._utils import Router from asyncfast.bindings import Binding from pydantic import TypeAdapter from pydantic.fields import FieldInfo @@ -98,16 +97,6 @@ class InvalidChannelDefinitionError(ValueError): """ -class RouteInvariantError(RuntimeError): - """Raised when a selected route fails to match its address.""" - - -class ChannelNotFoundError(LookupError): - def __init__(self, address: str) -> None: - super().__init__(f"Couldn't resolve address: {address}") - self.address = address - - class Header(FieldInfo): pass @@ -379,7 +368,6 @@ def resolve( @dataclass(frozen=True) class Channel(CallableResolver, ABC): address: str - address_pattern: Pattern[str] parameters: set[str] async def __call__( @@ -387,24 +375,13 @@ async def __call__( scope: MessageScope, receive: AMGIReceiveCallable, send: AMGISendCallable, - parameters: dict[str, str] | None = None, + parameters: dict[str, str], ) -> None: - parameters = self.match(scope["address"]) if parameters is None else parameters - if parameters is None: - raise RouteInvariantError( - f"Selected route did not match address {scope['address']!r}" - ) message_receive = MessageReceive(scope, parameters) dependency_cache = DependencyCache(asyncio.get_event_loop()) async with AsyncExitStack() as async_exit_stack: await self.call(message_receive, send, dependency_cache, async_exit_stack) - def match(self, address: str) -> dict[str, str] | None: - match = self.address_pattern.match(address) - if match: - return match.groupdict() - return None - @dataclass(frozen=True) class SyncChannel(Channel): @@ -574,7 +551,6 @@ def resolvers_dependencies( def get_channel(func: Callable[..., Any], address: str) -> Channel: address_parameters = get_address_parameters(address) - address_pattern = get_address_pattern(address) resolvers, dependencies = resolvers_dependencies(func, address_parameters) payloads = sum(isinstance(resolver, PayloadResolver) for _, resolver in resolvers) @@ -590,36 +566,29 @@ def get_channel(func: Callable[..., Any], address: str) -> Channel: ) if inspect.iscoroutinefunction(func): - return AsyncChannel( - func, resolvers, dependencies, address, address_pattern, address_parameters - ) + return AsyncChannel(func, resolvers, dependencies, address, address_parameters) if inspect.isasyncgenfunction(func): return AsyncGeneratorChannel( - func, resolvers, dependencies, address, address_pattern, address_parameters + func, resolvers, dependencies, address, address_parameters ) if inspect.isgeneratorfunction(func): return SyncGeneratorChannel( - func, resolvers, dependencies, address, address_pattern, address_parameters + func, resolvers, dependencies, address, address_parameters ) - return SyncChannel( - func, resolvers, dependencies, address, address_pattern, address_parameters - ) + return SyncChannel(func, resolvers, dependencies, address, address_parameters) -class Router: - def __init__(self) -> None: - self.channels: list[Channel] = [] +class ChannelRouter(Router[Channel]): + @property + def channels(self) -> Sequence[Channel]: + return [channel for _, channel in self.routes] def add_channel(self, address: str, func: Callable[..., Any]) -> None: - self.channels.append(get_channel(func, address)) + self.add_route(address, get_channel(func, address)) async def __call__( self, scope: MessageScope, receive: AMGIReceiveCallable, send: AMGISendCallable ) -> None: address = scope["address"] - for channel in self.channels: - parameters = channel.match(address) - if parameters is not None: - await channel(scope, receive, send, parameters) - return - raise ChannelNotFoundError(address) + parameters, channel = self.get(address) + await channel(scope, receive, send, parameters) diff --git a/packages/asyncfast/src/asyncfast/_utils.py b/packages/asyncfast/src/asyncfast/_utils.py index 6831237..4f04e66 100644 --- a/packages/asyncfast/src/asyncfast/_utils.py +++ b/packages/asyncfast/src/asyncfast/_utils.py @@ -1,17 +1,21 @@ import re from collections import Counter from re import Pattern +from typing import Generic +from typing import TypeVar -_FIELD_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") -_PARAMETER_PATTERN = re.compile(r"{(.*)}") +T = TypeVar("T") + +FIELD_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") +PARAMETER_PATTERN = re.compile(r"{(.*)}") def get_address_parameters(address: str | None) -> set[str]: if address is None: return set() - parameters = _PARAMETER_PATTERN.findall(address) + parameters = PARAMETER_PATTERN.findall(address) for parameter in parameters: - assert _FIELD_PATTERN.match(parameter), f"Parameter '{parameter}' is not valid" + assert FIELD_PATTERN.match(parameter), f"Parameter '{parameter}' is not valid" duplicates = {item for item, count in Counter(parameters).items() if count > 1} assert len(duplicates) == 0, f"Address contains duplicate parameters: {duplicates}" @@ -21,7 +25,7 @@ def get_address_parameters(address: str | None) -> set[str]: def get_address_pattern(address: str) -> Pattern[str]: index = 0 address_regex = "^" - for match in _PARAMETER_PATTERN.finditer(address): + for match in PARAMETER_PATTERN.finditer(address): (name,) = match.groups() address_regex += re.escape(address[index : match.start()]) address_regex += f"(?P<{name}>.*)" @@ -30,3 +34,24 @@ def get_address_pattern(address: str) -> Pattern[str]: address_regex += re.escape(address[index:]) + "$" return re.compile(address_regex) + + +class ChannelNotFoundError(LookupError): + def __init__(self, address: str) -> None: + super().__init__(f"Couldn't resolve address: {address}") + self.address = address + + +class Router(Generic[T]): + def __init__(self) -> None: + self.routes: list[tuple[Pattern[str], T]] = [] + + def add_route(self, address: str, route: T) -> None: + self.routes.append((get_address_pattern(address), route)) + + def get(self, address: str) -> tuple[dict[str, str], T]: + for address_pattern, route in self.routes: + parameters = address_pattern.match(address) + if parameters is not None: + return parameters.groupdict(), route + raise ChannelNotFoundError(address) diff --git a/packages/asyncfast/tests_asyncfast/test_channel.py b/packages/asyncfast/tests_asyncfast/test_channel.py index 37cd93c..baff6d9 100644 --- a/packages/asyncfast/tests_asyncfast/test_channel.py +++ b/packages/asyncfast/tests_asyncfast/test_channel.py @@ -30,6 +30,7 @@ def func(i: int) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with(1) @@ -50,6 +51,7 @@ def func(header: Annotated[str, Header()]) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with("value") @@ -70,6 +72,7 @@ def func(header: Annotated[str, Header()] = "value") -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with("value") @@ -90,6 +93,7 @@ def func(header_name: Annotated[str, Header()]) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with("value") @@ -110,6 +114,7 @@ def func(etag: Annotated[str, Header(alias="ETag")]) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with("9e30981e-02d5-11f1-9648-e323315723e1") @@ -130,6 +135,7 @@ def func(user: str) -> None: }, Mock(), Mock(), + {"user": "54a08cc6-02db-11f1-afbf-f3f4688d5de4"}, ) mock.assert_called_with("54a08cc6-02db-11f1-afbf-f3f4688d5de4") @@ -151,6 +157,7 @@ def func(key: Annotated[int, KafkaKey()]) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with(123) @@ -171,6 +178,7 @@ def func(key: Annotated[int, KafkaKey()] = 123) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with(123) @@ -192,6 +200,7 @@ async def func(i: int) -> None: }, Mock(), Mock(), + {}, ) mock.assert_awaited_once_with(1) @@ -217,6 +226,7 @@ async def func() -> AsyncGenerator[Mapping[str, Any], None]: }, Mock(), send_mock, + {}, ) send_mock.assert_awaited_once_with( @@ -249,6 +259,7 @@ def func() -> Generator[Mapping[str, Any], None, None]: }, Mock(), send_mock, + {}, ) send_mock.assert_awaited_once_with( @@ -283,6 +294,7 @@ async def func(message_sender: MessageSender[Mapping[str, Any]]) -> None: }, Mock(), send_mock, + {}, ) send_mock.assert_awaited_once_with( @@ -318,6 +330,7 @@ def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with({"header1": 1, "header2": 2}) @@ -346,6 +359,7 @@ def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: }, Mock(), Mock(), + {}, ) mock.assert_called_with({"header1": 1, "header2": 2}) @@ -373,6 +387,7 @@ def func( }, Mock(), Mock(), + {}, ) mock_func.assert_called_with( @@ -403,6 +418,7 @@ def func( }, Mock(), Mock(), + {}, ) mock_func.assert_called_with( @@ -441,6 +457,7 @@ def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: }, Mock(), Mock(), + {}, ) assert parent.mock_calls == [ @@ -479,6 +496,7 @@ def func(headers: Annotated[dict[str, int], Depends(dependency)]) -> None: }, Mock(), Mock(), + {}, ) assert parent.mock_calls == [ From af457ca127f8a79f41a82a06e71cc9f999c93e9d Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Thu, 16 Apr 2026 19:03:25 +0100 Subject: [PATCH 2/2] feat(asyncfast): add message send router --- docs-requirements.txt | 4 +- .../docs/examples/message_send_router.py | 69 +++++++++++++++++++ packages/asyncfast/docs/index.rst | 1 + .../asyncfast/docs/message_send_router.rst | 64 +++++++++++++++++ packages/asyncfast/src/asyncfast/_utils.py | 15 +++- .../asyncfast/src/asyncfast/message_send.py | 62 +++++++++++++++++ .../tests_asyncfast/test_message_send.py | 59 ++++++++++++++++ 7 files changed, 272 insertions(+), 2 deletions(-) create mode 100644 packages/asyncfast/docs/examples/message_send_router.py create mode 100644 packages/asyncfast/docs/message_send_router.rst create mode 100644 packages/asyncfast/src/asyncfast/message_send.py create mode 100644 packages/asyncfast/tests_asyncfast/test_message_send.py diff --git a/docs-requirements.txt b/docs-requirements.txt index 92b2b41..2c5f99d 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -5,4 +5,6 @@ furo>=2025.7.19 sphinx-copybutton>=0.5.2 redis>=7.0.1 packages/asyncfast -packages/amgi-types \ No newline at end of file +packages/amgi-types +packages/amgi-aiobotocore +packages/amgi-aiokafka \ No newline at end of file diff --git a/packages/asyncfast/docs/examples/message_send_router.py b/packages/asyncfast/docs/examples/message_send_router.py new file mode 100644 index 0000000..bd7369c --- /dev/null +++ b/packages/asyncfast/docs/examples/message_send_router.py @@ -0,0 +1,69 @@ +import os +from dataclasses import dataclass +from typing import Annotated + +from amgi_aiobotocore.sqs import MessageSend as SQSMessageSend +from amgi_aiokafka import MessageSend as KafkaMessageSend +from amgi_aiokafka import run +from asyncfast import AsyncFast +from asyncfast import Header +from asyncfast import Message +from asyncfast import MessageSender +from asyncfast.message_send import MessageSendRouter + +BOOTSTRAP_SERVERS = os.getenv("BOOTSTRAP_SERVERS", "localhost:9092") + +app = AsyncFast() + + +@dataclass +class Item: + sku_id: str + amount: int + + +@dataclass +class Order: + items: list[Item] + status: str + + +@dataclass +class EmailProcessingOrder(Message, address="order-email-processing-order"): + order_id: Annotated[str, Header()] + order: Order + + +@dataclass +class CancelShipping(Message, address="cancel-shipping"): + order_id: Annotated[str, Header()] + + +@app.channel("orders") +async def handle_order( + order: Order, + order_id: Annotated[str, Header()], + message_sender: MessageSender[EmailProcessingOrder | CancelShipping], +) -> None: + if order.status == "processing": + await message_sender.send(EmailProcessingOrder(order_id=order_id, order=order)) + + if order.status == "cancelled": + await message_sender.send(CancelShipping(order_id=order_id)) + + +message_send_router = MessageSendRouter() + +message_send_router.add_route( + "order-email-processing-order", + KafkaMessageSend(bootstrap_servers=BOOTSTRAP_SERVERS), +) +message_send_router.add_route("cancel-shipping", SQSMessageSend()) + +if __name__ == "__main__": + run( + app, + "orders", + bootstrap_servers=BOOTSTRAP_SERVERS, + message_send=message_send_router, + ) diff --git a/packages/asyncfast/docs/index.rst b/packages/asyncfast/docs/index.rst index a563637..77b8604 100644 --- a/packages/asyncfast/docs/index.rst +++ b/packages/asyncfast/docs/index.rst @@ -77,6 +77,7 @@ Taking ideas from: receiving sending + message_send_router dependencies lifespan middleware diff --git a/packages/asyncfast/docs/message_send_router.rst b/packages/asyncfast/docs/message_send_router.rst new file mode 100644 index 0000000..2062ea8 --- /dev/null +++ b/packages/asyncfast/docs/message_send_router.rst @@ -0,0 +1,64 @@ +##################### + Message Send Router +##################### + +``MessageSendRouter`` helps AMGI servers route ``message.send`` events to different backends based on the event address. +It manages setup and teardown of the underlying send callables using async context managers. This is useful when you +need to send different messages to different brokers, or you want to keep a single server running while routing +outbound traffic by address. + +************* + Basic Usage +************* + +The router is an async context manager. Create it, register routes, and then enter it when the server starts so the +message senders can establish connections: + +.. async-fast-example:: examples/message_send_router.py + +In the example above, one address is routed to Kafka, and another to SQS. The app itself stays the same; routing is +configured at the server boundary. + +*********************** + Integrating With Run +*********************** + +AMGI servers expect a ``send`` callable. ``MessageSendRouter`` provides that callable when you enter it, so pass the +router instance to the server and let it manage resource lifetimes: + +.. code:: python + + from amgi_aiokafka import run + from asyncfast.message_send import MessageSendRouter + + message_send_router = MessageSendRouter() + # add routes... + + run( + app, + "orders", + message_send=message_send_router, + ) + +When the server starts, it enters the router and uses the callable it yields. When the server shuts down, it exits the +router and closes all message senders cleanly. + +******************* + Address Patterns +******************* + +Routes use the same pattern syntax as channel parameters, so ``priority.{id}`` will match addresses like +``priority.123``. Register routes before entering the router so they are included in its setup. + +If you need a catch-all pattern, register a default route instead of a broad pattern; this keeps route matching +explicit and easier to reason about. + +**************** + Default Route +**************** + +If you pass ``default=``, the router will use that send callable when no route matches. Without a default, you should +ensure every outgoing address has a route registered, otherwise the send will fail at runtime. + +The default sender should be an async context manager just like the routed senders. This allows you to share connection +pools or client lifecycles with explicit cleanup on shutdown. diff --git a/packages/asyncfast/src/asyncfast/_utils.py b/packages/asyncfast/src/asyncfast/_utils.py index 4f04e66..4b9ef24 100644 --- a/packages/asyncfast/src/asyncfast/_utils.py +++ b/packages/asyncfast/src/asyncfast/_utils.py @@ -1,6 +1,7 @@ import re from collections import Counter from re import Pattern +from typing import Final from typing import Generic from typing import TypeVar @@ -42,16 +43,28 @@ def __init__(self, address: str) -> None: self.address = address +class MissingType: + pass + + class Router(Generic[T]): + MISSING: Final = MissingType() + def __init__(self) -> None: self.routes: list[tuple[Pattern[str], T]] = [] def add_route(self, address: str, route: T) -> None: self.routes.append((get_address_pattern(address), route)) - def get(self, address: str) -> tuple[dict[str, str], T]: + def get( + self, address: str, default: T | MissingType = MISSING + ) -> tuple[dict[str, str], T]: for address_pattern, route in self.routes: parameters = address_pattern.match(address) if parameters is not None: return parameters.groupdict(), route + + if not isinstance(default, MissingType): + return {}, default + raise ChannelNotFoundError(address) diff --git a/packages/asyncfast/src/asyncfast/message_send.py b/packages/asyncfast/src/asyncfast/message_send.py new file mode 100644 index 0000000..84b4e03 --- /dev/null +++ b/packages/asyncfast/src/asyncfast/message_send.py @@ -0,0 +1,62 @@ +import sys +from collections.abc import Awaitable +from collections.abc import Callable +from contextlib import AsyncExitStack +from types import TracebackType +from typing import AsyncContextManager + +from amgi_types import MessageSendEvent +from asyncfast._utils import Router + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + + +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] + + +class MessageSendRouter: + def __init__(self, *, default: _MessageSendManagerT | None = None) -> None: + self._default_manager = default + self._route_managers: list[tuple[str, _MessageSendManagerT]] = [] + + self._default_send: _MessageSendT | None = None + self._route_sends = Router[_MessageSendT]() + + self._exit_stack = AsyncExitStack() + + async def __aenter__(self) -> Self: + if self._default_manager is not None: + self._default_send = await self._exit_stack.enter_async_context( + self._default_manager, + ) + + for address, message_send_manager in self._route_managers: + message_send = await self._exit_stack.enter_async_context( + message_send_manager + ) + self._route_sends.add_route(address, message_send) + + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + await self._exit_stack.aclose() + + def add_route(self, address: str, message_send: _MessageSendManagerT) -> None: + self._route_managers.append((address, message_send)) + + async def __call__(self, event: MessageSendEvent) -> None: + _, message_send = self._route_sends.get( + event["address"], + default=self._default_send or Router.MISSING, + ) + + await message_send(event) diff --git a/packages/asyncfast/tests_asyncfast/test_message_send.py b/packages/asyncfast/tests_asyncfast/test_message_send.py new file mode 100644 index 0000000..c15fd66 --- /dev/null +++ b/packages/asyncfast/tests_asyncfast/test_message_send.py @@ -0,0 +1,59 @@ +from unittest.mock import AsyncMock + +from asyncfast.message_send import MessageSendRouter + + +async def test_message_send_router_default() -> None: + mock_default_message_send_manager = AsyncMock() + mock_default_message_send = ( + mock_default_message_send_manager.__aenter__.return_value + ) + + message_send_router = MessageSendRouter(default=mock_default_message_send_manager) + + async with message_send_router as message_send: + await message_send( + { + "type": "message.send", + "address": "address", + "headers": [], + "payload": b"test", + } + ) + + mock_default_message_send.assert_awaited_once_with( + { + "type": "message.send", + "address": "address", + "headers": [], + "payload": b"test", + } + ) + + +async def test_message_send_router() -> None: + mock_route_message_send_manager = AsyncMock() + mock_route_message_send = mock_route_message_send_manager.__aenter__.return_value + + message_send_router = MessageSendRouter() + + message_send_router.add_route("channel.{id}", mock_route_message_send_manager) + + async with message_send_router as message_send: + await message_send( + { + "type": "message.send", + "address": "channel.de320b42-2a98-11f1-badd-db379e7beed5", + "headers": [], + "payload": b"test", + } + ) + + mock_route_message_send.assert_awaited_once_with( + { + "type": "message.send", + "address": "channel.de320b42-2a98-11f1-badd-db379e7beed5", + "headers": [], + "payload": b"test", + } + )