From 6e2a4a3b0ecfa5a7ede9c7a18fd1b36b6fe4e5ee Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Tue, 24 Feb 2026 12:00:04 +0100 Subject: [PATCH] refactor: Adopt asyncio.TaskGroup for structured concurrency Replace manual task management (create_task + cancel + gather) with asyncio.TaskGroup in batch_add_requests and update docs example. Closes #598 Co-Authored-By: Claude Opus 4.6 --- docs/03_examples/code/02_tasks_async.py | 7 ++-- .../_resource_clients/request_queue.py | 39 ++++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/docs/03_examples/code/02_tasks_async.py b/docs/03_examples/code/02_tasks_async.py index f7a4c271..19e38304 100644 --- a/docs/03_examples/code/02_tasks_async.py +++ b/docs/03_examples/code/02_tasks_async.py @@ -30,9 +30,10 @@ async def main() -> None: print('Task clients created:', apify_task_clients) # Execute Apify tasks - task_run_results = await asyncio.gather( - *[client.call() for client in apify_task_clients] - ) + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(client.call()) for client in apify_task_clients] + + task_run_results = [task.result() for task in tasks] # Filter out None results (tasks that failed to return a run) successful_runs = [run for run in task_run_results if run is not None] diff --git a/src/apify_client/_resource_clients/request_queue.py b/src/apify_client/_resource_clients/request_queue.py index 907177d7..2a7ac991 100644 --- a/src/apify_client/_resource_clients/request_queue.py +++ b/src/apify_client/_resource_clients/request_queue.py @@ -798,7 +798,6 @@ async def batch_add_requests( if min_delay_between_unprocessed_requests_retries: logger.warning('`min_delay_between_unprocessed_requests_retries` is deprecated and not used anymore.') - tasks = set[asyncio.Task]() asyncio_queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue() request_params = self._build_params(clientKey=self.client_key, forefront=forefront) @@ -815,29 +814,31 @@ async def batch_add_requests( for batch in batches: await asyncio_queue.put(batch) - # Start a required number of worker tasks to process the batches. - for i in range(max_parallel): - coro = self._batch_add_requests_worker( - asyncio_queue, - request_params, - ) - task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}') - tasks.add(task) - - # Wait for all batches to be processed. - await asyncio_queue.join() - - # Send cancellation signals to all worker tasks and wait for them to finish. - for task in tasks: - task.cancel() - - results: list[BatchAddResponse] = await asyncio.gather(*tasks) + # Use TaskGroup for structured concurrency — automatic cleanup and error propagation. + try: + async with asyncio.TaskGroup() as tg: + workers = [ + tg.create_task( + self._batch_add_requests_worker(asyncio_queue, request_params), + name=f'batch_add_requests_worker_{i}', + ) + for i in range(max_parallel) + ] + + # Wait for all batches to be processed, then cancel idle workers. + await asyncio_queue.join() + for worker in workers: + worker.cancel() + except ExceptionGroup as eg: + # Re-raise the first worker exception directly to maintain backward-compatible error types. + raise eg.exceptions[0] from None # Combine the results from all workers and return them. processed_requests = list[AddedRequest]() unprocessed_requests = list[RequestDraft]() - for result in results: + for worker in workers: + result = worker.result() processed_requests.extend(result.data.processed_requests) unprocessed_requests.extend(result.data.unprocessed_requests)