Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/03_examples/code/02_tasks_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ async def main() -> None:
print('Task clients created:', apify_task_clients)

# Execute Apify tasks
task_run_results = await asyncio.gather(
*[client.call() for client in apify_task_clients]
)
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(client.call()) for client in apify_task_clients]

task_run_results = [task.result() for task in tasks]

# Filter out None results (tasks that failed to return a run)
successful_runs = [run for run in task_run_results if run is not None]
Expand Down
39 changes: 20 additions & 19 deletions src/apify_client/_resource_clients/request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,7 +798,6 @@ async def batch_add_requests(
if min_delay_between_unprocessed_requests_retries:
logger.warning('`min_delay_between_unprocessed_requests_retries` is deprecated and not used anymore.')

tasks = set[asyncio.Task]()
asyncio_queue: asyncio.Queue[Iterable[dict]] = asyncio.Queue()
request_params = self._build_params(clientKey=self.client_key, forefront=forefront)

Expand All @@ -815,29 +814,31 @@ async def batch_add_requests(
for batch in batches:
await asyncio_queue.put(batch)

# Start a required number of worker tasks to process the batches.
for i in range(max_parallel):
coro = self._batch_add_requests_worker(
asyncio_queue,
request_params,
)
task = asyncio.create_task(coro, name=f'batch_add_requests_worker_{i}')
tasks.add(task)

# Wait for all batches to be processed.
await asyncio_queue.join()

# Send cancellation signals to all worker tasks and wait for them to finish.
for task in tasks:
task.cancel()

results: list[BatchAddResponse] = await asyncio.gather(*tasks)
# Use TaskGroup for structured concurrency — automatic cleanup and error propagation.
try:
async with asyncio.TaskGroup() as tg:
workers = [
tg.create_task(
self._batch_add_requests_worker(asyncio_queue, request_params),
name=f'batch_add_requests_worker_{i}',
)
for i in range(max_parallel)
]

# Wait for all batches to be processed, then cancel idle workers.
await asyncio_queue.join()
for worker in workers:
worker.cancel()
except ExceptionGroup as eg:
# Re-raise the first worker exception directly to maintain backward-compatible error types.
raise eg.exceptions[0] from None

# Combine the results from all workers and return them.
processed_requests = list[AddedRequest]()
unprocessed_requests = list[RequestDraft]()

for result in results:
for worker in workers:
result = worker.result()
processed_requests.extend(result.data.processed_requests)
unprocessed_requests.extend(result.data.unprocessed_requests)

Expand Down