Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@
import signal
import sys
from asyncio import Task
from collections.abc import AsyncGenerator
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Sequence
from contextlib import asynccontextmanager
from contextlib import AsyncExitStack
from dataclasses import dataclass
from types import TracebackType
from typing import Any
from typing import AsyncContextManager
from typing import Literal
from typing import Protocol

from amgi_aiokafka import MessageSend as AioKafkaMessageSend
from amgi_common import Lifespan
Expand Down Expand Up @@ -71,6 +74,24 @@ class _KafkaEventSourceMapping(TypedDict):
records: dict[str, list[_KafkaRecord]]


class _InvocationHook(Protocol):
def __call__(
self,
event: _KafkaEventSourceMapping,
context: Any,
) -> AsyncContextManager[None]:
"""
Wraps one Lambda invocation
"""


@asynccontextmanager
async def _noop_hook(
event: _KafkaEventSourceMapping, context: Any
) -> AsyncGenerator[None, None]:
yield


class _Send:
def __init__(self, record_nack: _RecordNack, message_send: _MessageSendT) -> None:
self._message_send = message_send
Expand Down Expand Up @@ -183,13 +204,16 @@ class KafkaEventSourceMappingHandler:
def __init__(
self,
app: AMGIApplication,
*,
lifespan: bool = True,
on_nack: Literal["log", "error"] = "log",
message_send: _MessageSendManagerT | None = None,
invocation_hook: _InvocationHook | None = None,
) -> None:
self._app = app
self._on_nack = on_nack
self._lifespan = lifespan
self._invocation_hook = invocation_hook or _noop_hook
self._loop = asyncio.new_event_loop()
self._message_send_manager = (
_MessageSender()
Expand All @@ -208,32 +232,33 @@ def __init__(
pass

def __call__(self, event: _KafkaEventSourceMapping, context: Any) -> None:
return self._loop.run_until_complete(self._call(event))

async def _call(self, event: _KafkaEventSourceMapping) -> None:
if not self._lifespan_context and self._lifespan:
self._lifespan_context = Lifespan(self._app, self._state)
await self._lifespan_context.__aenter__()
if self._message_sender is None:
self._message_sender = await self._message_send_manager.__aenter__()

record_nacks = await asyncio.gather(
*(
self._call_source_batch(
event["bootstrapServers"],
_partition_records_topic(records),
records,
self._message_sender,
return self._loop.run_until_complete(self._call(event, context))

async def _call(self, event: _KafkaEventSourceMapping, context: Any) -> None:
async with self._invocation_hook(event, context):
if not self._lifespan_context and self._lifespan:
self._lifespan_context = Lifespan(self._app, self._state)
await self._lifespan_context.__aenter__()
if self._message_sender is None:
self._message_sender = await self._message_send_manager.__aenter__()

record_nacks = await asyncio.gather(
*(
self._call_source_batch(
event["bootstrapServers"],
_partition_records_topic(records),
records,
self._message_sender,
)
for records in event["records"].values()
)
for records in event["records"].values()
)
)

all_nacks = tuple(itertools.chain.from_iterable(record_nacks))
if self._on_nack == "error" and all_nacks:
raise NackError(all_nacks)
for nack in all_nacks:
_logger.error(str(nack))
all_nacks = tuple(itertools.chain.from_iterable(record_nacks))
if self._on_nack == "error" and all_nacks:
raise NackError(all_nacks)
for nack in all_nacks:
_logger.error(str(nack))

async def _call_source_batch(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from uuid import uuid4

import pytest
from amgi_kafka_event_source_mapping import _KafkaEventSourceMapping
from amgi_kafka_event_source_mapping import KafkaEventSourceMappingHandler
from amgi_kafka_event_source_mapping import NackError
from amgi_types import AMGIReceiveCallable
Expand Down Expand Up @@ -57,6 +58,7 @@ async def test_kafka_event_source_mapping_handler_records() -> None:
]
},
},
Mock(),
)
)
async with app.call() as (scope, receive, send):
Expand Down Expand Up @@ -122,6 +124,7 @@ async def test_kafka_event_source_mapping_handler_error_nack() -> None:
]
},
},
Mock(),
)
)
async with app.call() as (scope, receive, send):
Expand Down Expand Up @@ -191,6 +194,7 @@ async def test_kafka_event_source_mapping_handler_log_nack(
]
},
},
Mock(),
)
)
with caplog.at_level(logging.ERROR, logger="amgi-kafka-event-source-mapping.error"):
Expand Down Expand Up @@ -235,6 +239,7 @@ async def test_lifespan() -> None:
"bootstrapServers": "b-2.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092,b-1.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092",
"records": {},
},
Mock(),
)
)
async with app.lifespan({"item": state_item}):
Expand Down Expand Up @@ -277,6 +282,7 @@ async def test_lifespan() -> None:
]
},
},
Mock(),
)
)
async with app.call() as (scope, receive, send):
Expand Down Expand Up @@ -385,6 +391,7 @@ async def test_kafka_event_source_mapping_handler_message_send() -> None:
]
},
},
Mock(),
)
)
async with app.call() as (scope, receive, send):
Expand Down Expand Up @@ -462,10 +469,73 @@ async def test_kafka_event_source_mapping_receive_not_callable() -> None:
]
},
},
Mock(),
)
)
async with app.call() as (scope, receive, send):
with pytest.raises(RuntimeError, match="Receive should not be called"):
await receive()

await call_task


