diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 2abe8ed..6dcf7dd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,6 @@ jobs: - name: Check commit messages (Commitizen) if: github.event_name == 'pull_request' run: | - set -euo pipefail range="${{ github.event.pull_request.base.sha }}..HEAD" uv run cz check --rev-range "${range}" - name: Run pre-commit diff --git a/packages/amgi-paho-mqtt/tests_amgi_paho_mqtt/test_mqtt_message_integration.py b/packages/amgi-paho-mqtt/tests_amgi_paho_mqtt/test_mqtt_message_integration.py index 953a866..dbac489 100644 --- a/packages/amgi-paho-mqtt/tests_amgi_paho_mqtt/test_mqtt_message_integration.py +++ b/packages/amgi-paho-mqtt/tests_amgi_paho_mqtt/test_mqtt_message_integration.py @@ -25,7 +25,9 @@ def topic() -> str: @pytest.fixture(scope="module") async def mosquitto_container() -> AsyncGenerator[MosquittoContainer, None]: - mosquitto_container = MosquittoContainer().with_volume_mapping( + mosquitto_container = MosquittoContainer( + image="eclipse-mosquitto:2.0.22" + ).with_volume_mapping( Path(__file__).parent / "mqtt.acl", "/mosquitto/config/mqtt.acl", ) diff --git a/packages/amgi-sqs-event-source-mapping/README.md b/packages/amgi-sqs-event-source-mapping/README.md index a6471a5..465703d 100644 --- a/packages/amgi-sqs-event-source-mapping/README.md +++ b/packages/amgi-sqs-event-source-mapping/README.md @@ -16,7 +16,7 @@ This example uses [AsyncFast](https://pypi.org/project/asyncfast/): ```python from dataclasses import dataclass -from amgi_sqs_event_source_mapping import SqsHandler +from amgi_sqs_event_source_mapping import SqsEventSourceMappingHandler from asyncfast import AsyncFast app = AsyncFast() @@ -33,7 +33,7 @@ async def order_queue(order: Order) -> None: ... -handler = SqsHandler(app) +handler = SqsEventSourceMappingHandler(app) ``` ## Contact diff --git a/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py b/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py index e4fc146..d9fd2cc 100644 --- a/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py +++ b/packages/amgi-sqs-event-source-mapping/src/amgi_sqs_event_source_mapping/__init__.py @@ -5,13 +5,18 @@ import re import signal import sys +import warnings from collections import defaultdict from collections import deque +from collections.abc import Awaitable +from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterable from collections.abc import Sequence from functools import cached_property +from types import TracebackType from typing import Any +from typing import AsyncContextManager from typing import Literal from typing import TypedDict @@ -23,11 +28,18 @@ from amgi_types import AMGISendEvent from amgi_types import MessageReceiveEvent from amgi_types import MessageScope +from amgi_types import MessageSendEvent if sys.version_info >= (3, 11): from typing import NotRequired + from typing import Self else: from typing_extensions import NotRequired + from typing_extensions import Self + + +_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]] +_MessageSendManagerT = AsyncContextManager[_MessageSendT] class _AttributeValue(TypedDict): @@ -185,28 +197,72 @@ async def send_message( await self._operation_batcher.enqueue((queue_url, payload, headers)) -class _Send: +class MessageSend: def __init__( self, - queue_url_cache: _QueueUrlCache, - send_batcher: _SendBatcher, - message_ids: set[str], + region_name: str | None = None, + endpoint_url: str | None = None, + aws_access_key_id: str | None = None, + aws_secret_access_key: str | None = None, + ) -> None: + self._region_name = region_name + self._endpoint_url = endpoint_url + self._aws_access_key_id = aws_access_key_id + self._aws_secret_access_key = aws_secret_access_key + self._client_instantiated = False + + async def __aenter__(self) -> Self: + return self + + async def __call__(self, event: MessageSendEvent) -> None: + queue_url = await self._queue_url_cache.get_queue_url(event["address"]) + await self._send_batcher.send_message( + queue_url, event["payload"], event["headers"] + ) + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: + if self._client_instantiated: + self._client.close() + + @cached_property + def _client(self) -> Any: + client = boto3.client( + "sqs", + region_name=self._region_name, + endpoint_url=self._endpoint_url, + aws_access_key_id=self._aws_access_key_id, + aws_secret_access_key=self._aws_secret_access_key, + ) + self._client_instantiated = True + return client + + @cached_property + def _queue_url_cache(self) -> _QueueUrlCache: + return _QueueUrlCache(self._client) + + @cached_property + def _send_batcher(self) -> _SendBatcher: + return _SendBatcher(self._client) + + +class _Send: + def __init__(self, message_ids: set[str], message_send: _MessageSendT) -> None: self.message_ids = message_ids - self._queue_url_cache = queue_url_cache - self._send_batcher = send_batcher + self._message_send = message_send async def __call__(self, event: AMGISendEvent) -> None: if event["type"] == "message.ack": self.message_ids.discard(event["id"]) if event["type"] == "message.send": - queue_url = await self._queue_url_cache.get_queue_url(event["address"]) - await self._send_batcher.send_message( - queue_url, event["payload"], event["headers"] - ) + await self._message_send(event) -class SqsHandler: +class SqsEventSourceMappingHandler: def __init__( self, app: AMGIApplication, @@ -215,43 +271,30 @@ def __init__( aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, lifespan: bool = True, + message_send: _MessageSendManagerT | None = None, ) -> None: self._app = app self._region_name = region_name self._endpoint_url = endpoint_url self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key + self._message_send: _MessageSendT | None = None + self._message_send_context = message_send or MessageSend( + region_name=region_name, + endpoint_url=endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) self._loop = asyncio.new_event_loop() self._lifespan = lifespan self._lifespan_context: Lifespan | None = None self._state: dict[str, Any] = {} - self._client_instantiated = False try: self._loop.add_signal_handler(signal.SIGTERM, self._sigterm_handler) except NotImplementedError: # Windows / non-main thread: no signal handlers via asyncio pass - @cached_property - def _client(self) -> Any: - client = boto3.client( - "sqs", - region_name=self._region_name, - endpoint_url=self._endpoint_url, - aws_access_key_id=self._aws_access_key_id, - aws_secret_access_key=self._aws_secret_access_key, - ) - self._client_instantiated = True - return client - - @cached_property - def _queue_url_cache(self) -> _QueueUrlCache: - return _QueueUrlCache(self._client) - - @cached_property - def _send_batcher(self) -> _SendBatcher: - return _SendBatcher(self._client) - def __call__( self, event: _SqsEventSourceMapping, context: Any ) -> _BatchItemFailures: @@ -261,6 +304,8 @@ async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures: if not self._lifespan_context and self._lifespan: self._lifespan_context = Lifespan(self._app, self._state) await self._lifespan_context.__aenter__() + if self._message_send is None: + self._message_send = await self._message_send_context.__aenter__() event_source_arn_records = defaultdict(list) corrupted_message_ids = [] for record in event["Records"]: @@ -271,7 +316,7 @@ async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures: unacked_message_ids = await asyncio.gather( *( - self._call_source_batch(event_source_arn, records) + self._call_source_batch(event_source_arn, records, self._message_send) for event_source_arn, records in event_source_arn_records.items() ) ) @@ -286,7 +331,10 @@ async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures: } async def _call_source_batch( - self, event_source_arn: str, records: Iterable[_Record] + self, + event_source_arn: str, + records: Iterable[_Record], + message_send: _MessageSendT, ) -> Iterable[str]: event_source_arn_match = EVENT_SOURCE_ARN_PATTERN.match(event_source_arn) message_ids = {record["messageId"] for record in records} @@ -300,7 +348,7 @@ async def _call_source_batch( "extensions": {"message.ack.out_of_order": {}}, } - records_send = _Send(self._queue_url_cache, self._send_batcher, message_ids) + records_send = _Send(message_ids, message_send) await self._app(scope, _Receive(records), records_send) return records_send.message_ids @@ -308,7 +356,19 @@ def _sigterm_handler(self) -> None: self._loop.run_until_complete(self._shutdown()) async def _shutdown(self) -> None: - if self._client_instantiated: - self._client.close() if self._lifespan_context: await self._lifespan_context.__aexit__(None, None, None) + if self._message_send is not None: + await self._message_send_context.__aexit__(None, None, None) + + +def __getattr__(name: str) -> object: + if name == "SqsHandler": + warnings.warn( + "SqsHandler is deprecated; use SqsEventSourceMappingHandler instead.", + DeprecationWarning, + stacklevel=2, + ) + globals()[name] = SqsEventSourceMappingHandler + return SqsEventSourceMappingHandler + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_handler.py b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py similarity index 91% rename from packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_handler.py rename to packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py index 958c899..4a2f562 100644 --- a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_handler.py +++ b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler.py @@ -8,9 +8,10 @@ from unittest.mock import patch from uuid import uuid4 +import amgi_sqs_event_source_mapping import boto3 import pytest -from amgi_sqs_event_source_mapping import SqsHandler +from amgi_sqs_event_source_mapping import SqsEventSourceMappingHandler from amgi_types import AMGIReceiveCallable from amgi_types import AMGISendCallable from amgi_types import Scope @@ -24,9 +25,11 @@ def mock_sqs_client() -> Generator[None, None, None]: @pytest.fixture -async def app_sqs_handler() -> AsyncGenerator[tuple[MockApp, SqsHandler], None]: +async def app_sqs_handler() -> ( + AsyncGenerator[tuple[MockApp, SqsEventSourceMappingHandler], None] +): app = MockApp() - sqs_handler = SqsHandler(app) + sqs_handler = SqsEventSourceMappingHandler(app) loop = asyncio.get_event_loop() @@ -44,18 +47,22 @@ async def app_sqs_handler() -> AsyncGenerator[tuple[MockApp, SqsHandler], None]: @pytest.fixture -def app(app_sqs_handler: tuple[MockApp, SqsHandler]) -> MockApp: +def app(app_sqs_handler: tuple[MockApp, SqsEventSourceMappingHandler]) -> MockApp: return app_sqs_handler[0] @pytest.fixture -def sqs_handler(app_sqs_handler: tuple[MockApp, SqsHandler]) -> SqsHandler: +def sqs_event_source_mapping_handler( + app_sqs_handler: tuple[MockApp, SqsEventSourceMappingHandler], +) -> SqsEventSourceMappingHandler: return app_sqs_handler[1] -async def test_sqs_handler_records(app: MockApp, sqs_handler: SqsHandler) -> None: +async def test_sqs_handler_records( + app: MockApp, sqs_event_source_mapping_handler: SqsEventSourceMappingHandler +) -> None: call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -141,9 +148,11 @@ async def test_sqs_handler_records(app: MockApp, sqs_handler: SqsHandler) -> Non assert batch_item_failures == {"batchItemFailures": []} -async def test_sqs_handler_record_nack(app: MockApp, sqs_handler: SqsHandler) -> None: +async def test_sqs_handler_record_nack( + app: MockApp, sqs_event_source_mapping_handler: SqsEventSourceMappingHandler +) -> None: call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -206,10 +215,10 @@ async def test_sqs_handler_record_nack(app: MockApp, sqs_handler: SqsHandler) -> async def test_sqs_handler_record_unacked( - app: MockApp, sqs_handler: SqsHandler + app: MockApp, sqs_event_source_mapping_handler: SqsEventSourceMappingHandler ) -> None: call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -266,10 +275,10 @@ async def test_sqs_handler_record_unacked( async def test_sqs_handler_record_message_attribute_binary_value( app: MockApp, - sqs_handler: SqsHandler, + sqs_event_source_mapping_handler: SqsEventSourceMappingHandler, ) -> None: call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -325,10 +334,10 @@ async def test_sqs_handler_record_message_attribute_binary_value( async def test_sqs_handler_record_corrupted( - app: MockApp, sqs_handler: SqsHandler + app: MockApp, sqs_event_source_mapping_handler: SqsEventSourceMappingHandler ) -> None: call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -369,7 +378,7 @@ async def test_sqs_handler_record_corrupted( async def test_lifespan() -> None: app = MockApp() - sqs_handler = SqsHandler(app) + sqs_handler = SqsEventSourceMappingHandler(app) loop = asyncio.get_event_loop() state_item = uuid4() @@ -455,7 +464,7 @@ async def _app( queue.put(e) raise - sqs_handler = SqsHandler(_app) + sqs_handler = SqsEventSourceMappingHandler(_app) sqs_handler({"Records": []}, Mock()) @@ -467,7 +476,7 @@ async def _app( def test_sqs_handler_app_not_called_if_invalid_arn() -> None: mock_app = AsyncMock() - sqs_handler = SqsHandler(mock_app, lifespan=False) + sqs_handler = SqsEventSourceMappingHandler(mock_app, lifespan=False) sqs_handler( { "Records": [ @@ -500,3 +509,13 @@ def test_sqs_handler_app_not_called_if_invalid_arn() -> None: ) mock_app.assert_not_awaited() + + +def test_sqs_handler_attribute_is_deprecated() -> None: + with pytest.warns(DeprecationWarning, match="SqsEventSourceMappingHandler"): + amgi_sqs_event_source_mapping.SqsHandler + + +def test_unknown_attribute_raises_attribute_error() -> None: + with pytest.raises(AttributeError): + amgi_sqs_event_source_mapping.unknown diff --git a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_handler_integration.py b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py similarity index 77% rename from packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_handler_integration.py rename to packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py index eb9d944..32325f1 100644 --- a/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_handler_integration.py +++ b/packages/amgi-sqs-event-source-mapping/tests_amgi_sqs_event_source_mapping/test_sqs_event_source_mapping_handler_integration.py @@ -3,8 +3,9 @@ from uuid import uuid4 import pytest +from amgi_sqs_event_source_mapping import MessageSend from amgi_sqs_event_source_mapping import SqsBatchFailureError -from amgi_sqs_event_source_mapping import SqsHandler +from amgi_sqs_event_source_mapping import SqsEventSourceMappingHandler from test_utils import MockApp from testcontainers.localstack import LocalStackContainer @@ -22,7 +23,7 @@ async def test_sqs_handler_record_send( localstack_container: LocalStackContainer, ) -> None: app = MockApp() - sqs_handler = SqsHandler( + sqs_event_source_mapping_handler = SqsEventSourceMappingHandler( app, region_name=localstack_container.region_name, endpoint_url=localstack_container.get_url(), @@ -36,7 +37,7 @@ async def test_sqs_handler_record_send( send_queue_url = sqs_client.create_queue(QueueName=send_queue_name)["QueueUrl"] call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -95,7 +96,7 @@ async def test_sqs_handler_record_send_invalid_message( localstack_container: LocalStackContainer, ) -> None: app = MockApp() - sqs_handler = SqsHandler( + sqs_event_source_mapping_handler = SqsEventSourceMappingHandler( app, region_name=localstack_container.region_name, endpoint_url=localstack_container.get_url(), @@ -111,7 +112,7 @@ async def test_sqs_handler_record_send_invalid_message( ) call_task = asyncio.get_running_loop().create_task( - sqs_handler._call( + sqs_event_source_mapping_handler._call( { "Records": [ { @@ -153,3 +154,37 @@ async def test_sqs_handler_record_send_invalid_message( ) await call_task + + +@pytest.mark.integration +async def test_message_send(localstack_container: LocalStackContainer) -> None: + sqs_client = localstack_container.get_client("sqs") + + send_queue_name = f"send-{uuid4()}" + send_queue_url = sqs_client.create_queue(QueueName=send_queue_name)["QueueUrl"] + + async with MessageSend( + region_name=localstack_container.region_name, + endpoint_url=localstack_container.get_url(), + aws_access_key_id="testcontainers-localstack", + aws_secret_access_key="testcontainers-localstack", + ) as message_send: + await message_send( + { + "type": "message.send", + "address": send_queue_name, + "payload": b"test", + "headers": [(b"test", b"test")], + } + ) + + messages_response = sqs_client.receive_message( + QueueUrl=send_queue_url, MessageAttributeNames=["All"] + ) + assert "Messages" in messages_response + assert len(messages_response["Messages"]) == 1 + message = messages_response["Messages"][0] + assert message["Body"] == "test" + assert message["MessageAttributes"] == { + "test": {"StringValue": "test", "DataType": "StringValue"} + }