Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ jobs:
- name: Check commit messages (Commitizen)
if: github.event_name == 'pull_request'
run: |
set -euo pipefail
range="${{ github.event.pull_request.base.sha }}..HEAD"
uv run cz check --rev-range "${range}"
- name: Run pre-commit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def topic() -> str:

@pytest.fixture(scope="module")
async def mosquitto_container() -> AsyncGenerator[MosquittoContainer, None]:
mosquitto_container = MosquittoContainer().with_volume_mapping(
mosquitto_container = MosquittoContainer(
image="eclipse-mosquitto:2.0.22"
).with_volume_mapping(
Path(__file__).parent / "mqtt.acl",
"/mosquitto/config/mqtt.acl",
)
Expand Down
4 changes: 2 additions & 2 deletions packages/amgi-sqs-event-source-mapping/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This example uses [AsyncFast](https://pypi.org/project/asyncfast/):
```python
from dataclasses import dataclass

from amgi_sqs_event_source_mapping import SqsHandler
from amgi_sqs_event_source_mapping import SqsEventSourceMappingHandler
from asyncfast import AsyncFast

app = AsyncFast()
Expand All @@ -33,7 +33,7 @@ async def order_queue(order: Order) -> None:
...


handler = SqsHandler(app)
handler = SqsEventSourceMappingHandler(app)
```

## Contact
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
import re
import signal
import sys
import warnings
from collections import defaultdict
from collections import deque
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from functools import cached_property
from types import TracebackType
from typing import Any
from typing import AsyncContextManager
from typing import Literal
from typing import TypedDict

Expand All @@ -23,11 +28,18 @@
from amgi_types import AMGISendEvent
from amgi_types import MessageReceiveEvent
from amgi_types import MessageScope
from amgi_types import MessageSendEvent

if sys.version_info >= (3, 11):
from typing import NotRequired
from typing import Self
else:
from typing_extensions import NotRequired
from typing_extensions import Self


_MessageSendT = Callable[[MessageSendEvent], Awaitable[None]]
_MessageSendManagerT = AsyncContextManager[_MessageSendT]


class _AttributeValue(TypedDict):
Expand Down Expand Up @@ -185,28 +197,72 @@ async def send_message(
await self._operation_batcher.enqueue((queue_url, payload, headers))


class _Send:
class MessageSend:
def __init__(
self,
queue_url_cache: _QueueUrlCache,
send_batcher: _SendBatcher,
message_ids: set[str],
region_name: str | None = None,
endpoint_url: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
) -> None:
self._region_name = region_name
self._endpoint_url = endpoint_url
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._client_instantiated = False

async def __aenter__(self) -> Self:
return self

async def __call__(self, event: MessageSendEvent) -> None:
queue_url = await self._queue_url_cache.get_queue_url(event["address"])
await self._send_batcher.send_message(
queue_url, event["payload"], event["headers"]
)

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._client_instantiated:
self._client.close()

@cached_property
def _client(self) -> Any:
client = boto3.client(
"sqs",
region_name=self._region_name,
endpoint_url=self._endpoint_url,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
)
self._client_instantiated = True
return client

@cached_property
def _queue_url_cache(self) -> _QueueUrlCache:
return _QueueUrlCache(self._client)

@cached_property
def _send_batcher(self) -> _SendBatcher:
return _SendBatcher(self._client)


class _Send:
def __init__(self, message_ids: set[str], message_send: _MessageSendT) -> None:
self.message_ids = message_ids
self._queue_url_cache = queue_url_cache
self._send_batcher = send_batcher
self._message_send = message_send

async def __call__(self, event: AMGISendEvent) -> None:
if event["type"] == "message.ack":
self.message_ids.discard(event["id"])
if event["type"] == "message.send":
queue_url = await self._queue_url_cache.get_queue_url(event["address"])
await self._send_batcher.send_message(
queue_url, event["payload"], event["headers"]
)
await self._message_send(event)


class SqsHandler:
class SqsEventSourceMappingHandler:
def __init__(
self,
app: AMGIApplication,
Expand All @@ -215,43 +271,30 @@ def __init__(
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
lifespan: bool = True,
message_send: _MessageSendManagerT | None = None,
) -> None:
self._app = app
self._region_name = region_name
self._endpoint_url = endpoint_url
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._message_send: _MessageSendT | None = None
self._message_send_context = message_send or MessageSend(
region_name=region_name,
endpoint_url=endpoint_url,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
self._loop = asyncio.new_event_loop()
self._lifespan = lifespan
self._lifespan_context: Lifespan | None = None
self._state: dict[str, Any] = {}
self._client_instantiated = False
try:
self._loop.add_signal_handler(signal.SIGTERM, self._sigterm_handler)
except NotImplementedError:
# Windows / non-main thread: no signal handlers via asyncio
pass

@cached_property
def _client(self) -> Any:
client = boto3.client(
"sqs",
region_name=self._region_name,
endpoint_url=self._endpoint_url,
aws_access_key_id=self._aws_access_key_id,
aws_secret_access_key=self._aws_secret_access_key,
)
self._client_instantiated = True
return client

@cached_property
def _queue_url_cache(self) -> _QueueUrlCache:
return _QueueUrlCache(self._client)

@cached_property
def _send_batcher(self) -> _SendBatcher:
return _SendBatcher(self._client)

def __call__(
self, event: _SqsEventSourceMapping, context: Any
) -> _BatchItemFailures:
Expand All @@ -261,6 +304,8 @@ async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures:
if not self._lifespan_context and self._lifespan:
self._lifespan_context = Lifespan(self._app, self._state)
await self._lifespan_context.__aenter__()
if self._message_send is None:
self._message_send = await self._message_send_context.__aenter__()
event_source_arn_records = defaultdict(list)
corrupted_message_ids = []
for record in event["Records"]:
Expand All @@ -271,7 +316,7 @@ async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures:

unacked_message_ids = await asyncio.gather(
*(
self._call_source_batch(event_source_arn, records)
self._call_source_batch(event_source_arn, records, self._message_send)
for event_source_arn, records in event_source_arn_records.items()
)
)
Expand All @@ -286,7 +331,10 @@ async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures:
}

async def _call_source_batch(
self, event_source_arn: str, records: Iterable[_Record]
self,
event_source_arn: str,
records: Iterable[_Record],
message_send: _MessageSendT,
) -> Iterable[str]:
event_source_arn_match = EVENT_SOURCE_ARN_PATTERN.match(event_source_arn)
message_ids = {record["messageId"] for record in records}
Expand All @@ -300,15 +348,27 @@ async def _call_source_batch(
"extensions": {"message.ack.out_of_order": {}},
}

records_send = _Send(self._queue_url_cache, self._send_batcher, message_ids)
records_send = _Send(message_ids, message_send)
await self._app(scope, _Receive(records), records_send)
return records_send.message_ids

def _sigterm_handler(self) -> None:
self._loop.run_until_complete(self._shutdown())

async def _shutdown(self) -> None:
if self._client_instantiated:
self._client.close()
if self._lifespan_context:
await self._lifespan_context.__aexit__(None, None, None)
if self._message_send is not None:
await self._message_send_context.__aexit__(None, None, None)


def __getattr__(name: str) -> object:
if name == "SqsHandler":
warnings.warn(
"SqsHandler is deprecated; use SqsEventSourceMappingHandler instead.",
DeprecationWarning,
stacklevel=2,
)
globals()[name] = SqsEventSourceMappingHandler
return SqsEventSourceMappingHandler
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading