diff --git a/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py b/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py index 789cfac..4d28443 100644 --- a/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py +++ b/packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py @@ -1,12 +1,17 @@ import asyncio import logging import sys +from asyncio import Lock +from collections.abc import AsyncIterable +from collections.abc import AsyncIterator from collections.abc import Awaitable from collections.abc import Callable from types import TracebackType from typing import Any from typing import AsyncContextManager from typing import Literal +from typing import TYPE_CHECKING +from typing import TypeVar from aiokafka import AIOKafkaConsumer from aiokafka import AIOKafkaProducer @@ -21,6 +26,16 @@ from amgi_types import MessageScope from amgi_types import MessageSendEvent +if TYPE_CHECKING: + + class _ConsumerRebalanceListener: ... + +else: + from aiokafka import ConsumerRebalanceListener as _ConsumerRebalanceListener + + +T = TypeVar("T") + if sys.version_info >= (3, 11): from typing import Self else: @@ -131,6 +146,58 @@ async def __aexit__( await self._producer.stop() +class _HoldingLock: + def __init__(self, lock: Lock): + self._lock = lock + self._holding = False + + def release(self) -> None: + if self._holding: + self._lock.release() + self._holding = False + + async def acquire(self) -> None: + await self._lock.acquire() + self._holding = True + + +async def _iter_with_lock( + iterable: AsyncIterable[T], + lock: Lock, +) -> AsyncIterator[T]: + iterator = iterable.__aiter__() + holding_lock = _HoldingLock(lock) + try: + while True: + holding_lock.release() + await holding_lock.acquire() + try: + yield await iterator.__anext__() + except StopAsyncIteration: + break + finally: + holding_lock.release() + + +class _RebalanceListener(_ConsumerRebalanceListener): + def __init__(self, rebalance_lock: Lock) -> None: + self._rebalance_lock = rebalance_lock + + async def on_partitions_revoked( + self, + revoked: list[TopicPartition], + ) -> None: + async with self._rebalance_lock: + pass + + async def on_partitions_assigned( + self, + assigned: list[TopicPartition], + ) -> None: + async with self._rebalance_lock: + pass + + class Server: _consumer: AIOKafkaConsumer @@ -150,16 +217,20 @@ def __init__( self._auto_offset_reset = auto_offset_reset self._message_send = message_send or MessageSend(bootstrap_servers) self._ackable_consumer = self._group_id is not None + self._rebalance_lock = Lock() self._stoppable = Stoppable() async def serve(self) -> None: self._consumer = AIOKafkaConsumer( - *self._topics, bootstrap_servers=self._bootstrap_servers, group_id=self._group_id, enable_auto_commit=False, auto_offset_reset=self._auto_offset_reset, ) + self._consumer.subscribe( + topics=list(self._topics), + listener=_RebalanceListener(self._rebalance_lock), + ) async with self._consumer, self._message_send as message_send: async with Lifespan(self._app) as state: await self._main_loop(state, message_send) @@ -167,21 +238,19 @@ async def serve(self) -> None: async def _main_loop( self, state: dict[str, Any], message_send: _MessageSendT ) -> None: - async for messages in self._stoppable.call( - self._consumer.getmany, timeout_ms=1000 + async for messages in _iter_with_lock( + self._stoppable.call(self._consumer.getmany, timeout_ms=1000), + self._rebalance_lock, ): await asyncio.gather( *[ - self._handle_partition_records( - topic_partition, records, message_send, state - ) + self._handle_partition_records(records, message_send, state) for topic_partition, records in messages.items() ] ) async def _handle_partition_records( self, - topic_partition: TopicPartition, records: list[ConsumerRecord], message_send: _MessageSendT, state: dict[str, Any],