Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions packages/amgi-aiokafka/src/amgi_aiokafka/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import asyncio
import logging
import sys
from asyncio import Lock
from collections.abc import AsyncIterable
from collections.abc import AsyncIterator
from collections.abc import Awaitable
from collections.abc import Callable
from types import TracebackType
from typing import Any
from typing import AsyncContextManager
from typing import Literal
from typing import TYPE_CHECKING
from typing import TypeVar

from aiokafka import AIOKafkaConsumer
from aiokafka import AIOKafkaProducer
Expand All @@ -21,6 +26,16 @@
from amgi_types import MessageScope
from amgi_types import MessageSendEvent

if TYPE_CHECKING:

class _ConsumerRebalanceListener: ...

else:
from aiokafka import ConsumerRebalanceListener as _ConsumerRebalanceListener


T = TypeVar("T")

if sys.version_info >= (3, 11):
from typing import Self
else:
Expand Down Expand Up @@ -131,6 +146,58 @@ async def __aexit__(
await self._producer.stop()


class _HoldingLock:
def __init__(self, lock: Lock):
self._lock = lock
self._holding = False

def release(self) -> None:
if self._holding:
self._lock.release()
self._holding = False

async def acquire(self) -> None:
await self._lock.acquire()
self._holding = True


async def _iter_with_lock(
iterable: AsyncIterable[T],
lock: Lock,
) -> AsyncIterator[T]:
iterator = iterable.__aiter__()
holding_lock = _HoldingLock(lock)
try:
while True:
holding_lock.release()
await holding_lock.acquire()
try:
yield await iterator.__anext__()
except StopAsyncIteration:
break
finally:
holding_lock.release()


class _RebalanceListener(_ConsumerRebalanceListener):
def __init__(self, rebalance_lock: Lock) -> None:
self._rebalance_lock = rebalance_lock

async def on_partitions_revoked(
self,
revoked: list[TopicPartition],
) -> None:
async with self._rebalance_lock:
pass

async def on_partitions_assigned(
self,
assigned: list[TopicPartition],
) -> None:
async with self._rebalance_lock:
pass


class Server:
_consumer: AIOKafkaConsumer

Expand All @@ -150,38 +217,40 @@ def __init__(
self._auto_offset_reset = auto_offset_reset
self._message_send = message_send or MessageSend(bootstrap_servers)
self._ackable_consumer = self._group_id is not None
self._rebalance_lock = Lock()
self._stoppable = Stoppable()

async def serve(self) -> None:
self._consumer = AIOKafkaConsumer(
*self._topics,
bootstrap_servers=self._bootstrap_servers,
group_id=self._group_id,
enable_auto_commit=False,
auto_offset_reset=self._auto_offset_reset,
)
self._consumer.subscribe(
topics=list(self._topics),
listener=_RebalanceListener(self._rebalance_lock),
)
async with self._consumer, self._message_send as message_send:
async with Lifespan(self._app) as state:
await self._main_loop(state, message_send)

async def _main_loop(
self, state: dict[str, Any], message_send: _MessageSendT
) -> None:
async for messages in self._stoppable.call(
self._consumer.getmany, timeout_ms=1000
async for messages in _iter_with_lock(
self._stoppable.call(self._consumer.getmany, timeout_ms=1000),
self._rebalance_lock,
):
await asyncio.gather(
*[
self._handle_partition_records(
topic_partition, records, message_send, state
)
self._handle_partition_records(records, message_send, state)
for topic_partition, records in messages.items()
]
)

async def _handle_partition_records(
self,
topic_partition: TopicPartition,
records: list[ConsumerRecord],
message_send: _MessageSendT,
state: dict[str, Any],
Expand Down
Loading