From 91f7da83287b6eccd5e128433442204c22d01a5e Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 20 Feb 2026 11:24:15 -0700 Subject: [PATCH 1/4] Initial implementation --- CHANGELOG.md | 13 + durabletask/client.py | 343 +++++++++++++----- durabletask/internal/client_helpers.py | 199 ++++++++++ durabletask/internal/grpc_interceptor.py | 85 ++++- durabletask/internal/shared.py | 40 ++ pyproject.toml | 1 + requirements.txt | 1 + tests/durabletask/test_client.py | 92 ++++- .../test_orchestration_async_e2e.py | 180 +++++++++ 9 files changed, 838 insertions(+), 116 deletions(-) create mode 100644 durabletask/internal/client_helpers.py create mode 100644 tests/durabletask/test_orchestration_async_e2e.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f5f1f8d5..7b755818 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v1.4.0 + +ADDED + +- Added `AsyncTaskHubGrpcClient` for asyncio-based applications using `grpc.aio` +- Added `DefaultAsyncClientInterceptorImpl` for async gRPC metadata interceptors +- Added `get_async_grpc_channel` helper for creating async gRPC channels + +CHANGED + +- Refactored `TaskHubGrpcClient` to share request-building and validation logic + with `AsyncTaskHubGrpcClient` via module-level helper functions + ## v1.3.0 ADDED diff --git a/durabletask/client.py b/durabletask/client.py index 2fbd1d22..3e581ea6 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -2,13 +2,13 @@ # Licensed under the MIT License. import logging -import uuid from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime from enum import Enum from typing import Any, List, Optional, Sequence, TypeVar, Union import grpc +import grpc.aio from durabletask.entities import EntityInstanceId from durabletask.entities.entity_metadata import EntityMetadata @@ -17,7 +17,19 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import task -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from durabletask.internal.client_helpers import ( + build_query_entities_req, + build_query_instances_req, + build_purge_by_filter_req, + build_raise_event_req, + build_schedule_new_orchestration_req, + build_signal_entity_req, + build_terminate_req, + check_continuation_token, + log_completion_state, + prepare_async_interceptors, + prepare_sync_interceptors, +) TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -140,16 +152,7 @@ def __init__(self, *, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, default_version: Optional[str] = None): - # If the caller provided metadata, we need to create a new interceptor for it and - # add it to the list of interceptors. - if interceptors is not None: - interceptors = list(interceptors) - if metadata is not None: - interceptors.append(DefaultClientInterceptorImpl(metadata)) - elif metadata is not None: - interceptors = [DefaultClientInterceptorImpl(metadata)] - else: - interceptors = None + interceptors = prepare_sync_interceptors(metadata, interceptors) channel = shared.get_grpc_channel( host_address=host_address, @@ -168,19 +171,12 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu tags: Optional[dict[str, str]] = None, version: Optional[str] = None) -> str: - name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + req = build_schedule_new_orchestration_req( + orchestrator, input=input, instance_id=instance_id, start_at=start_at, + reuse_id_policy=reuse_id_policy, tags=tags, + version=version if version else self.default_version) - req = pb.CreateInstanceRequest( - name=name, - instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=helpers.get_string_value(shared.to_json(input) if input is not None else None), - scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, - version=helpers.get_string_value(version if version else self.default_version), - orchestrationIdReusePolicy=reuse_id_policy, - tags=tags - ) - - self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") + self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") res: pb.CreateInstanceResponse = self._stub.StartInstance(req) return res.instanceId @@ -201,24 +197,10 @@ def get_all_orchestration_states(self, states = [] while True: - req = pb.QueryInstancesRequest( - query=pb.InstanceQuery( - runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None, - createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None, - createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None, - maxInstanceCount=orchestration_query.max_instance_count, - fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs, - continuationToken=_continuation_token - ) - ) + req = build_query_instances_req(orchestration_query, _continuation_token) resp: pb.QueryInstancesResponse = self._stub.QueryInstances(req) states += [parse_orchestration_state(res) for res in resp.orchestrationState] - # Check the value for continuationToken - none or "0" indicates that there are no more results. - if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0": - self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next list of instances...") - if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value: - self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.") - break + if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: break @@ -248,32 +230,17 @@ def wait_for_orchestration_completion(self, instance_id: str, *, self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=timeout) state = new_orchestration_state(req.instanceId, res) - if not state: - return None - - if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None: - details = state.failure_details - self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}") - elif state.runtime_status == OrchestrationStatus.TERMINATED: - self._logger.info(f"Instance '{instance_id}' was terminated.") - elif state.runtime_status == OrchestrationStatus.COMPLETED: - self._logger.info(f"Instance '{instance_id}' completed.") - + log_completion_state(self._logger, instance_id, state) return state except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError raise TimeoutError("Timed-out waiting for the orchestration to complete") else: raise def raise_orchestration_event(self, instance_id: str, event_name: str, *, data: Optional[Any] = None): - req = pb.RaiseEventRequest( - instanceId=instance_id, - name=event_name, - input=helpers.get_string_value(shared.to_json(data) if data is not None else None) - ) + req = build_raise_event_req(instance_id, event_name, data) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") self._stub.RaiseEvent(req) @@ -281,10 +248,7 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True): - req = pb.TerminateRequest( - instanceId=instance_id, - output=helpers.get_string_value(shared.to_json(output) if output is not None else None), - recursive=recursive) + req = build_terminate_req(instance_id, output, recursive) self._logger.info(f"Terminating instance '{instance_id}'.") self._stub.TerminateInstance(req) @@ -315,29 +279,15 @@ def purge_orchestrations_by(self, f"created_time_to={created_time_to}, " f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " f"recursive={recursive}") - resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(pb.PurgeInstancesRequest( - purgeInstanceFilter=pb.PurgeInstanceFilter( - createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None, - createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None, - runtimeStatus=[status.value for status in runtime_status] if runtime_status else None - ), - recursive=recursive - )) + req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive) + resp: pb.PurgeInstancesResponse = self._stub.PurgeInstances(req) return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) def signal_entity(self, entity_instance_id: EntityInstanceId, operation_name: str, input: Optional[Any] = None) -> None: - req = pb.SignalEntityRequest( - instanceId=str(entity_instance_id), - name=operation_name, - input=helpers.get_string_value(shared.to_json(input) if input is not None else None), - requestId=str(uuid.uuid4()), - scheduledTime=None, - parentTraceContext=None, - requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) - ) + req = build_signal_entity_req(entity_instance_id, operation_name, input) self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") self._stub.SignalEntity(req, None) # TODO: Cancellation timeout? @@ -364,24 +314,10 @@ def get_all_entities(self, entities = [] while True: - query_request = pb.QueryEntitiesRequest( - query=pb.EntityQuery( - instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with), - lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None, - lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None, - includeState=entity_query.include_state, - includeTransient=entity_query.include_transient, - pageSize=helpers.get_int_value(entity_query.page_size), - continuationToken=_continuation_token - ) - ) + query_request = build_query_entities_req(entity_query, _continuation_token) resp: pb.QueryEntitiesResponse = self._stub.QueryEntities(query_request) entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] - if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0": - self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, fetching next page of entities...") - if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value: - self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.") - break + if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: break @@ -407,11 +343,218 @@ def clean_entity_storage(self, empty_entities_removed += resp.emptyEntitiesRemoved orphaned_locks_released += resp.orphanedLocksReleased - if resp.continuationToken and resp.continuationToken.value and resp.continuationToken.value != "0": - self._logger.info(f"Received continuation token with value {resp.continuationToken.value}, cleaning next page...") - if _continuation_token and _continuation_token.value and _continuation_token.value == resp.continuationToken.value: - self._logger.warning(f"Received the same continuation token value {resp.continuationToken.value} again, stopping to avoid infinite loop.") - break + if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): + _continuation_token = resp.continuationToken + else: + break + + return CleanEntityStorageResult(empty_entities_removed, orphaned_locks_released) + + +class AsyncTaskHubGrpcClient: + """Async version of TaskHubGrpcClient using grpc.aio for asyncio-based applications.""" + + def __init__(self, *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] = None, + default_version: Optional[str] = None): + + interceptors = prepare_async_interceptors(metadata, interceptors) + + channel = shared.get_async_grpc_channel( + host_address=host_address, + secure_channel=secure_channel, + interceptors=interceptors + ) + self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._logger = shared.get_logger("client", log_handler, log_formatter) + self.default_version = default_version + + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + start_at: Optional[datetime] = None, + reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, + tags: Optional[dict[str, str]] = None, + version: Optional[str] = None) -> str: + + req = build_schedule_new_orchestration_req( + orchestrator, input=input, instance_id=instance_id, start_at=start_at, + reuse_id_policy=reuse_id_policy, tags=tags, + version=version if version else self.default_version) + + self._logger.info(f"Starting new '{req.name}' instance with ID = '{req.instanceId}'.") + res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + return res.instanceId + + async def get_orchestration_state(self, instance_id: str, *, + fetch_payloads: bool = True) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + return new_orchestration_state(req.instanceId, res) + + async def get_all_orchestration_states(self, + orchestration_query: Optional[OrchestrationQuery] = None + ) -> List[OrchestrationState]: + if orchestration_query is None: + orchestration_query = OrchestrationQuery() + _continuation_token = None + + self._logger.info(f"Querying orchestration instances with query: {orchestration_query}") + + states = [] + + while True: + req = build_query_instances_req(orchestration_query, _continuation_token) + resp: pb.QueryInstancesResponse = await self._stub.QueryInstances(req) + states += [parse_orchestration_state(res) for res in resp.orchestrationState] + if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): + _continuation_token = resp.continuationToken + else: + break + + return states + + async def wait_for_orchestration_start(self, instance_id: str, *, + fetch_payloads: bool = False, + timeout: int = 60) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + try: + self._logger.info(f"Waiting up to {timeout}s for instance '{instance_id}' to start.") + res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=timeout) + return new_orchestration_state(req.instanceId, res) + except grpc.aio.AioRpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise TimeoutError("Timed-out waiting for the orchestration to start") + else: + raise + + async def wait_for_orchestration_completion(self, instance_id: str, *, + fetch_payloads: bool = True, + timeout: int = 60) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + try: + self._logger.info(f"Waiting {timeout}s for instance '{instance_id}' to complete.") + res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=timeout) + state = new_orchestration_state(req.instanceId, res) + log_completion_state(self._logger, instance_id, state) + return state + except grpc.aio.AioRpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise TimeoutError("Timed-out waiting for the orchestration to complete") + else: + raise + + async def raise_orchestration_event(self, instance_id: str, event_name: str, *, + data: Optional[Any] = None): + req = build_raise_event_req(instance_id, event_name, data) + + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + await self._stub.RaiseEvent(req) + + async def terminate_orchestration(self, instance_id: str, *, + output: Optional[Any] = None, + recursive: bool = True): + req = build_terminate_req(instance_id, output, recursive) + + self._logger.info(f"Terminating instance '{instance_id}'.") + await self._stub.TerminateInstance(req) + + async def suspend_orchestration(self, instance_id: str): + req = pb.SuspendRequest(instanceId=instance_id) + self._logger.info(f"Suspending instance '{instance_id}'.") + await self._stub.SuspendInstance(req) + + async def resume_orchestration(self, instance_id: str): + req = pb.ResumeRequest(instanceId=instance_id) + self._logger.info(f"Resuming instance '{instance_id}'.") + await self._stub.ResumeInstance(req) + + async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: + req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) + self._logger.info(f"Purging instance '{instance_id}'.") + resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req) + return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) + + async def purge_orchestrations_by(self, + created_time_from: Optional[datetime] = None, + created_time_to: Optional[datetime] = None, + runtime_status: Optional[List[OrchestrationStatus]] = None, + recursive: bool = False) -> PurgeInstancesResult: + self._logger.info("Purging orchestrations by filter: " + f"created_time_from={created_time_from}, " + f"created_time_to={created_time_to}, " + f"runtime_status={[str(status) for status in runtime_status] if runtime_status else None}, " + f"recursive={recursive}") + req = build_purge_by_filter_req(created_time_from, created_time_to, runtime_status, recursive) + resp: pb.PurgeInstancesResponse = await self._stub.PurgeInstances(req) + return PurgeInstancesResult(resp.deletedInstanceCount, resp.isComplete.value) + + async def signal_entity(self, + entity_instance_id: EntityInstanceId, + operation_name: str, + input: Optional[Any] = None) -> None: + req = build_signal_entity_req(entity_instance_id, operation_name, input) + self._logger.info(f"Signaling entity '{entity_instance_id}' operation '{operation_name}'.") + await self._stub.SignalEntity(req, None) + + async def get_entity(self, + entity_instance_id: EntityInstanceId, + include_state: bool = True + ) -> Optional[EntityMetadata]: + req = pb.GetEntityRequest(instanceId=str(entity_instance_id), includeState=include_state) + self._logger.info(f"Getting entity '{entity_instance_id}'.") + res: pb.GetEntityResponse = await self._stub.GetEntity(req) + if not res.exists: + return None + + return EntityMetadata.from_entity_metadata(res.entity, include_state) + + async def get_all_entities(self, + entity_query: Optional[EntityQuery] = None) -> List[EntityMetadata]: + if entity_query is None: + entity_query = EntityQuery() + _continuation_token = None + + self._logger.info(f"Retrieving entities by filter: {entity_query}") + + entities = [] + + while True: + query_request = build_query_entities_req(entity_query, _continuation_token) + resp: pb.QueryEntitiesResponse = await self._stub.QueryEntities(query_request) + entities += [EntityMetadata.from_entity_metadata(entity, query_request.query.includeState) for entity in resp.entities] + if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): + _continuation_token = resp.continuationToken + else: + break + return entities + + async def clean_entity_storage(self, + remove_empty_entities: bool = True, + release_orphaned_locks: bool = True + ) -> CleanEntityStorageResult: + self._logger.info("Cleaning entity storage") + + empty_entities_removed = 0 + orphaned_locks_released = 0 + _continuation_token = None + + while True: + req = pb.CleanEntityStorageRequest( + removeEmptyEntities=remove_empty_entities, + releaseOrphanedLocks=release_orphaned_locks, + continuationToken=_continuation_token + ) + resp: pb.CleanEntityStorageResponse = await self._stub.CleanEntityStorage(req) + empty_entities_removed += resp.emptyEntitiesRemoved + orphaned_locks_released += resp.orphanedLocksReleased + + if check_continuation_token(resp.continuationToken, _continuation_token, self._logger): _continuation_token = resp.continuationToken else: break diff --git a/durabletask/internal/client_helpers.py b/durabletask/internal/client_helpers.py new file mode 100644 index 00000000..cfeef5a7 --- /dev/null +++ b/durabletask/internal/client_helpers.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, TypeVar, Union + +import durabletask.internal.helpers as helpers +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.shared as shared +from durabletask import task +from durabletask.internal.grpc_interceptor import ( + DefaultAsyncClientInterceptorImpl, + DefaultClientInterceptorImpl, +) + +if TYPE_CHECKING: + from durabletask.client import ( + EntityQuery, + OrchestrationQuery, + OrchestrationState, + OrchestrationStatus, + ) + from durabletask.entities import EntityInstanceId + +TInput = TypeVar('TInput') +TOutput = TypeVar('TOutput') + + +def prepare_sync_interceptors( + metadata: Optional[list[tuple[str, str]]], + interceptors: Optional[Sequence[shared.ClientInterceptor]] +) -> Optional[list[shared.ClientInterceptor]]: + """Prepare the list of sync gRPC interceptors, adding a metadata interceptor if needed.""" + result: Optional[list[shared.ClientInterceptor]] = None + if interceptors is not None: + result = list(interceptors) + if metadata is not None: + result.append(DefaultClientInterceptorImpl(metadata)) + elif metadata is not None: + result = [DefaultClientInterceptorImpl(metadata)] + return result + + +def prepare_async_interceptors( + metadata: Optional[list[tuple[str, str]]], + interceptors: Optional[Sequence[shared.AsyncClientInterceptor]] +) -> Optional[list[shared.AsyncClientInterceptor]]: + """Prepare the list of async gRPC interceptors, adding a metadata interceptor if needed.""" + result: Optional[list[shared.AsyncClientInterceptor]] = None + if interceptors is not None: + result = list(interceptors) + if metadata is not None: + result.append(DefaultAsyncClientInterceptorImpl(metadata)) + elif metadata is not None: + result = [DefaultAsyncClientInterceptorImpl(metadata)] + return result + + +def build_schedule_new_orchestration_req( + orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, + input: Optional[TInput], + instance_id: Optional[str], + start_at: Optional[datetime], + reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy], + tags: Optional[dict[str, str]], + version: Optional[str]) -> pb.CreateInstanceRequest: + """Build a CreateInstanceRequest for scheduling a new orchestration.""" + name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + return pb.CreateInstanceRequest( + name=name, + instanceId=instance_id if instance_id else uuid.uuid4().hex, + input=helpers.get_string_value(shared.to_json(input) if input is not None else None), + scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, + version=helpers.get_string_value(version), + orchestrationIdReusePolicy=reuse_id_policy, + tags=tags + ) + + +def build_query_instances_req( + orchestration_query: OrchestrationQuery, + continuation_token) -> pb.QueryInstancesRequest: + """Build a QueryInstancesRequest from an OrchestrationQuery.""" + return pb.QueryInstancesRequest( + query=pb.InstanceQuery( + runtimeStatus=[status.value for status in orchestration_query.runtime_status] if orchestration_query.runtime_status else None, + createdTimeFrom=helpers.new_timestamp(orchestration_query.created_time_from) if orchestration_query.created_time_from else None, + createdTimeTo=helpers.new_timestamp(orchestration_query.created_time_to) if orchestration_query.created_time_to else None, + maxInstanceCount=orchestration_query.max_instance_count, + fetchInputsAndOutputs=orchestration_query.fetch_inputs_and_outputs, + continuationToken=continuation_token + ) + ) + + +def build_purge_by_filter_req( + created_time_from: Optional[datetime], + created_time_to: Optional[datetime], + runtime_status: Optional[List[OrchestrationStatus]], + recursive: bool) -> pb.PurgeInstancesRequest: + """Build a PurgeInstancesRequest for purging orchestrations by filter.""" + return pb.PurgeInstancesRequest( + purgeInstanceFilter=pb.PurgeInstanceFilter( + createdTimeFrom=helpers.new_timestamp(created_time_from) if created_time_from else None, + createdTimeTo=helpers.new_timestamp(created_time_to) if created_time_to else None, + runtimeStatus=[status.value for status in runtime_status] if runtime_status else None + ), + recursive=recursive + ) + + +def build_query_entities_req( + entity_query: EntityQuery, + continuation_token) -> pb.QueryEntitiesRequest: + """Build a QueryEntitiesRequest from an EntityQuery.""" + return pb.QueryEntitiesRequest( + query=pb.EntityQuery( + instanceIdStartsWith=helpers.get_string_value(entity_query.instance_id_starts_with), + lastModifiedFrom=helpers.new_timestamp(entity_query.last_modified_from) if entity_query.last_modified_from else None, + lastModifiedTo=helpers.new_timestamp(entity_query.last_modified_to) if entity_query.last_modified_to else None, + includeState=entity_query.include_state, + includeTransient=entity_query.include_transient, + pageSize=helpers.get_int_value(entity_query.page_size), + continuationToken=continuation_token + ) + ) + + +def check_continuation_token(resp_token, prev_token, logger: logging.Logger) -> bool: + """Check if a continuation token indicates more pages. Returns True to continue, False to stop.""" + if resp_token and resp_token.value and resp_token.value != "0": + logger.info(f"Received continuation token with value {resp_token.value}, fetching next page...") + if prev_token and prev_token.value and prev_token.value == resp_token.value: + logger.warning(f"Received the same continuation token value {resp_token.value} again, stopping to avoid infinite loop.") + return False + return True + return False + + +def log_completion_state( + logger: logging.Logger, + instance_id: str, + state: Optional[OrchestrationState]): + """Log the final state of a completed orchestration.""" + if not state: + return + # Compare against proto constants to avoid circular imports with client.py + status_val = state.runtime_status.value + if status_val == pb.ORCHESTRATION_STATUS_FAILED and state.failure_details is not None: + details = state.failure_details + logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}") + elif status_val == pb.ORCHESTRATION_STATUS_TERMINATED: + logger.info(f"Instance '{instance_id}' was terminated.") + elif status_val == pb.ORCHESTRATION_STATUS_COMPLETED: + logger.info(f"Instance '{instance_id}' completed.") + + +def build_raise_event_req( + instance_id: str, + event_name: str, + data: Optional[Any] = None) -> pb.RaiseEventRequest: + """Build a RaiseEventRequest for raising an orchestration event.""" + return pb.RaiseEventRequest( + instanceId=instance_id, + name=event_name, + input=helpers.get_string_value(shared.to_json(data) if data is not None else None) + ) + + +def build_terminate_req( + instance_id: str, + output: Optional[Any] = None, + recursive: bool = True) -> pb.TerminateRequest: + """Build a TerminateRequest for terminating an orchestration.""" + return pb.TerminateRequest( + instanceId=instance_id, + output=helpers.get_string_value(shared.to_json(output) if output is not None else None), + recursive=recursive + ) + + +def build_signal_entity_req( + entity_instance_id: EntityInstanceId, + operation_name: str, + input: Optional[Any] = None) -> pb.SignalEntityRequest: + """Build a SignalEntityRequest for signaling an entity.""" + return pb.SignalEntityRequest( + instanceId=str(entity_instance_id), + name=operation_name, + input=helpers.get_string_value(shared.to_json(input) if input is not None else None), + requestId=str(uuid.uuid4()), + scheduledTime=None, + parentTraceContext=None, + requestTime=helpers.new_timestamp(datetime.now(timezone.utc)) + ) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 69db3c55..232c31c6 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -4,6 +4,7 @@ from collections import namedtuple import grpc +import grpc.aio class _ClientCallDetails( @@ -18,6 +19,32 @@ class _ClientCallDetails( pass +class _AsyncClientCallDetails( + namedtuple( + '_AsyncClientCallDetails', + ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']), + grpc.aio.ClientCallDetails): + """This is an implementation of the aio ClientCallDetails interface needed for async interceptors. + This class takes five named values and inherits the ClientCallDetails from grpc.aio package. + This class encloses the values that describe a RPC to be invoked. + """ + pass + + +def _apply_metadata(client_call_details, metadata): + """Shared logic for applying metadata to call details. Returns the updated metadata list.""" + if metadata is None: + return client_call_details.metadata + + if client_call_details.metadata is not None: + new_metadata = list(client_call_details.metadata) + else: + new_metadata = [] + + new_metadata.extend(metadata) + return new_metadata + + class DefaultClientInterceptorImpl ( grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): @@ -30,24 +57,17 @@ def __init__(self, metadata: list[tuple[str, str]]): self._metadata = metadata def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details.""" - if self._metadata is None: + new_metadata = _apply_metadata(client_call_details, self._metadata) + if new_metadata is client_call_details.metadata: return client_call_details - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - else: - metadata = [] - - metadata.extend(self._metadata) - client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, + return _ClientCallDetails( + client_call_details.method, client_call_details.timeout, new_metadata, client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression) - return client_call_details - def intercept_unary_unary(self, continuation, client_call_details, request): new_client_call_details = self._intercept_call(client_call_details) return continuation(new_client_call_details, request) @@ -63,3 +83,44 @@ def intercept_stream_unary(self, continuation, client_call_details, request): def intercept_stream_stream(self, continuation, client_call_details, request): new_client_call_details = self._intercept_call(client_call_details) return continuation(new_client_call_details, request) + + +class DefaultAsyncClientInterceptorImpl( + grpc.aio.UnaryUnaryClientInterceptor, grpc.aio.UnaryStreamClientInterceptor, + grpc.aio.StreamUnaryClientInterceptor, grpc.aio.StreamStreamClientInterceptor): + """Async gRPC interceptor that adds metadata headers to all calls.""" + + def __init__(self, metadata: list[tuple[str, str]]): + self._metadata = metadata + + def _intercept_call( + self, client_call_details: grpc.aio.ClientCallDetails) -> grpc.aio.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + new_metadata = _apply_metadata(client_call_details, self._metadata) + if new_metadata is client_call_details.metadata: + return client_call_details + + return _AsyncClientCallDetails( + client_call_details.method, + client_call_details.timeout, + new_metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + ) + + async def intercept_unary_unary(self, continuation, client_call_details, request): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request) + + async def intercept_unary_stream(self, continuation, client_call_details, request): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request) + + async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request_iterator) + + async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request_iterator) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 1872ad45..20ad26f6 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -8,6 +8,7 @@ from typing import Any, Optional, Sequence, Union import grpc +import grpc.aio ClientInterceptor = Union[ grpc.UnaryUnaryClientInterceptor, @@ -16,6 +17,13 @@ grpc.StreamStreamClientInterceptor ] +AsyncClientInterceptor = Union[ + grpc.aio.UnaryUnaryClientInterceptor, + grpc.aio.UnaryStreamClientInterceptor, + grpc.aio.StreamUnaryClientInterceptor, + grpc.aio.StreamStreamClientInterceptor +] + # Field name used to indicate that an object was automatically serialized # and should be deserialized as a SimpleNamespace AUTO_SERIALIZED = "__durabletask_autoobject__" @@ -62,6 +70,38 @@ def get_grpc_channel( return channel +def get_async_grpc_channel( + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence[AsyncClientInterceptor]] = None) -> grpc.aio.Channel: + + if host_address is None: + host_address = get_default_host_address() + + for protocol in SECURE_PROTOCOLS: + if host_address.lower().startswith(protocol): + secure_channel = True + host_address = host_address[len(protocol):] + break + + for protocol in INSECURE_PROTOCOLS: + if host_address.lower().startswith(protocol): + secure_channel = False + host_address = host_address[len(protocol):] + break + + if secure_channel: + channel = grpc.aio.secure_channel( + host_address, grpc.ssl_channel_credentials(), + interceptors=interceptors) + else: + channel = grpc.aio.insecure_channel( + host_address, + interceptors=interceptors) + + return channel + + def get_logger( name_suffix: str, log_handler: Optional[logging.Handler] = None, diff --git a/pyproject.toml b/pyproject.toml index ec8a511d..70f3b38a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,3 +44,4 @@ testpaths = ["tests"] markers = [ "e2e: mark a test as an end-to-end test that requires a running sidecar" ] +asyncio_mode = "auto" diff --git a/requirements.txt b/requirements.txt index f32d3500..166d0471 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ autopep8 grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible protobuf pytest +pytest-asyncio pytest-cov azure-identity asyncio diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index e7501341..e6ea7c35 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,14 +1,25 @@ -from unittest.mock import ANY, patch +from unittest.mock import ANY, MagicMock, patch -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import (get_default_host_address, - get_grpc_channel) +from durabletask.client import AsyncTaskHubGrpcClient + +from durabletask.internal.grpc_interceptor import ( + DefaultAsyncClientInterceptorImpl, + DefaultClientInterceptorImpl, +) +from durabletask.internal.shared import ( + get_async_grpc_channel, + get_default_host_address, + get_grpc_channel, +) HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] +# ==== Sync channel tests ==== + + def test_get_grpc_channel_insecure(): with patch('grpc.insecure_channel') as mock_channel: get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) @@ -87,3 +98,76 @@ def test_grpc_channel_with_host_name_protocol_stripping(): prefix = "" get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) + + +# ==== Async channel tests ==== + + +def test_get_async_grpc_channel_insecure(): + with patch('grpc.aio.insecure_channel') as mock_channel: + get_async_grpc_channel(HOST_ADDRESS, False) + mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=None) + + +def test_get_async_grpc_channel_secure(): + with patch('grpc.aio.secure_channel') as mock_channel, patch( + 'grpc.ssl_channel_credentials') as mock_credentials: + get_async_grpc_channel(HOST_ADDRESS, True) + mock_channel.assert_called_once_with( + HOST_ADDRESS, mock_credentials.return_value, interceptors=None) + + +def test_get_async_grpc_channel_default_host_address(): + with patch('grpc.aio.insecure_channel') as mock_channel: + get_async_grpc_channel(None, False) + mock_channel.assert_called_once_with(get_default_host_address(), interceptors=None) + + +def test_get_async_grpc_channel_with_interceptors(): + async_interceptors = [DefaultAsyncClientInterceptorImpl(METADATA)] + with patch('grpc.aio.insecure_channel') as mock_channel: + get_async_grpc_channel(HOST_ADDRESS, False, interceptors=async_interceptors) + mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=async_interceptors) + + +def test_async_grpc_channel_protocol_stripping(): + with patch('grpc.aio.insecure_channel') as mock_insecure, patch( + 'grpc.aio.secure_channel') as mock_secure: + host_name = "myserver.com:1234" + + get_async_grpc_channel("http://" + host_name) + mock_insecure.assert_called_with(host_name, interceptors=None) + + get_async_grpc_channel("grpc://" + host_name) + mock_insecure.assert_called_with(host_name, interceptors=None) + + get_async_grpc_channel("https://" + host_name) + mock_secure.assert_called_with(host_name, ANY, interceptors=None) + + get_async_grpc_channel("grpcs://" + host_name) + mock_secure.assert_called_with(host_name, ANY, interceptors=None) + + +# ==== Async client construction tests ==== + + +def test_async_client_creates_with_defaults(): + with patch('grpc.aio.insecure_channel') as mock_channel: + mock_channel.return_value = MagicMock() + _ = AsyncTaskHubGrpcClient() + mock_channel.assert_called_once_with( + get_default_host_address(), interceptors=None) + + +def test_async_client_creates_with_metadata(): + with patch('grpc.aio.insecure_channel') as mock_channel: + mock_channel.return_value = MagicMock() + _ = AsyncTaskHubGrpcClient( + host_address=HOST_ADDRESS, metadata=METADATA) + mock_channel.assert_called_once() + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + interceptors = kwargs.get('interceptors') + assert interceptors is not None + assert len(interceptors) == 1 + assert isinstance(interceptors[0], DefaultAsyncClientInterceptorImpl) diff --git a/tests/durabletask/test_orchestration_async_e2e.py b/tests/durabletask/test_orchestration_async_e2e.py new file mode 100644 index 00000000..bec3b12e --- /dev/null +++ b/tests/durabletask/test_orchestration_async_e2e.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import asyncio +import json + +import pytest + +from durabletask import client, task, worker + +# NOTE: These tests assume a sidecar process is running. Example command: +# go install github.com/microsoft/durabletask-go@main +# durabletask-go --port 4001 +pytestmark = pytest.mark.e2e + + +@pytest.mark.asyncio +async def test_async_empty_orchestration(): + + invoked = False + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(empty_orchestrator, tags={'Tagged': 'true'}) + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +@pytest.mark.asyncio +async def test_async_activity_sequence(): + + def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + def sequence(ctx: task.OrchestrationContext, start_val: int): + numbers = [start_val] + current = start_val + for _ in range(10): + current = yield ctx.call_activity(plus_one, input=current) + numbers.append(current) + return numbers + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(sequence) + w.add_activity(plus_one) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(sequence, input=1) + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(sequence) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + + +@pytest.mark.asyncio +async def test_async_wait_for_multiple_external_events(): + def orchestrator(ctx: task.OrchestrationContext, _): + a = yield ctx.wait_for_external_event('A') + b = yield ctx.wait_for_external_event('B') + c = yield ctx.wait_for_external_event('C') + return [a, b, c] + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator) + await c.raise_orchestration_event(id, 'A', data='a') + await c.raise_orchestration_event(id, 'B', data='b') + await c.raise_orchestration_event(id, 'C', data='c') + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(['a', 'b', 'c']) + + +@pytest.mark.asyncio +async def test_async_suspend_and_resume(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator) + state = await c.wait_for_orchestration_start(id, timeout=30) + assert state is not None + + # Suspend the orchestration and wait for it to go into the SUSPENDED state + await c.suspend_orchestration(id) + while state.runtime_status == client.OrchestrationStatus.RUNNING: + await asyncio.sleep(0.1) + state = await c.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.SUSPENDED + + # Raise an event and confirm that it does NOT complete while suspended + await c.raise_orchestration_event(id, "my_event", data=42) + try: + state = await c.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + await c.resume_orchestration(id) + state = await c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) + + +@pytest.mark.asyncio +async def test_async_terminate(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator) + state = await c.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING + + await c.terminate_orchestration(id, output="some reason for termination") + state = await c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") + + +@pytest.mark.asyncio +async def test_async_purge_orchestration(): + def orchestrator(ctx: task.OrchestrationContext, _): + pass + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator) + await c.wait_for_orchestration_completion(id, timeout=30) + + result = await c.purge_orchestration(id) + assert result.deleted_instance_count == 1 + + state = await c.get_orchestration_state(id) + assert state is None From e915ce31a4565675a5ccbfa7d3b98fb22c3d5a06 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 20 Feb 2026 11:48:32 -0700 Subject: [PATCH 2/4] PR feedback --- durabletask/client.py | 26 ++++++++++----- durabletask/internal/grpc_interceptor.py | 4 +-- tests/durabletask/test_client.py | 42 ++++++++++++++++-------- 3 files changed, 48 insertions(+), 24 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 3e581ea6..574a8cdd 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -159,10 +159,15 @@ def __init__(self, *, secure_channel=secure_channel, interceptors=interceptors ) + self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version + def close(self) -> None: + """Closes the underlying gRPC channel.""" + self._channel.close() + def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -239,7 +244,7 @@ def wait_for_orchestration_completion(self, instance_id: str, *, raise def raise_orchestration_event(self, instance_id: str, event_name: str, *, - data: Optional[Any] = None): + data: Optional[Any] = None) -> None: req = build_raise_event_req(instance_id, event_name, data) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") @@ -247,18 +252,18 @@ def raise_orchestration_event(self, instance_id: str, event_name: str, *, def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, - recursive: bool = True): + recursive: bool = True) -> None: req = build_terminate_req(instance_id, output, recursive) self._logger.info(f"Terminating instance '{instance_id}'.") self._stub.TerminateInstance(req) - def suspend_orchestration(self, instance_id: str): + def suspend_orchestration(self, instance_id: str) -> None: req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") self._stub.SuspendInstance(req) - def resume_orchestration(self, instance_id: str): + def resume_orchestration(self, instance_id: str) -> None: req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") self._stub.ResumeInstance(req) @@ -370,10 +375,15 @@ def __init__(self, *, secure_channel=secure_channel, interceptors=interceptors ) + self._channel = channel self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) self.default_version = default_version + async def close(self) -> None: + """Closes the underlying gRPC channel.""" + await self._channel.close() + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, @@ -450,7 +460,7 @@ async def wait_for_orchestration_completion(self, instance_id: str, *, raise async def raise_orchestration_event(self, instance_id: str, event_name: str, *, - data: Optional[Any] = None): + data: Optional[Any] = None) -> None: req = build_raise_event_req(instance_id, event_name, data) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") @@ -458,18 +468,18 @@ async def raise_orchestration_event(self, instance_id: str, event_name: str, *, async def terminate_orchestration(self, instance_id: str, *, output: Optional[Any] = None, - recursive: bool = True): + recursive: bool = True) -> None: req = build_terminate_req(instance_id, output, recursive) self._logger.info(f"Terminating instance '{instance_id}'.") await self._stub.TerminateInstance(req) - async def suspend_orchestration(self, instance_id: str): + async def suspend_orchestration(self, instance_id: str) -> None: req = pb.SuspendRequest(instanceId=instance_id) self._logger.info(f"Suspending instance '{instance_id}'.") await self._stub.SuspendInstance(req) - async def resume_orchestration(self, instance_id: str): + async def resume_orchestration(self, instance_id: str) -> None: req = pb.ResumeRequest(instanceId=instance_id) self._logger.info(f"Resuming instance '{instance_id}'.") await self._stub.ResumeInstance(req) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 232c31c6..b16af5b0 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -22,10 +22,10 @@ class _ClientCallDetails( class _AsyncClientCallDetails( namedtuple( '_AsyncClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']), + ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), grpc.aio.ClientCallDetails): """This is an implementation of the aio ClientCallDetails interface needed for async interceptors. - This class takes five named values and inherits the ClientCallDetails from grpc.aio package. + This class takes six named values and inherits the ClientCallDetails from grpc.aio package. This class encloses the values that describe a RPC to be invoked. """ pass diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index e6ea7c35..006c0987 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -61,43 +61,53 @@ def test_grpc_channel_with_host_name_protocol_stripping(): prefix = "grpc://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + mock_insecure_channel.assert_called_once_with(host_name) + mock_insecure_channel.reset_mock() prefix = "http://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + mock_insecure_channel.assert_called_once_with(host_name) + mock_insecure_channel.reset_mock() prefix = "HTTP://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + mock_insecure_channel.assert_called_once_with(host_name) + mock_insecure_channel.reset_mock() prefix = "GRPC://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + mock_insecure_channel.assert_called_once_with(host_name) + mock_insecure_channel.reset_mock() prefix = "" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_insecure_channel.assert_called_with(host_name) + mock_insecure_channel.assert_called_once_with(host_name) + mock_insecure_channel.reset_mock() prefix = "grpcs://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + mock_secure_channel.assert_called_once_with(host_name, ANY) + mock_secure_channel.reset_mock() prefix = "https://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + mock_secure_channel.assert_called_once_with(host_name, ANY) + mock_secure_channel.reset_mock() prefix = "HTTPS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + mock_secure_channel.assert_called_once_with(host_name, ANY) + mock_secure_channel.reset_mock() prefix = "GRPCS://" get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + mock_secure_channel.assert_called_once_with(host_name, ANY) + mock_secure_channel.reset_mock() prefix = "" get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) - mock_secure_channel.assert_called_with(host_name, ANY) + mock_secure_channel.assert_called_once_with(host_name, ANY) + mock_secure_channel.reset_mock() # ==== Async channel tests ==== @@ -136,16 +146,20 @@ def test_async_grpc_channel_protocol_stripping(): host_name = "myserver.com:1234" get_async_grpc_channel("http://" + host_name) - mock_insecure.assert_called_with(host_name, interceptors=None) + mock_insecure.assert_called_once_with(host_name, interceptors=None) + mock_insecure.reset_mock() get_async_grpc_channel("grpc://" + host_name) - mock_insecure.assert_called_with(host_name, interceptors=None) + mock_insecure.assert_called_once_with(host_name, interceptors=None) + mock_insecure.reset_mock() get_async_grpc_channel("https://" + host_name) - mock_secure.assert_called_with(host_name, ANY, interceptors=None) + mock_secure.assert_called_once_with(host_name, ANY, interceptors=None) + mock_secure.reset_mock() get_async_grpc_channel("grpcs://" + host_name) - mock_secure.assert_called_with(host_name, ANY, interceptors=None) + mock_secure.assert_called_once_with(host_name, ANY, interceptors=None) + mock_secure.reset_mock() # ==== Async client construction tests ==== From f5367b6e9fd823baff6a88ad91be69537db520db Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 20 Feb 2026 11:59:45 -0700 Subject: [PATCH 3/4] Support restart in async client --- durabletask/client.py | 20 ++++++ .../test_orchestration_async_e2e.py | 63 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/durabletask/client.py b/durabletask/client.py index 9b26b3b1..3787ff12 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -504,6 +504,26 @@ async def resume_orchestration(self, instance_id: str) -> None: self._logger.info(f"Resuming instance '{instance_id}'.") await self._stub.ResumeInstance(req) + async def restart_orchestration(self, instance_id: str, *, + restart_with_new_instance_id: bool = False) -> str: + """Restarts an existing orchestration instance. + + Args: + instance_id: The ID of the orchestration instance to restart. + restart_with_new_instance_id: If True, the restarted orchestration will use a new instance ID. + If False (default), the restarted orchestration will reuse the same instance ID. + + Returns: + The instance ID of the restarted orchestration. + """ + req = pb.RestartInstanceRequest( + instanceId=instance_id, + restartWithNewInstanceId=restart_with_new_instance_id) + + self._logger.info(f"Restarting instance '{instance_id}'.") + res: pb.RestartInstanceResponse = await self._stub.RestartInstance(req) + return res.instanceId + async def purge_orchestration(self, instance_id: str, recursive: bool = True) -> PurgeInstancesResult: req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) self._logger.info(f"Purging instance '{instance_id}'.") diff --git a/tests/durabletask/test_orchestration_async_e2e.py b/tests/durabletask/test_orchestration_async_e2e.py index bec3b12e..ee4a216a 100644 --- a/tests/durabletask/test_orchestration_async_e2e.py +++ b/tests/durabletask/test_orchestration_async_e2e.py @@ -178,3 +178,66 @@ def orchestrator(ctx: task.OrchestrationContext, _): state = await c.get_orchestration_state(id) assert state is None + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance") +async def test_async_restart_with_same_instance_id(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(say_hello, input="World") + return result + + def say_hello(ctx: task.ActivityContext, input: str): + return f"Hello, {input}!" + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.add_activity(say_hello) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + + # Restart the orchestration with the same instance ID + restarted_id = await c.restart_orchestration(id) + assert restarted_id == id + + state = await c.wait_for_orchestration_completion(restarted_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance") +async def test_async_restart_with_new_instance_id(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(say_hello, input="World") + return result + + def say_hello(ctx: task.ActivityContext, input: str): + return f"Hello, {input}!" + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.add_activity(say_hello) + w.start() + + c = client.AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + # Restart the orchestration with a new instance ID + restarted_id = await c.restart_orchestration(id, restart_with_new_instance_id=True) + assert restarted_id != id + + state = await c.wait_for_orchestration_completion(restarted_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") From aa9d198727adb442ccac44245ca98ca049bf117f Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 20 Feb 2026 12:40:43 -0700 Subject: [PATCH 4/4] Async client for durabletask-azuremanged --- durabletask-azuremanaged/CHANGELOG.md | 5 + .../durabletask/azuremanaged/client.py | 65 ++- .../internal/durabletask_grpc_interceptor.py | 43 ++ durabletask/internal/grpc_interceptor.py | 2 +- .../test_dts_async_orchestration_e2e.py | 542 ++++++++++++++++++ 5 files changed, 655 insertions(+), 2 deletions(-) create mode 100644 tests/durabletask-azuremanaged/test_dts_async_orchestration_e2e.py diff --git a/durabletask-azuremanaged/CHANGELOG.md b/durabletask-azuremanaged/CHANGELOG.md index 6fb231be..84e6dbba 100644 --- a/durabletask-azuremanaged/CHANGELOG.md +++ b/durabletask-azuremanaged/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +- Added `AsyncDurableTaskSchedulerClient` for async/await usage with `grpc.aio` +- Added `DTSAsyncDefaultClientInterceptorImpl` async gRPC interceptor for DTS authentication + ## v1.3.0 - Updates base dependency to durabletask v1.3.0 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 50612e0c..c31c262e 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -8,9 +8,10 @@ from azure.core.credentials import TokenCredential from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import ( + DTSAsyncDefaultClientInterceptorImpl, DTSDefaultClientInterceptorImpl, ) -from durabletask.client import TaskHubGrpcClient +from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient # Client class used for Durable Task Scheduler (DTS) @@ -39,3 +40,65 @@ def __init__(self, *, log_formatter=log_formatter, interceptors=interceptors, default_version=default_version) + + +# Async client class used for Durable Task Scheduler (DTS) +class AsyncDurableTaskSchedulerClient(AsyncTaskHubGrpcClient): + """An async client implementation for Azure Durable Task Scheduler (DTS). + + This class extends AsyncTaskHubGrpcClient to provide integration with Azure's + Durable Task Scheduler service using async gRPC. It handles authentication via + Azure credentials and configures the necessary gRPC interceptors for DTS + communication. + + Args: + host_address (str): The gRPC endpoint address of the DTS service. + taskhub (str): The name of the task hub. Cannot be empty. + token_credential (Optional[TokenCredential]): Azure credential for authentication. + If None, anonymous authentication will be used. + secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). + Defaults to True. + default_version (Optional[str], optional): Default version string for orchestrations. + log_handler (Optional[logging.Handler], optional): Custom logging handler for client logs. + log_formatter (Optional[logging.Formatter], optional): Custom log formatter for client logs. + + Raises: + ValueError: If taskhub is empty or None. + + Example: + >>> from azure.identity.aio import DefaultAzureCredential + >>> from durabletask.azuremanaged import AsyncDurableTaskSchedulerClient + >>> + >>> credential = DefaultAzureCredential() + >>> async with AsyncDurableTaskSchedulerClient( + ... host_address="my-dts-service.azure.com:443", + ... taskhub="my-task-hub", + ... token_credential=credential + ... ) as client: + ... instance_id = await client.schedule_new_orchestration("my_orchestrator") + """ + + def __init__(self, *, + host_address: str, + taskhub: str, + token_credential: Optional[TokenCredential], + secure_channel: bool = True, + default_version: Optional[str] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None): + + if not taskhub: + raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + + interceptors = [DTSAsyncDefaultClientInterceptorImpl(token_credential, taskhub)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + log_handler=log_handler, + log_formatter=log_formatter, + interceptors=interceptors, + default_version=default_version) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py index fa1459f4..e683b180 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py @@ -9,7 +9,9 @@ from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager from durabletask.internal.grpc_interceptor import ( + DefaultAsyncClientInterceptorImpl, DefaultClientInterceptorImpl, + _AsyncClientCallDetails, _ClientCallDetails, ) @@ -52,3 +54,44 @@ def _intercept_call( self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token return super()._intercept_call(client_call_details) + + +class DTSAsyncDefaultClientInterceptorImpl(DefaultAsyncClientInterceptorImpl): + """Async version of DTSDefaultClientInterceptorImpl for use with grpc.aio channels. + + This class implements async gRPC interceptors to add DTS-specific headers + (task hub name, user agent, and authentication token) to all async calls.""" + + def __init__(self, token_credential: Optional[TokenCredential], taskhub_name: str): + try: + # Get the version of the azuremanaged package + sdk_version = version('durabletask-azuremanaged') + except Exception: + # Fallback if version cannot be determined + sdk_version = "unknown" + user_agent = f"durabletask-python/{sdk_version}" + self._metadata = [ + ("taskhub", taskhub_name), + ("x-user-agent", user_agent)] + super().__init__(self._metadata) + + if token_credential is not None: + self._token_credential = token_credential + self._token_manager = AccessTokenManager(token_credential=self._token_credential) + access_token = self._token_manager.get_access_token() + if access_token is not None: + self._metadata.append(("authorization", f"Bearer {access_token.token}")) + + def _intercept_call( + self, client_call_details: _AsyncClientCallDetails) -> grpc.aio.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + # Refresh the auth token if it is present and needed + if self._metadata is not None: + for i, (key, _) in enumerate(self._metadata): + if key.lower() == "authorization": # Ensure case-insensitive comparison + new_token = self._token_manager.get_access_token() # Get the new token + if new_token is not None: + self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token + + return super()._intercept_call(client_call_details) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index b16af5b0..61d1c876 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -22,7 +22,7 @@ class _ClientCallDetails( class _AsyncClientCallDetails( namedtuple( '_AsyncClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), + ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready']), grpc.aio.ClientCallDetails): """This is an implementation of the aio ClientCallDetails interface needed for async interceptors. This class takes six named values and inherits the ClientCallDetails from grpc.aio package. diff --git a/tests/durabletask-azuremanaged/test_dts_async_orchestration_e2e.py b/tests/durabletask-azuremanaged/test_dts_async_orchestration_e2e.py new file mode 100644 index 00000000..564e5650 --- /dev/null +++ b/tests/durabletask-azuremanaged/test_dts_async_orchestration_e2e.py @@ -0,0 +1,542 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import json +import os +import threading +from datetime import timedelta +import uuid + +import pytest + +from durabletask import client, task +from durabletask.azuremanaged.client import AsyncDurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + +# NOTE: These tests assume a sidecar process is running. Example command: +# docker run -i -p 8080:8080 -p 8082:8082 -d mcr.microsoft.com/dts/dts-emulator:latest +pytestmark = [pytest.mark.dts, pytest.mark.asyncio] + +# Read the environment variables +taskhub_name = os.getenv("TASKHUB", "default") +endpoint = os.getenv("ENDPOINT", "http://localhost:8080") + + +def _get_credential(): + """Returns DefaultAzureCredential if endpoint is https, otherwise None (for emulator).""" + if endpoint.startswith("https://"): + from azure.identity import DefaultAzureCredential + return DefaultAzureCredential() + return None + + +async def test_empty_orchestration(): + + invoked = False + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + c = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await c.schedule_new_orchestration(empty_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +async def test_activity_sequence(): + + def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + def sequence(ctx: task.OrchestrationContext, start_val: int): + numbers = [start_val] + current = start_val + for _ in range(10): + current = yield ctx.call_activity(plus_one, input=current) + numbers.append(current) + return numbers + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(sequence) + w.add_activity(plus_one) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(sequence, input=1) + state = await task_hub_client.wait_for_orchestration_completion( + id, timeout=30) + + assert state is not None + assert state.name == task.get_name(sequence) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert state.serialized_custom_status is None + + +async def test_activity_error_handling(): + + def throw(_: task.ActivityContext, input: int) -> int: + raise RuntimeError("Kah-BOOOOM!!!") + + compensation_counter = 0 + + def increment_counter(ctx, _): + nonlocal compensation_counter + compensation_counter += 1 + + def orchestrator(ctx: task.OrchestrationContext, input: int): + error_msg = "" + try: + yield ctx.call_activity(throw, input=input) + except task.TaskFailedError as e: + error_msg = e.details.message + + # compensating actions + yield ctx.call_activity(increment_counter) + yield ctx.call_activity(increment_counter) + + return error_msg + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.add_activity(throw) + w.add_activity(increment_counter) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(orchestrator, input=1) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(orchestrator) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Kah-BOOOOM!!!") + assert state.failure_details is None + assert state.serialized_custom_status is None + assert compensation_counter == 2 + + +async def test_sub_orchestration_fan_out(): + threadLock = threading.Lock() + activity_counter = 0 + + def increment(ctx, _): + with threadLock: + nonlocal activity_counter + activity_counter += 1 + + def orchestrator_child(ctx: task.OrchestrationContext, activity_count: int): + for _ in range(activity_count): + yield ctx.call_activity(increment) + + def parent_orchestrator(ctx: task.OrchestrationContext, count: int): + # Fan out to multiple sub-orchestrations + tasks = [] + for _ in range(count): + tasks.append(ctx.call_sub_orchestrator( + orchestrator_child, input=3)) + # Wait for all sub-orchestrations to complete + yield task.when_all(tasks) + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_activity(increment) + w.add_orchestrator(orchestrator_child) + w.add_orchestrator(parent_orchestrator) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert activity_counter == 30 + + +async def test_sub_orchestrator_by_name(): + sub_orchestrator_counter = 0 + + def orchestrator_child(ctx: task.OrchestrationContext, _): + nonlocal sub_orchestrator_counter + sub_orchestrator_counter += 1 + + def parent_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator("orchestrator_child") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator_child) + w.add_orchestrator(parent_orchestrator) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert sub_orchestrator_counter == 1 + + +async def test_wait_for_multiple_external_events(): + def orchestrator(ctx: task.OrchestrationContext, _): + a = yield ctx.wait_for_external_event('A') + b = yield ctx.wait_for_external_event('B') + c = yield ctx.wait_for_external_event('C') + return [a, b, c] + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.start() + + # Start the orchestration and immediately raise events to it. + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(orchestrator) + await task_hub_client.raise_orchestration_event(id, 'A', data='a') + await task_hub_client.raise_orchestration_event(id, 'B', data='b') + await task_hub_client.raise_orchestration_event(id, 'C', data='c') + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(['a', 'b', 'c']) + + +async def test_terminate(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(orchestrator) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(orchestrator) + state = await task_hub_client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING + + await task_hub_client.terminate_orchestration(id, output="some reason for termination") + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") + + +async def test_terminate_recursive(): + def root(ctx: task.OrchestrationContext, _): + result = yield ctx.call_sub_orchestrator(child) + return result + + def child(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(root) + w.add_orchestrator(child) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(root) + state = await task_hub_client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.RUNNING + + # Terminate root orchestration(recursive set to True by default) + await task_hub_client.terminate_orchestration(id, output="some reason for termination") + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + + # Verify that child orchestration is also terminated + await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.TERMINATED + + await task_hub_client.purge_orchestration(id) + state = await task_hub_client.get_orchestration_state(id) + assert state is None + + +async def test_restart_with_same_instance_id(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(say_hello, input="World") + return result + + def say_hello(ctx: task.ActivityContext, input: str): + return f"Hello, {input}!" + + credential = _get_credential() + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(orchestrator) + w.add_activity(say_hello) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) + id = await task_hub_client.schedule_new_orchestration(orchestrator) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + + # Restart the orchestration with the same instance ID + restarted_id = await task_hub_client.restart_orchestration(id) + assert restarted_id == id + + state = await task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + + +async def test_restart_with_new_instance_id(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.call_activity(say_hello, input="World") + return result + + def say_hello(ctx: task.ActivityContext, input: str): + return f"Hello, {input}!" + + credential = _get_credential() + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(orchestrator) + w.add_activity(say_hello) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) + id = await task_hub_client.schedule_new_orchestration(orchestrator) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + # Restart the orchestration with a new instance ID + restarted_id = await task_hub_client.restart_orchestration(id, restart_with_new_instance_id=True) + assert restarted_id != id + + state = await task_hub_client.wait_for_orchestration_completion(restarted_id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Hello, World!") + + +async def test_retry_policies(): + child_orch_counter = 0 + throw_activity_counter = 0 + + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=30)) + + def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy) + + def child_orchestrator_with_retry(ctx: task.OrchestrationContext, _): + nonlocal child_orch_counter + if not ctx.is_replaying: + child_orch_counter += 1 + yield ctx.call_activity(throw_activity_with_retry, retry_policy=retry_policy) + + def throw_activity_with_retry(ctx: task.ActivityContext, _): + nonlocal throw_activity_counter + throw_activity_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(parent_orchestrator_with_retry) + w.add_orchestrator(child_orchestrator_with_retry) + w.add_activity(throw_activity_with_retry) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(parent_orchestrator_with_retry) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 9 + assert child_orch_counter == 3 + + +async def test_retry_timeout(): + throw_activity_counter = 0 + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=14)) + + def mock_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(throw_activity, retry_policy=retry_policy) + + def throw_activity(ctx: task.ActivityContext, _): + nonlocal throw_activity_counter + throw_activity_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(mock_orchestrator) + w.add_activity(throw_activity) + w.start() + + task_hub_client = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await task_hub_client.schedule_new_orchestration(mock_orchestrator) + state = await task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 4 + + +async def test_custom_status(): + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + ctx.set_custom_status("foobaz") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + c = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await c.schedule_new_orchestration(empty_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status == "\"foobaz\"" + + +async def test_new_uuid(): + def noop(_: task.ActivityContext, _1): + pass + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + # Assert that two new_uuid calls return different values + results = [ctx.new_uuid(), ctx.new_uuid()] + yield ctx.call_activity("noop") + # Assert that new_uuid still returns a unique value after replay + results.append(ctx.new_uuid()) + return results + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(empty_orchestrator) + w.add_activity(noop) + w.start() + + c = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await c.schedule_new_orchestration(empty_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + results = json.loads(state.serialized_output or "\"\"") + assert isinstance(results, list) and len(results) == 3 + assert uuid.UUID(results[0]) != uuid.UUID(results[1]) + assert uuid.UUID(results[0]) != uuid.UUID(results[2]) + assert uuid.UUID(results[1]) != uuid.UUID(results[2]) + + +async def test_orchestration_with_unparsable_output_fails(): + def test_orchestrator(ctx: task.OrchestrationContext, _): + return Exception("This is not JSON serializable") + + # Start a worker, which will connect to the sidecar in a background thread + with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) as w: + w.add_orchestrator(test_orchestrator) + w.start() + + c = AsyncDurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + id = await c.schedule_new_orchestration(test_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is not None + assert state.failure_details.error_type == "JsonEncodeOutputException" + assert state.failure_details.message.startswith("The orchestration result could not be encoded. Object details:") + assert state.failure_details.message.find("This is not JSON serializable") != -1 + assert state.runtime_status == client.OrchestrationStatus.FAILED