async def test_kafka_event_source_mapping_handler_invocation_hook() -> None:
app = MockApp()
mock_invocation_hook = Mock(return_value=AsyncMock())
kafka_event_source_mapping_handler = KafkaEventSourceMappingHandler(
app,
lifespan=False,
message_send=AsyncMock(),
invocation_hook=mock_invocation_hook,
)

mock_context = Mock()
event: _KafkaEventSourceMapping = {
"eventSource": "aws:kafka",
"eventSourceArn": "arn:aws:kafka:us-east-1:123456789012:cluster/vpc-2priv-2pub/751d2973-a626-431c-9d4e-d7975eb44dd7-2",
"bootstrapServers": "b-2.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092,b-1.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092",
"records": {
"mytopic-0": [
{
"topic": "mytopic",
"partition": 0,
"offset": 15,
"timestamp": 1545084650987,
"timestampType": "CREATE_TIME",
"key": "a2V5",
"value": "SGVsbG8sIHRoaXMgaXMgYSB0ZXN0Lg==",
"headers": [
{
"headerKey": [
104,
101,
97,
100,
101,
114,
86,
97,
108,
117,
101,
]
}
],
},
]
},
}
call_task = asyncio.get_running_loop().create_task(
kafka_event_source_mapping_handler._call(
event,
mock_context,
)
)
async with app.call():
mock_invocation_hook.assert_called_once_with(event, mock_context)
mock_invocation_hook.return_value.__aenter__.assert_awaited_once()
mock_invocation_hook.return_value.__aexit__.assert_not_awaited()

await call_task

mock_invocation_hook.return_value.__aexit__.assert_awaited_once()
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def test_kafka_event_source_mapping_handler_message_send(
]
},
},
Mock(),
)
)
async with AIOKafkaConsumer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@
import sys
import warnings
from collections import defaultdict
from collections.abc import AsyncGenerator
from collections.abc import Awaitable
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from contextlib import asynccontextmanager
from functools import cached_property
from types import TracebackType
from typing import Any
from typing import AsyncContextManager
from typing import Literal
from typing import Protocol
from typing import TypedDict

import boto3
Expand Down Expand Up @@ -65,6 +68,24 @@ class _SqsEventSourceMapping(TypedDict):
Records: list[_Record]


class _InvocationHook(Protocol):
def __call__(
self,
event: _SqsEventSourceMapping,
context: Any,
) -> AsyncContextManager[None]:
"""
Wraps one Lambda invocation
"""


@asynccontextmanager
async def _noop_hook(
event: _SqsEventSourceMapping, context: Any
) -> AsyncGenerator[None, None]:
yield


class _ItemIdentifier(TypedDict):
itemIdentifier: str

Expand Down Expand Up @@ -251,12 +272,14 @@ class SqsEventSourceMappingHandler:
def __init__(
self,
app: AMGIApplication,
*,
region_name: str | None = None,
endpoint_url: str | None = None,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
lifespan: bool = True,
message_send: _MessageSendManagerT | None = None,
invocation_hook: _InvocationHook | None = None,
) -> None:
self._app = app
self._region_name = region_name
Expand All @@ -270,6 +293,7 @@ def __init__(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
self._invocation_hook = invocation_hook or _noop_hook
self._loop = asyncio.new_event_loop()
self._lifespan = lifespan
self._lifespan_context: Lifespan | None = None
Expand All @@ -283,37 +307,45 @@ def __init__(
def __call__(
self, event: _SqsEventSourceMapping, context: Any
) -> _BatchItemFailures:
return self._loop.run_until_complete(self._call(event))

async def _call(self, event: _SqsEventSourceMapping) -> _BatchItemFailures:
if not self._lifespan_context and self._lifespan:
self._lifespan_context = Lifespan(self._app, self._state)
await self._lifespan_context.__aenter__()
if self._message_send is None:
self._message_send = await self._message_send_context.__aenter__()
event_source_arn_records = defaultdict(list)
corrupted_message_ids = []
for record in event["Records"]:
if hashlib.md5(record["body"].encode()).hexdigest() == record["md5OfBody"]:
event_source_arn_records[record["eventSourceARN"]].append(record)
else:
corrupted_message_ids.append(record["messageId"])

unacked_message_ids = await asyncio.gather(
*(
self._call_source_batch(event_source_arn, records, self._message_send)
for event_source_arn, records in event_source_arn_records.items()
)
)
return self._loop.run_until_complete(self._call(event, context))

return {
"batchItemFailures": [
{"itemIdentifier": message_id}
for message_id in itertools.chain(
*unacked_message_ids, corrupted_message_ids
async def _call(
self, event: _SqsEventSourceMapping, context: Any
) -> _BatchItemFailures:
async with self._invocation_hook(event, context):
if not self._lifespan_context and self._lifespan:
self._lifespan_context = Lifespan(self._app, self._state)
await self._lifespan_context.__aenter__()
if self._message_send is None:
self._message_send = await self._message_send_context.__aenter__()
event_source_arn_records = defaultdict(list)
corrupted_message_ids = []
for record in event["Records"]:
if (
hashlib.md5(record["body"].encode()).hexdigest()
== record["md5OfBody"]
):
event_source_arn_records[record["eventSourceARN"]].append(record)
else:
corrupted_message_ids.append(record["messageId"])

unacked_message_ids = await asyncio.gather(
*(
self._call_source_batch(
event_source_arn, records, self._message_send
)
for event_source_arn, records in event_source_arn_records.items()
)
]
}
)

return {
"batchItemFailures": [
{"itemIdentifier": message_id}
for message_id in itertools.chain(
*unacked_message_ids, corrupted_message_ids
)
]
}

async def _call_source_batch(
self,
Expand Down
Loading