From e55f97e136439d345f55207c7117b90178ab8b01 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Sun, 22 Feb 2026 20:19:39 +0000 Subject: [PATCH 1/5] perf(asyncfast): optimize resolver/binding storage --- packages/asyncfast/src/asyncfast/_asyncapi.py | 22 ++++--- packages/asyncfast/src/asyncfast/_channel.py | 32 +++++----- packages/asyncfast/src/asyncfast/_message.py | 60 +++++++++---------- 3 files changed, 55 insertions(+), 59 deletions(-) diff --git a/packages/asyncfast/src/asyncfast/_asyncapi.py b/packages/asyncfast/src/asyncfast/_asyncapi.py index ea13441..4042324 100644 --- a/packages/asyncfast/src/asyncfast/_asyncapi.py +++ b/packages/asyncfast/src/asyncfast/_asyncapi.py @@ -34,8 +34,9 @@ def generate_resolvers( callable_resolver: CallableResolver, ) -> Generator[Resolver[Any], None, None]: - yield from callable_resolver.resolvers.values() - for dependency in callable_resolver.dependencies.values(): + for _, resolver in callable_resolver.resolvers: + yield resolver + for _, dependency in callable_resolver.dependencies: yield from generate_resolvers(dependency) @@ -133,7 +134,7 @@ class ChannelDefinition: def bindings(self) -> Sequence[BindingResolver[Any]]: return [ resolver - for resolver in self.channel.resolvers.values() + for _, resolver in self.channel.resolvers if isinstance(resolver, BindingResolver) ] @@ -200,19 +201,16 @@ def send_message_definitions(self) -> Sequence[MessageDefinition]: MessageDefinition( message.__name__, message.__address__, - {name for name in message.__parameters__}, + {name for name, _ in message.__parameters__}, + [(alias, type_, ...) for _, alias, type_, _ in message.__headers__], [ - (alias, field.type, ...) - for name, (alias, field) in message.__headers__.items() - ], - [ - (protocol, field_name, field.type, field.type_adapter.core_schema) - for protocol, field_name, field in message.__bindings__.values() + (protocol, field_name, type_, type_adapter.core_schema) + for _, protocol, field_name, type_, type_adapter in message.__bindings__ ], ( ( - message.__payload__[1].type, - message.__payload__[1].type_adapter.core_schema, + message.__payload__[1], + message.__payload__[2].core_schema, ) if message.__payload__ else None diff --git a/packages/asyncfast/src/asyncfast/_channel.py b/packages/asyncfast/src/asyncfast/_channel.py index 2314d5f..eb43bda 100644 --- a/packages/asyncfast/src/asyncfast/_channel.py +++ b/packages/asyncfast/src/asyncfast/_channel.py @@ -11,6 +11,7 @@ from collections.abc import Callable from collections.abc import Generator from collections.abc import Mapping +from collections.abc import Sequence from contextlib import AbstractAsyncContextManager from contextlib import asynccontextmanager from contextlib import AsyncExitStack @@ -228,8 +229,8 @@ def resolve( @dataclass(frozen=True) class CallableResolver(ABC): func: Callable[..., Any] - resolvers: dict[str, Resolver[Any]] - dependencies: dict[str, DependencyResolver] + resolvers: Sequence[tuple[str, Resolver[Any]]] + dependencies: Sequence[tuple[str, DependencyResolver]] async def resolve( self, @@ -240,18 +241,19 @@ async def resolve( ) -> dict[str, Any]: resolver_result = { name: resolver.resolve(message_receive, send) - for name, resolver in self.resolvers.items() + for name, resolver in self.resolvers } + dependency_names = (name for name, _ in self.dependencies) dependency_result = dict( zip( - self.dependencies.keys(), + dependency_names, await asyncio.gather( *( dependency_cache.resolve( dependency, message_receive, send, async_exit_stack ) - for dependency in self.dependencies.values() + for _, dependency in self.dependencies ) ), ) @@ -557,17 +559,19 @@ def parameter_resolver( def resolvers_dependencies( func: Callable[..., Any], address_parameters: set[str] -) -> tuple[dict[str, Resolver[Any]], dict[str, DependencyResolver]]: +) -> tuple[ + Sequence[tuple[str, Resolver[Any]]], Sequence[tuple[str, DependencyResolver]] +]: signature = inspect.signature(func) - resolvers = {} - dependencies = {} + resolvers = [] + dependencies = [] for name, parameter in signature.parameters.items(): resolver = parameter_resolver(name, parameter, address_parameters) if isinstance(resolver, Resolver): - resolvers[name] = resolver + resolvers.append((name, resolver)) else: - dependencies[name] = resolver - return resolvers, dependencies + dependencies.append((name, resolver)) + return tuple(resolvers), tuple(dependencies) def get_channel(func: Callable[..., Any], address: str) -> Channel: @@ -575,14 +579,12 @@ def get_channel(func: Callable[..., Any], address: str) -> Channel: address_pattern = get_address_pattern(address) resolvers, dependencies = resolvers_dependencies(func, address_parameters) - payloads = sum( - isinstance(resolver, PayloadResolver) for resolver in resolvers.values() - ) + payloads = sum(isinstance(resolver, PayloadResolver) for _, resolver in resolvers) if payloads > 1: raise InvalidChannelDefinitionError("Channel must have no more than 1 payload") message_senders = sum( - isinstance(resolver, MessageSenderResolver) for resolver in resolvers.values() + isinstance(resolver, MessageSenderResolver) for _, resolver in resolvers ) if message_senders > 1: raise InvalidChannelDefinitionError( diff --git a/packages/asyncfast/src/asyncfast/_message.py b/packages/asyncfast/src/asyncfast/_message.py index b60615c..e8c043f 100644 --- a/packages/asyncfast/src/asyncfast/_message.py +++ b/packages/asyncfast/src/asyncfast/_message.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from collections.abc import Iterator from collections.abc import Mapping +from collections.abc import Sequence from typing import Annotated from typing import Any from typing import ClassVar @@ -16,58 +17,53 @@ from pydantic import TypeAdapter -class _Field: - def __init__(self, type_: type): - self.type = type_ - self.type_adapter = TypeAdapter[Any](type_) - - def __hash__(self) -> int: - return hash(self.type) - - class Message(Mapping[str, Any]): __address__: ClassVar[str | None] = None - __headers__: ClassVar[dict[str, tuple[str, _Field]]] - __parameters__: ClassVar[dict[str, TypeAdapter[Any]]] - __payload__: ClassVar[tuple[str, _Field] | None] - __bindings__: ClassVar[dict[str, tuple[str, str, _Field]]] + __headers__: ClassVar[Sequence[tuple[str, str, type[Any], TypeAdapter[Any]]]] + __parameters__: ClassVar[Sequence[tuple[str, TypeAdapter[Any]]]] + __payload__: ClassVar[tuple[str, type[Any], TypeAdapter[Any]] | None] + __bindings__: ClassVar[Sequence[tuple[str, str, str, type[Any], TypeAdapter[Any]]]] def __init_subclass__(cls, address: str | None = None, **kwargs: Any) -> None: cls.__address__ = address annotations = list(_generate_message_annotations(address, cls.__annotations__)) - headers = { - name: ( + headers = tuple( + ( + name, ( get_args(annotated)[1].alias if get_args(annotated)[1].alias else name.replace("_", "-") ), - _Field(annotated), + annotated, + TypeAdapter(annotated), ) for name, annotated in annotations if isinstance(get_args(annotated)[1], Header) - } + ) - parameters = { - name: TypeAdapter(annotated) + parameters = tuple( + (name, TypeAdapter(annotated)) for name, annotated in annotations if isinstance(get_args(annotated)[1], Parameter) - } + ) - bindings = { - name: ( + bindings = tuple( + ( + name, get_args(annotated)[1].__protocol__, get_args(annotated)[1].__field_name__, - _Field(annotated), + annotated, + TypeAdapter(annotated), ) for name, annotated in annotations if isinstance(get_args(annotated)[1], Binding) - } + ) payloads = [ - (name, _Field(annotated)) + (name, annotated, TypeAdapter(annotated)) for name, annotated in annotations if isinstance(get_args(annotated)[1], Payload) ] @@ -87,8 +83,8 @@ def __getitem__(self, key: str, /) -> Any: elif key == "headers": return self._get_headers() elif key == "payload" and self.__payload__: - name, field = self.__payload__ - return field.type_adapter.dump_json(getattr(self, name)) + name, _, type_adapter = self.__payload__ + return type_adapter.dump_json(getattr(self, name)) elif key == "bindings" and self.__bindings__: return self._get_bindings() raise KeyError(key) @@ -110,14 +106,14 @@ def _get_address(self) -> str | None: return None parameters = { name: type_adapter.dump_python(getattr(self, name)) - for name, type_adapter in self.__parameters__.items() + for name, type_adapter in self.__parameters__ } return self.__address__.format(**parameters) def _generate_headers(self) -> Iterable[tuple[str, bytes]]: - for name, (alias, field) in self.__headers__.items(): - yield alias, self._get_value(name, field.type_adapter) + for name, alias, _, type_adapter in self.__headers__: + yield alias, self._get_value(name, type_adapter) def _get_headers(self) -> Iterable[tuple[bytes, bytes]]: return [(name.encode(), value) for name, value in self._generate_headers()] @@ -131,9 +127,9 @@ def _get_value(self, name: str, type_adapter: TypeAdapter[Any]) -> bytes: def _get_bindings(self) -> dict[str, dict[str, Any]]: bindings: dict[str, dict[str, Any]] = {} - for name, (protocol, field_name, field) in self.__bindings__.items(): + for name, protocol, field_name, _, type_adapter in self.__bindings__: bindings.setdefault(protocol, {})[field_name] = self._get_value( - name, field.type_adapter + name, type_adapter ) return bindings From 209f6619e0ced96a01bdc3e4870511cc5bfb9b5e Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Mon, 23 Feb 2026 20:02:57 +0000 Subject: [PATCH 2/5] feat(amgi-kafka-event-source-mapping): add invocation hook --- .../__init__.py | 70 +++++++++++++------ ...test_kafka_event_source_mapping_handler.py | 70 +++++++++++++++++++ ...vent_source_mapping_handler_integration.py | 1 + 3 files changed, 118 insertions(+), 23 deletions(-) diff --git a/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py b/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py index 2d616f1..59c926b 100644 --- a/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py +++ b/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py @@ -5,16 +5,19 @@ import signal import sys from asyncio import Task +from collections.abc import AsyncGenerator from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Iterable from collections.abc import Sequence +from contextlib import asynccontextmanager from contextlib import AsyncExitStack from dataclasses import dataclass from types import TracebackType from typing import Any from typing import AsyncContextManager from typing import Literal +from typing import Protocol from amgi_aiokafka import MessageSend as AioKafkaMessageSend from amgi_common import Lifespan @@ -71,6 +74,24 @@ class _KafkaEventSourceMapping(TypedDict): records: dict[str, list[_KafkaRecord]] +class _InvocationHook(Protocol): + def __call__( + self, + event: _KafkaEventSourceMapping, + context: Any, + ) -> AsyncContextManager[None]: + """ + Wraps one Lambda invocation + """ + + +@asynccontextmanager +async def _noop_hook( + event: _KafkaEventSourceMapping, context: Any +) -> AsyncGenerator[None, None]: + yield + + class _Send: def __init__(self, record_nack: _RecordNack, message_send: _MessageSendT) -> None: self._message_send = message_send @@ -186,10 +207,12 @@ def __init__( lifespan: bool = True, on_nack: Literal["log", "error"] = "log", message_send: _MessageSendManagerT | None = None, + invocation_hook: _InvocationHook | None = None, ) -> None: self._app = app self._on_nack = on_nack self._lifespan = lifespan + self._invocation_hook = invocation_hook or _noop_hook self._loop = asyncio.new_event_loop() self._message_send_manager = ( _MessageSender() @@ -208,32 +231,33 @@ def __init__( pass def __call__(self, event: _KafkaEventSourceMapping, context: Any) -> None: - return self._loop.run_until_complete(self._call(event)) - - async def _call(self, event: _KafkaEventSourceMapping) -> None: - if not self._lifespan_context and self._lifespan: - self._lifespan_context = Lifespan(self._app, self._state) - await self._lifespan_context.__aenter__() - if self._message_sender is None: - self._message_sender = await self._message_send_manager.__aenter__() - - record_nacks = await asyncio.gather( - *( - self._call_source_batch( - event["bootstrapServers"], - _partition_records_topic(records), - records, - self._message_sender, + return self._loop.run_until_complete(self._call(event, context)) + + async def _call(self, event: _KafkaEventSourceMapping, context: Any) -> None: + async with self._invocation_hook(event, context): + if not self._lifespan_context and self._lifespan: + self._lifespan_context = Lifespan(self._app, self._state) + await self._lifespan_context.__aenter__() + if self._message_sender is None: + self._message_sender = await self._message_send_manager.__aenter__() + + record_nacks = await asyncio.gather( + *( + self._call_source_batch( + event["bootstrapServers"], + _partition_records_topic(records), + records, + self._message_sender, + ) + for records in event["records"].values() ) - for records in event["records"].values() ) - ) - all_nacks = tuple(itertools.chain.from_iterable(record_nacks)) - if self._on_nack == "error" and all_nacks: - raise NackError(all_nacks) - for nack in all_nacks: - _logger.error(str(nack)) + all_nacks = tuple(itertools.chain.from_iterable(record_nacks)) + if self._on_nack == "error" and all_nacks: + raise NackError(all_nacks) + for nack in all_nacks: + _logger.error(str(nack)) async def _call_source_batch( self, diff --git a/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler.py b/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler.py index 5f147da..7042907 100644 --- a/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler.py +++ b/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler.py @@ -6,6 +6,7 @@ from uuid import uuid4 import pytest +from amgi_kafka_event_source_mapping import _KafkaEventSourceMapping from amgi_kafka_event_source_mapping import KafkaEventSourceMappingHandler from amgi_kafka_event_source_mapping import NackError from amgi_types import AMGIReceiveCallable @@ -57,6 +58,7 @@ async def test_kafka_event_source_mapping_handler_records() -> None: ] }, }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -122,6 +124,7 @@ async def test_kafka_event_source_mapping_handler_error_nack() -> None: ] }, }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -191,6 +194,7 @@ async def test_kafka_event_source_mapping_handler_log_nack( ] }, }, + Mock(), ) ) with caplog.at_level(logging.ERROR, logger="amgi-kafka-event-source-mapping.error"): @@ -235,6 +239,7 @@ async def test_lifespan() -> None: "bootstrapServers": "b-2.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092,b-1.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092", "records": {}, }, + Mock(), ) ) async with app.lifespan({"item": state_item}): @@ -277,6 +282,7 @@ async def test_lifespan() -> None: ] }, }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -385,6 +391,7 @@ async def test_kafka_event_source_mapping_handler_message_send() -> None: ] }, }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -462,6 +469,7 @@ async def test_kafka_event_source_mapping_receive_not_callable() -> None: ] }, }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -469,3 +477,65 @@ async def test_kafka_event_source_mapping_receive_not_callable() -> None: await receive() await call_task + + +async def test_kafka_event_source_mapping_handler_invocation_hook() -> None: + app = MockApp() + mock_invocation_hook = Mock(return_value=AsyncMock()) + kafka_event_source_mapping_handler = KafkaEventSourceMappingHandler( + app, + lifespan=False, + message_send=AsyncMock(), + invocation_hook=mock_invocation_hook, + ) + + mock_context = Mock() + event: _KafkaEventSourceMapping = { + "eventSource": "aws:kafka", + "eventSourceArn": "arn:aws:kafka:us-east-1:123456789012:cluster/vpc-2priv-2pub/751d2973-a626-431c-9d4e-d7975eb44dd7-2", + "bootstrapServers": "b-2.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092,b-1.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092", + "records": { + "mytopic-0": [ + { + "topic": "mytopic", + "partition": 0, + "offset": 15, + "timestamp": 1545084650987, + "timestampType": "CREATE_TIME", + "key": "a2V5", + "value": "SGVsbG8sIHRoaXMgaXMgYSB0ZXN0Lg==", + "headers": [ + { + "headerKey": [ + 104, + 101, + 97, + 100, + 101, + 114, + 86, + 97, + 108, + 117, + 101, + ] + } + ], + }, + ] + }, + } + call_task = asyncio.get_running_loop().create_task( + kafka_event_source_mapping_handler._call( + event, + mock_context, + ) + ) + async with app.call(): + mock_invocation_hook.assert_called_once_with(event, mock_context) + mock_invocation_hook.return_value.__aenter__.assert_awaited_once() + mock_invocation_hook.return_value.__aexit__.assert_not_awaited() + + await call_task + + mock_invocation_hook.return_value.__aexit__.assert_awaited_once() diff --git a/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler_integration.py b/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler_integration.py index 78ea71b..590d435 100644 --- a/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler_integration.py +++ b/packages/amgi-kafka-event-source-mapping/tests_amgi_kafka_event_source_mapping/test_kafka_event_source_mapping_handler_integration.py @@ -76,6 +76,7 @@ async def test_kafka_event_source_mapping_handler_message_send( ] }, }, + Mock(), ) ) async with AIOKafkaConsumer( From 81882d5dd0a2f3bd4d16a418950985357d7f7065 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Mon, 23 Feb 2026 20:03:26 +0000 Subject: [PATCH 3/5] feat(amgi-sqs-event-source-mapping): add invocation hook --- .../amgi_sqs_event_source_mapping/__init__.py | 89 +++++++++++++------ .../test_sqs_event_source_mapping_handler.py | 74 ++++++++++++--- ...vent_source_mapping_handler_integration.py | 3 + 3 files changed, 127 insertions(+), 39 deletions(-) diff --git a/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py b/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py index bf6f6f6..b818dde 100644 --- a/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py +++ b/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py @@ -7,16 +7,19 @@ import sys import warnings from collections import defaultdict +from collections.abc import AsyncGenerator from collections.abc import Awaitable from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterable from collections.abc import Sequence +from contextlib import asynccontextmanager from functools import cached_property from types import TracebackType from typing import Any from typing import AsyncContextManager from typing import Literal +from typing import Protocol from typing import TypedDict import boto3 @@ -65,6 +68,24 @@ class _SqsEventSourceMapping(TypedDict): Records: list[_Record] +class _InvocationHook(Protocol): + def __call__( + self, + event: _SqsEventSourceMapping, + context: Any, + ) -> AsyncContextManager[None]: + """ + Wraps one Lambda invocation + """ + + +@asynccontextmanager +async def _noop_hook( + event: _SqsEventSourceMapping, context: Any +) -> AsyncGenerator[None, None]: + yield + + class _ItemIdentifier(TypedDict): itemIdentifier: str @@ -257,6 +278,7 @@ def __init__( aws_secret_access_key: str | None = None, lifespan: bool = True, message_send: _MessageSendManagerT | None = None, + invocation_hook: _InvocationHook | None = None, ) -> None: self._app = app self._region_name = region_name @@ -270,6 +292,7 @@ def __init__( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, ) + self._invocation_hook = invocation_hook or _noop_hook self._loop = asyncio.new_event_loop() self._lifespan = lifespan self._lifespan_context: Lifespan | None = None @@ -283,37 +306,45 @@ def __init__( def __call__( self, event: _SqsEventSourceMapping, context: Any ) -> _BatchItemFailures: - return self._loop.run_until_complete(self._call(event)) - - async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures: - if not self._lifespan_context and self._lifespan: - self._lifespan_context = Lifespan(self._app, self._state) - await self._lifespan_context.__aenter__() - if self._message_send is None: - self._message_send = await self._message_send_context.__aenter__() - event_source_arn_records = defaultdict(list) - corrupted_message_ids = [] - for record in event["Records"]: - if hashlib.md5(record["body"].encode()).hexdigest() == record["md5OfBody"]: - event_source_arn_records[record["eventSourceARN"]].append(record) - else: - corrupted_message_ids.append(record["messageId"]) - - unacked_message_ids = await asyncio.gather( - *( - self._call_source_batch(event_source_arn, records, self._message_send) - for event_source_arn, records in event_source_arn_records.items() - ) - ) + return self._loop.run_until_complete(self._call(event, context)) - return { - "batchItemFailures": [ - {"itemIdentifier": message_id} - for message_id in itertools.chain( - *unacked_message_ids, corrupted_message_ids + async def _call( + self, event: _SqsEventSourceMapping, context: Any + ) -> _BatchItemFailures: + async with self._invocation_hook(event, context): + if not self._lifespan_context and self._lifespan: + self._lifespan_context = Lifespan(self._app, self._state) + await self._lifespan_context.__aenter__() + if self._message_send is None: + self._message_send = await self._message_send_context.__aenter__() + event_source_arn_records = defaultdict(list) + corrupted_message_ids = [] + for record in event["Records"]: + if ( + hashlib.md5(record["body"].encode()).hexdigest() + == record["md5OfBody"] + ): + event_source_arn_records[record["eventSourceARN"]].append(record) + else: + corrupted_message_ids.append(record["messageId"]) + + unacked_message_ids = await asyncio.gather( + *( + self._call_source_batch( + event_source_arn, records, self._message_send + ) + for event_source_arn, records in event_source_arn_records.items() ) - ] - } + ) + + return { + "batchItemFailures": [ + {"itemIdentifier": message_id} + for message_id in itertools.chain( + *unacked_message_ids, corrupted_message_ids + ) + ] + } async def _call_source_batch( self, diff --git a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py index c808de3..706b234 100644 --- a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py +++ b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py @@ -11,6 +11,7 @@ import amgi_sqs_event_source_mapping import boto3 import pytest +from amgi_sqs_event_source_mapping import _SqsEventSourceMapping from amgi_sqs_event_source_mapping import SqsEventSourceMappingHandler from amgi_types import AMGIReceiveCallable from amgi_types import AMGISendCallable @@ -33,11 +34,7 @@ async def app_sqs_handler() -> ( loop = asyncio.get_event_loop() - call_task = loop.create_task( - sqs_handler._call( - {"Records": []}, - ) - ) + call_task = loop.create_task(sqs_handler._call({"Records": []}, Mock())) async with app.lifespan(): yield app, sqs_handler shutdown_task = loop.create_task(sqs_handler._shutdown()) @@ -90,6 +87,7 @@ async def test_sqs_handler_records( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -143,6 +141,7 @@ async def test_sqs_handler_record_nack( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -202,6 +201,7 @@ async def test_sqs_handler_record_unacked( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -253,6 +253,7 @@ async def test_sqs_handler_record_message_attribute_binary_value( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -307,6 +308,7 @@ async def test_sqs_handler_record_corrupted( } ] }, + Mock(), ) ) @@ -325,11 +327,7 @@ async def test_lifespan() -> None: loop = asyncio.get_event_loop() state_item = uuid4() - lifespan_task = loop.create_task( - sqs_handler._call( - {"Records": []}, - ) - ) + lifespan_task = loop.create_task(sqs_handler._call({"Records": []}, Mock())) async with app.lifespan({"item": state_item}): await lifespan_task @@ -362,6 +360,7 @@ async def test_lifespan() -> None: } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -496,6 +495,7 @@ async def test_sqs_handler_records_receive_not_callable( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -503,3 +503,57 @@ async def test_sqs_handler_records_receive_not_callable( await receive() await call_task + + +async def test_sqs_event_source_mapping_handler_invocation_hook() -> None: + app = MockApp() + mock_invocation_hook = Mock(return_value=AsyncMock()) + sqs_event_source_mapping_handler = SqsEventSourceMappingHandler( + app, + lifespan=False, + message_send=AsyncMock(), + invocation_hook=mock_invocation_hook, + ) + + mock_context = Mock() + event: _SqsEventSourceMapping = { + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185", + }, + "messageAttributes": { + "myAttribute": { + "stringValue": "myValue", + "stringListValues": [], + "binaryListValues": [], + "dataType": "String", + } + }, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2", + } + ] + } + call_task = asyncio.get_running_loop().create_task( + sqs_event_source_mapping_handler._call( + event, + mock_context, + ) + ) + async with app.call(): + mock_invocation_hook.assert_called_once_with(event, mock_context) + mock_invocation_hook.return_value.__aenter__.assert_awaited_once() + mock_invocation_hook.return_value.__aexit__.assert_not_awaited() + + await call_task + + mock_invocation_hook.return_value.__aexit__.assert_awaited_once() diff --git a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py index 6401b5c..e4f690f 100644 --- a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py +++ b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py @@ -1,5 +1,6 @@ import asyncio from typing import Generator +from unittest.mock import Mock from uuid import uuid4 import pytest @@ -65,6 +66,7 @@ async def test_sqs_handler_record_send( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): @@ -140,6 +142,7 @@ async def test_sqs_handler_record_send_invalid_message( } ] }, + Mock(), ) ) async with app.call() as (scope, receive, send): From acf2919234d47039eafb0ea3c07ab6112957f421 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Wed, 25 Feb 2026 22:23:18 +0000 Subject: [PATCH 4/5] feat(amgi-kafka-event-source-mapping): make kafka event source mapping handler init options keyword-only BREAKING CHANGE: --- .../src/amgi_kafka_event_source_mapping/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py b/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py index 59c926b..31c58c0 100644 --- a/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py +++ b/packages/amgi-kafka-event-source-mapping/src/amgi_kafka_event_source_mapping/__init__.py @@ -204,6 +204,7 @@ class KafkaEventSourceMappingHandler: def __init__( self, app: AMGIApplication, + *, lifespan: bool = True, on_nack: Literal["log", "error"] = "log", message_send: _MessageSendManagerT | None = None, From ef7c776b08781b5b693a7b7805cf6ed2c87f3bf2 Mon Sep 17 00:00:00 2001 From: "jack.burridge" Date: Wed, 25 Feb 2026 22:21:43 +0000 Subject: [PATCH 5/5] feat(amgi-sqs-event-source-mapping): make sqs event source mapping handler init options keyword-only BREAKING CHANGE: --- .../src/amgi_sqs_event_source_mapping/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py b/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py index b818dde..7949e21 100644 --- a/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py +++ b/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py @@ -272,6 +272,7 @@ class SqsEventSourceMappingHandler: def __init__( self, app: AMGIApplication, + *, region_name: str | None = None, endpoint_url: str | None = None, aws_access_key_id: str | None = None,