diff --git a/packages/asyncfast/src/asyncfast/_asyncfast.py b/packages/asyncfast/src/asyncfast/_asyncfast.py index 2ae1675..3ac2695 100644 --- a/packages/asyncfast/src/asyncfast/_asyncfast.py +++ b/packages/asyncfast/src/asyncfast/_asyncfast.py @@ -16,6 +16,7 @@ from amgi_types import Scope from asyncfast._asyncapi import get_asyncapi from asyncfast._channel import Router +from asyncfast.middleware.errors import ServerErrorMiddleware P = ParamSpec("P") DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any]) @@ -123,7 +124,7 @@ def build_middleware_stack(self) -> AMGIApplication: app = self._app for cls, args, kwargs in self._middleware: app = cls(app, *args, **kwargs) - return app + return ServerErrorMiddleware(app) def add_middleware( self, diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index eb43bda..eb3008d 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -30,8 +30,6 @@ from amgi_types import AMGIReceiveCallable from amgi_types import AMGISendCallable -from amgi_types import MessageAckEvent -from amgi_types import MessageNackEvent from amgi_types import MessageScope from amgi_types import MessageSendEvent from asyncfast._utils import get_address_parameters @@ -617,23 +615,6 @@ def add_channel(self, address: str, func: Callable[..., Any]) -> None: async def __call__( self, scope: MessageScope, receive: AMGIReceiveCallable, send: AMGISendCallable - ) -> None: - try: - await self.call_channel(scope, receive, send) - - message_ack_event: MessageAckEvent = { - "type": "message.ack", - } - await send(message_ack_event) - except Exception as e: - message_nack_event: MessageNackEvent = { - "type": "message.nack", - "message": str(e), - } - await send(message_nack_event) - - async def call_channel( - self, scope: MessageScope, receive: AMGIReceiveCallable, send: AMGISendCallable ) -> None: address = scope["address"] for channel in self.channels: diff --git a/packages/asyncfast/src/asyncfast/middleware/__init__.py b/packages/asyncfast/src/asyncfast/middleware/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/asyncfast/src/asyncfast/middleware/errors.py b/packages/asyncfast/src/asyncfast/middleware/errors.py new file mode 100644 index 0000000..4f96356 --- /dev/null +++ b/packages/asyncfast/src/asyncfast/middleware/errors.py @@ -0,0 +1,32 @@ +from amgi_types import AMGIApplication +from amgi_types import AMGIReceiveCallable +from amgi_types import AMGISendCallable +from amgi_types import MessageAckEvent +from amgi_types import MessageNackEvent +from amgi_types import Scope + + +class ServerErrorMiddleware: + def __init__(self, app: AMGIApplication) -> None: + self.app = app + + async def __call__( + self, scope: Scope, receive: AMGIReceiveCallable, send: AMGISendCallable + ) -> None: + if scope["type"] != "message": + await self.app(scope, receive, send) + return + + try: + await self.app(scope, receive, send) + + message_ack_event: MessageAckEvent = { + "type": "message.ack", + } + await send(message_ack_event) + except Exception as e: + message_nack_event: MessageNackEvent = { + "type": "message.nack", + "message": str(e), + } + await send(message_nack_event)