From 6941750bc905b44128ddf860bf243da42dfc4db4 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 21 May 2026 11:50:58 -0700 Subject: [PATCH 1/7] Narrow overloads on the sano client. --- temporalio/client/_nexus.py | 139 ++++++++++++++++++++------ tests/nexus/test_nexus_type_errors.py | 24 ----- 2 files changed, 106 insertions(+), 57 deletions(-) diff --git a/temporalio/client/_nexus.py b/temporalio/client/_nexus.py index 991ab34a3..0d8a26c40 100644 --- a/temporalio/client/_nexus.py +++ b/temporalio/client/_nexus.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Generic, cast, overload @@ -473,7 +473,7 @@ class NexusClient(ABC, Generic[NexusServiceType]): Use :py:meth:`Client.create_nexus_client` to create a client. """ - # Overload for nexusrpc.Operation with input + # Overload for nexusrpc.Operation @overload @abstractmethod async def start_operation( @@ -484,6 +484,7 @@ async def start_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -494,18 +495,18 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... - # Overload for Callable with result_type + # Overload for string operation name @overload @abstractmethod async def start_operation( self, - operation: Callable[..., Any], + operation: str, arg: Any, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -516,17 +517,21 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... - # Overload for Callable without result_type + # Overload for workflow_run_operation methods @overload @abstractmethod async def start_operation( self, - operation: Callable[..., Any], - arg: Any, + operation: Callable[ + [NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT], + Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -535,20 +540,23 @@ async def start_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> NexusOperationHandle[Any]: ... + ) -> NexusOperationHandle[OutputT]: ... - # Overload for str with result_type + # Overload for sync_operation methods (async def) @overload @abstractmethod async def start_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + Awaitable[OutputT], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -559,17 +567,45 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... - # Overload for str without result_type + # Overload for sync_operation methods (def) @overload @abstractmethod async def start_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + OutputT, + ], + arg: InputT, + *, + id: str, + id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + headers: Mapping[str, str] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> NexusOperationHandle[OutputT]: ... + + # Overload for operation_handler + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [NexusServiceType], nexusrpc.handler.OperationHandler[InputT, OutputT] + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -578,7 +614,7 @@ async def start_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> NexusOperationHandle[Any]: ... + ) -> NexusOperationHandle[OutputT]: ... @abstractmethod async def start_operation( @@ -633,7 +669,7 @@ async def start_operation( """ ... - # Overload for nexusrpc.Operation with input + # Overload for nexusrpc.Operation @overload @abstractmethod async def execute_operation( @@ -644,6 +680,7 @@ async def execute_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -654,18 +691,18 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... - # Overload for Callable with result_type + # Overload for string operation name @overload @abstractmethod async def execute_operation( self, - operation: Callable[..., Any], + operation: str, arg: Any, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -676,17 +713,21 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... - # Overload for Callable without result_type + # Overload for workflow_run_operation methods @overload @abstractmethod async def execute_operation( self, - operation: Callable[..., Any], - arg: Any, + operation: Callable[ + [NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT], + Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -695,20 +736,23 @@ async def execute_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> Any: ... + ) -> OutputT: ... - # Overload for str with result_type + # Overload for sync_operation methods (async def) @overload @abstractmethod async def execute_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + Awaitable[OutputT], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -719,17 +763,46 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... - # Overload for str without result_type + # Overload for sync_operation methods (async def) @overload @abstractmethod async def execute_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + OutputT, + ], + arg: InputT, + *, + id: str, + id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + headers: Mapping[str, str] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> OutputT: ... + + # Overload for operation_handler + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [NexusServiceType], + nexusrpc.handler.OperationHandler[InputT, OutputT], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -738,7 +811,7 @@ async def execute_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> Any: ... + ) -> OutputT: ... @abstractmethod async def execute_operation( diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index c669f8a5b..7a1ca64fa 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -238,16 +238,6 @@ async def standalone_operation_type_tests(): start_to_close_timeout=timedelta(seconds=2), ) - # result_type overrides output type from operation definition - # conflicting result_type and annotation on variable cause type error - # assert-type-error-pyright: 'Type "str" is not assignable to declared type "MyOutput"' - _bad_result_type_output: MyOutput = await nexus_client.execute_operation( # type: ignore - MyServiceHandler.my_sync_operation, - MyInput(), - id="op-1", - result_type=str, # type: ignore - ) - # string operation name and result_type infers output type _str_op_result_type_output: MyOutput = await nexus_client.execute_operation( "my_sync_operation", MyInput(), id="op-1", result_type=MyOutput @@ -337,20 +327,6 @@ async def standalone_operation_type_tests(): ) _defn_handle_output: MyOutput = await _defn_handle.result() - # result_type overrides output type from operation definition - # conflicting result_type and annotation on variable cause type error - _result_type_handle: NexusOperationHandle[ - MyOutput - # assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"' - ] = await nexus_client.start_operation( # type: ignore - MyServiceHandler.my_sync_operation, - MyInput(), - id="op-1", - result_type=str, # type: ignore - ) - # handle still respects type declaration on the variable - _result_type_handle_output: MyOutput = await _result_type_handle.result() - # starting with string operation name and result_type infers output type on the handle # and result from the handle _str_op_result_type_handle: NexusOperationHandle[ From fc95b9913d6fb2638513d46dbd70c013cdbee4f5 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 21 May 2026 12:23:16 -0700 Subject: [PATCH 2/7] remove result_type params for overloads that don't need them --- temporalio/client/_nexus.py | 20 +++++++------------- tests/nexus/test_nexus_type_errors.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/temporalio/client/_nexus.py b/temporalio/client/_nexus.py index 0d8a26c40..060235e01 100644 --- a/temporalio/client/_nexus.py +++ b/temporalio/client/_nexus.py @@ -484,7 +484,6 @@ async def start_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -531,7 +530,6 @@ async def start_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -556,7 +554,6 @@ async def start_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -581,7 +578,6 @@ async def start_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -605,7 +601,6 @@ async def start_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -647,7 +642,8 @@ async def start_operation( id: Unique identifier for this operation. id_reuse_policy: Policy for reusing operation IDs. id_conflict_policy: Policy for handling ID conflicts. - result_type: The result type to deserialize into. + result_type: For string operation names, this can set the specific + result type hint to deserialize into. schedule_to_close_timeout: End-to-end timeout for the Nexus operation. If unset, defaults to the maximum allowed by the Temporal server. @@ -680,7 +676,6 @@ async def execute_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -727,7 +722,6 @@ async def execute_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -752,7 +746,6 @@ async def execute_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -777,7 +770,6 @@ async def execute_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -802,7 +794,6 @@ async def execute_operation( id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -846,7 +837,8 @@ async def execute_operation( id: Unique identifier for this operation. id_reuse_policy: Policy for reusing operation IDs. id_conflict_policy: Policy for handling ID conflicts. - result_type: The result type to deserialize into. + result_type: For string operation names, this can set the specific + result type hint to deserialize into. schedule_to_close_timeout: End-to-end timeout for the Nexus operation. If unset, defaults to the maximum allowed by the Temporal server. @@ -933,7 +925,9 @@ async def start_operation( This API is experimental and unstable. """ op_name, output_type = self._resolve_operation(operation) - final_result_type: type | None = result_type or output_type + final_result_type: type | None = ( + result_type if isinstance(operation, str) else output_type + ) return await self._client._impl.start_nexus_operation( StartNexusOperationInput( diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index 7a1ca64fa..4a013e8a9 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -238,6 +238,15 @@ async def standalone_operation_type_tests(): start_to_close_timeout=timedelta(seconds=2), ) + # result_type is not allowed when an operation is provided + await nexus_client.execute_operation( + # assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"' + MyService.my_sync_operation, # type: ignore + MyInput(), + id="op-1", + result_type=str, + ) + # string operation name and result_type infers output type _str_op_result_type_output: MyOutput = await nexus_client.execute_operation( "my_sync_operation", MyInput(), id="op-1", result_type=MyOutput @@ -327,6 +336,15 @@ async def standalone_operation_type_tests(): ) _defn_handle_output: MyOutput = await _defn_handle.result() + # result_type is not allowed when an operation is provided + await nexus_client.start_operation( + # assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"' + MyServiceHandler.my_sync_operation, # type: ignore + MyInput(), + id="op-1", + result_type=str, + ) + # starting with string operation name and result_type infers output type on the handle # and result from the handle _str_op_result_type_handle: NexusOperationHandle[ From 945418862955788eac5f9346f232b8645a17708a Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 21 May 2026 09:47:45 -0700 Subject: [PATCH 3/7] Add Temporal Nexus operation handler --- temporalio/nexus/__init__.py | 8 +- temporalio/nexus/_decorators.py | 175 +++++- temporalio/nexus/_operation_context.py | 171 +++-- temporalio/nexus/_operation_handlers.py | 65 +- temporalio/nexus/_temporal_client.py | 299 +++++++++ temporalio/nexus/_token.py | 154 +++-- temporalio/nexus/_util.py | 105 +++- temporalio/workflow/__init__.py | 3 - temporalio/workflow/_nexus.py | 61 +- tests/helpers/__init__.py | 38 ++ .../test_handler_operation_definitions.py | 99 +++ tests/nexus/test_nexus_client_updates.py | 2 +- tests/nexus/test_nexus_type_errors.py | 207 +++++- tests/nexus/test_operation_token.py | 115 ++++ tests/nexus/test_temporal_operation.py | 588 ++++++++++++++++++ ...test_workflow_caller_cancellation_types.py | 37 +- ...llation_types_when_cancel_handler_fails.py | 3 +- tests/test_workflow_exports.py | 1 - 18 files changed, 1946 insertions(+), 185 deletions(-) create mode 100644 temporalio/nexus/_temporal_client.py create mode 100644 tests/nexus/test_operation_token.py create mode 100644 tests/nexus/test_temporal_operation.py diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index ea049d90e..d8129e601 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -3,11 +3,12 @@ See https://github.com/temporalio/sdk-python/tree/main#nexus """ -from ._decorators import workflow_run_operation +from ._decorators import temporal_operation, workflow_run_operation from ._operation_context import ( Info, LoggerAdapter, NexusCallback, + TemporalStartOperationContext, WorkflowRunOperationContext, client, in_operation, @@ -18,6 +19,7 @@ wait_for_worker_shutdown, wait_for_worker_shutdown_sync, ) +from ._temporal_client import TemporalNexusClient, TemporalOperationResult from ._token import WorkflowHandle __all__ = ( @@ -26,6 +28,7 @@ "LoggerAdapter", "NexusCallback", "WorkflowRunOperationContext", + "TemporalStartOperationContext", "client", "in_operation", "info", @@ -35,4 +38,7 @@ "wait_for_worker_shutdown", "wait_for_worker_shutdown_sync", "WorkflowHandle", + "TemporalNexusClient", + "TemporalOperationResult", + "temporal_operation", ) diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 6dfd3daff..7bbd689f0 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -2,7 +2,6 @@ from collections.abc import Awaitable, Callable from typing import ( - TypeVar, overload, ) @@ -13,26 +12,37 @@ StartOperationContext, ) -from ._operation_context import WorkflowRunOperationContext -from ._operation_handlers import WorkflowRunOperationHandler +from temporalio.nexus._temporal_client import ( + TemporalNexusClient, + TemporalOperationResult, +) +from temporalio.types import NexusServiceType + +from ._operation_context import ( + TemporalStartOperationContext, + WorkflowRunOperationContext, +) +from ._operation_handlers import ( + TemporalNexusOperationHandler, + WorkflowRunOperationHandler, +) from ._token import WorkflowHandle from ._util import ( get_callable_name, + get_temporal_operation_start_method_input_and_output_type_annotations, get_workflow_run_start_method_input_and_output_type_annotations, set_operation_factory, ) -ServiceHandlerT = TypeVar("ServiceHandlerT") - @overload def workflow_run_operation( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ... @@ -44,12 +54,12 @@ def workflow_run_operation( ) -> Callable[ [ Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ]: ... @@ -59,7 +69,7 @@ def workflow_run_operation( start: None | ( Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ) = None, @@ -67,18 +77,18 @@ def workflow_run_operation( name: str | None = None, ) -> ( Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] | Callable[ [ Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ] ], Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ] @@ -87,11 +97,11 @@ def workflow_run_operation( def decorator( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ]: ( @@ -100,7 +110,7 @@ def decorator( ) = get_workflow_run_start_method_input_and_output_type_annotations(start) def operation_handler_factory( - self: ServiceHandlerT, + self: NexusServiceType, ) -> OperationHandler[InputT, OutputT]: async def _start( ctx: StartOperationContext, input: InputT @@ -130,3 +140,136 @@ async def _start( return decorator return decorator(start) + + +@overload +def temporal_operation( + start: Callable[ + [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], +) -> Callable[ + [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], +]: ... + + +@overload +def temporal_operation( + *, + name: str | None = None, +) -> Callable[ + [ + Callable[ + [ + NexusServiceType, + TemporalStartOperationContext, + TemporalNexusClient, + InputT, + ], + Awaitable[TemporalOperationResult[OutputT]], + ] + ], + Callable[ + [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], +]: ... + + +def temporal_operation( + start: None + | ( + Callable[ + [ + NexusServiceType, + TemporalStartOperationContext, + TemporalNexusClient, + InputT, + ], + Awaitable[TemporalOperationResult[OutputT]], + ] + ) = None, + *, + name: str | None = None, +) -> ( + Callable[ + [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ] + | Callable[ + [ + Callable[ + [ + NexusServiceType, + TemporalStartOperationContext, + TemporalNexusClient, + InputT, + ], + Awaitable[TemporalOperationResult[OutputT]], + ] + ], + Callable[ + [ + NexusServiceType, + TemporalStartOperationContext, + TemporalNexusClient, + InputT, + ], + Awaitable[TemporalOperationResult[OutputT]], + ], + ] +): + """Decorator marking a method as the start method for an operation that interacts with Temporal.""" + + def decorator( + start: Callable[ + [ + NexusServiceType, + TemporalStartOperationContext, + TemporalNexusClient, + InputT, + ], + Awaitable[TemporalOperationResult[OutputT]], + ], + ) -> Callable[ + [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ]: + ( + input_type, + output_type, + ) = get_temporal_operation_start_method_input_and_output_type_annotations(start) + + def operation_handler_factory( + self: NexusServiceType, + ) -> OperationHandler[InputT, OutputT]: + async def _start( + ctx: StartOperationContext, client: TemporalNexusClient, input: InputT + ) -> TemporalOperationResult[OutputT]: + return await start( + self, + TemporalStartOperationContext._from_start_operation_context(ctx), + client, + input, + ) + + _start.__doc__ = start.__doc__ + return TemporalNexusOperationHandler(_start) + + method_name = get_callable_name(start) + op = nexusrpc.Operation( + name=name or method_name, + input_type=input_type, + output_type=output_type, + ) + op.method_name = method_name + nexusrpc.set_operation(operation_handler_factory, op) + + set_operation_factory(start, operation_handler_factory) + return start + + if start is None: + return decorator + + return decorator(start) diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index 069fd65d3..b714b1cb8 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -28,6 +28,7 @@ OperationContext, StartOperationContext, ) +from typing_extensions import Self import temporalio.api.common.v1 import temporalio.api.workflowservice.v1 @@ -279,6 +280,26 @@ def _add_outbound_links( return workflow_handle +class TemporalStartOperationContext(StartOperationContext): + """Context received by a Temporal operation.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the Temporal operation context.""" + super().__init__(*args, **kwargs) + self._temporal_context = _TemporalStartOperationContext.get() + + @classmethod + def _from_start_operation_context(cls, ctx: StartOperationContext) -> Self: + return cls( + **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, + ) + + @property + def metric_meter(self) -> temporalio.common.MetricMeter: + """The metric meter""" + return self._temporal_context.metric_meter + + class WorkflowRunOperationContext(StartOperationContext): """Context received by a workflow run operation.""" @@ -492,50 +513,34 @@ async def start_workflow( Nexus caller is itself a workflow, this means that the workflow in the caller namespace web UI will contain links to the started workflow, and vice versa. """ - # We must pass nexus_completion_callbacks, event_links, and request_id, - # but these are deliberately not exposed in overloads, hence the type-check - # violation. - - # Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request - # contains nexus-specific data such as a completion callback (used by the handler server - # namespace to deliver the result to the caller namespace when the workflow reaches a - # terminal state) and inbound links to the caller workflow (attached to history events of - # the workflow started in the handler namespace, and displayed in the UI). - with _nexus_backing_workflow_start_context(): - wf_handle = await self._temporal_context.client.start_workflow( # type: ignore - workflow=workflow, - arg=arg, - args=args, - id=id, - task_queue=task_queue or self._temporal_context.info().task_queue, - result_type=result_type, - execution_timeout=execution_timeout, - run_timeout=run_timeout, - task_timeout=task_timeout, - id_reuse_policy=id_reuse_policy, - id_conflict_policy=id_conflict_policy, - retry_policy=retry_policy, - cron_schedule=cron_schedule, - memo=memo, - search_attributes=search_attributes, - static_summary=static_summary, - static_details=static_details, - start_delay=start_delay, - start_signal=start_signal, - start_signal_args=start_signal_args, - rpc_metadata=rpc_metadata, - rpc_timeout=rpc_timeout, - request_eager_start=request_eager_start, - priority=priority, - versioning_override=versioning_override, - callbacks=self._temporal_context._get_callbacks(), - links=self._temporal_context._get_links(), - request_id=self._temporal_context.nexus_context.request_id, - ) - - self._temporal_context._add_outbound_links(wf_handle) - - return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) + return await _start_nexus_backing_workflow( + temporal_context=self._temporal_context, + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + ) @dataclass(frozen=True) @@ -586,3 +591,81 @@ def process( logger = LoggerAdapter(logging.getLogger("temporalio.nexus"), None) """Logger that emits additional data describing the current Nexus operation.""" + + +async def _start_nexus_backing_workflow( + temporal_context: _TemporalStartOperationContext, + workflow: str | Callable[..., Awaitable[ReturnType]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, +) -> WorkflowHandle[ReturnType]: + # We must pass nexus_completion_callbacks, workflow_event_links, and request_id, + # but these are deliberately not exposed in overloads, hence the type-check + # violation. + + # Here we are starting a "nexus-backing" workflow. That means that the StartWorkflow request + # contains nexus-specific data such as a completion callback (used by the handler server + # namespace to deliver the result to the caller namespace when the workflow reaches a + # terminal state) and inbound links to the caller workflow (attached to history events of + # the workflow started in the handler namespace, and displayed in the UI). + with _nexus_backing_workflow_start_context(): + wf_handle = await temporal_context.client.start_workflow( # type: ignore + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue or temporal_context.info().task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + callbacks=temporal_context._get_callbacks(), + links=temporal_context._get_links(), + request_id=temporal_context.nexus_context.request_id, + ) + + temporal_context._add_outbound_links(wf_handle) + + return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 68035ca41..ac0642b14 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -1,9 +1,7 @@ from __future__ import annotations from collections.abc import Awaitable, Callable -from typing import ( - Any, -) +from typing import Any from nexusrpc import ( HandlerError, @@ -16,12 +14,18 @@ OperationHandler, StartOperationContext, StartOperationResultAsync, + StartOperationResultSync, ) from temporalio.nexus._operation_context import ( _temporal_cancel_operation_context, + _TemporalCancelOperationContext, +) +from temporalio.nexus._temporal_client import ( + TemporalNexusClient, + TemporalOperationResult, ) -from temporalio.nexus._token import WorkflowHandle +from temporalio.nexus._token import OperationToken, OperationTokenType, WorkflowHandle from ._util import ( is_async_callable, @@ -112,3 +116,56 @@ async def _cancel_workflow( type=HandlerErrorType.NOT_FOUND, ) from err await client_workflow_handle.cancel(**kwargs) + + +class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT]): + """Operation handler for Nexus operations that interact with Temporal.""" + + def __init__( + self, + start: Callable[ + [StartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], + ) -> None: + """Initialize the Temporal operation handler.""" + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "TemporalNexusOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + if start_func := getattr(self.start, "__func__", None): + start_func.__doc__ = start.__doc__ + + async def start( + self, ctx: StartOperationContext, input: InputT + ) -> StartOperationResultSync[OutputT] | StartOperationResultAsync: + """Start the Nexus operation using a Nexus-aware Temporal client.""" + nexus_client = TemporalNexusClient() + result = await self._start(ctx, nexus_client, input) + return result._to_nexus_result() + + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: + """Cancel a Nexus operation using its operation token.""" + temporal_context = _TemporalCancelOperationContext.get() + client = temporal_context.client + + operation_token = OperationToken.decode(token) + if client.namespace != operation_token.namespace: + raise ValueError( + f"Client namespace {client.namespace} does not match " + f"operation token namespace {operation_token.namespace}" + ) + + match operation_token.type: + case OperationTokenType.WORKFLOW: + await self.cancel_workflow_run(ctx, operation_token.workflow_id) + + async def cancel_workflow_run(self, _ctx: CancelOperationContext, workflow_id: str): + """Cancels the workflow identified by workflow_id""" + temporal_context = _TemporalCancelOperationContext.get() + workflow_handle = temporal_context.client.get_workflow_handle(workflow_id) + await workflow_handle.cancel() diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py new file mode 100644 index 000000000..372d779ca --- /dev/null +++ b/temporalio/nexus/_temporal_client.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Iterator, Mapping, Sequence +from contextlib import contextmanager +from dataclasses import dataclass +from datetime import timedelta +from typing import ( + TYPE_CHECKING, + Any, + Concatenate, + Generic, + TypeVar, + cast, + overload, +) + +from nexusrpc import HandlerError, HandlerErrorType +from nexusrpc.handler import StartOperationResultAsync, StartOperationResultSync +from typing_extensions import Self + +import temporalio.common +from temporalio.nexus._operation_context import ( + _start_nexus_backing_workflow, + _TemporalStartOperationContext, +) +from temporalio.types import ( + MethodAsyncNoParam, + MethodAsyncSingleParam, + MultiParamSpec, + ParamType, + ReturnType, + SelfType, +) + +if TYPE_CHECKING: + import temporalio.client + + +_ResultT = TypeVar("_ResultT") + + +@dataclass(frozen=True) +class TemporalOperationResult(Generic[_ResultT]): + """Unified result: sync value or async token.""" + + value: _ResultT | object = temporalio.common._arg_unset + token: str | None = None + + @classmethod + def sync(cls, value: _ResultT) -> "TemporalOperationResult[_ResultT]": + """Create a result that completes the Nexus operation synchronously.""" + return cls(value=value) + + @classmethod + def async_token(cls, token: str) -> Self: + """Create a result that completes the Nexus operation asynchronously.""" + return cls(token=token) + + def _to_nexus_result( + self, + ) -> StartOperationResultSync[_ResultT] | StartOperationResultAsync: + if self.token is not None: + return StartOperationResultAsync(self.token) + elif self.value is not temporalio.common._arg_unset: + return StartOperationResultSync(cast(_ResultT, self.value)) + else: + raise RuntimeError( + "Invalid TemporalOperationResult. Neither token nor value are set." + ) + + +class TemporalNexusClient: + """Nexus-aware wrapper around a Temporal Client.""" + + def __init__(self) -> None: + """Initialize the client wrapper from the active Nexus operation context.""" + self._temporal_context = _TemporalStartOperationContext.get() + self._started_async = False + + @property + def client(self) -> temporalio.client.Client: + """Return the Temporal client for the active Nexus operation.""" + return self._temporal_context.client + + @contextmanager + def _reserve_async_start(self) -> Iterator[None]: + if self._started_async: + raise HandlerError( + "Only one async operation can be started per operation handler invocation. Use TemporalNexusClient.client for additional workflow interactions", + type=HandlerErrorType.BAD_REQUEST, + ) + + # Reserve the started flag before sending to prevent concurrent starts + self._started_async = True + try: + yield + except BaseException: + self._started_async = False + raise + + # Overload for no-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncNoParam[SelfType, ReturnType], + *, + id: str, + task_queue: str | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for single-param workflow + @overload + async def start_workflow( + self, + workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType], + arg: ParamType, + *, + id: str, + task_queue: str | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for multi-param workflow + @overload + async def start_workflow( + self, + workflow: Callable[ + Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType] + ], + *, + args: Sequence[Any], + id: str, + task_queue: str | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + # Overload for string-name workflow + @overload + async def start_workflow( + self, + workflow: str, + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type[ReturnType] | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: ... + + async def start_workflow( + self, + workflow: str | Callable[..., Awaitable[ReturnType]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: + """Start a workflow as the backing asynchronous Nexus operation.""" + with self._reserve_async_start(): + wf_handle = await _start_nexus_backing_workflow( + temporal_context=self._temporal_context, + workflow=workflow, + arg=arg, + args=args, + id=id, + task_queue=task_queue, + result_type=result_type, + execution_timeout=execution_timeout, + run_timeout=run_timeout, + task_timeout=task_timeout, + id_reuse_policy=id_reuse_policy, + id_conflict_policy=id_conflict_policy, + retry_policy=retry_policy, + cron_schedule=cron_schedule, + memo=memo, + search_attributes=search_attributes, + static_summary=static_summary, + static_details=static_details, + start_delay=start_delay, + start_signal=start_signal, + start_signal_args=start_signal_args, + rpc_metadata=rpc_metadata, + rpc_timeout=rpc_timeout, + request_eager_start=request_eager_start, + priority=priority, + versioning_override=versioning_override, + ) + + return TemporalOperationResult.async_token(wf_handle.to_token()) diff --git a/temporalio/nexus/_token.py b/temporalio/nexus/_token.py index 0a3d27375..f3e5758cf 100644 --- a/temporalio/nexus/_token.py +++ b/temporalio/nexus/_token.py @@ -3,17 +3,110 @@ import base64 import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Literal +from enum import IntEnum +from typing import TYPE_CHECKING, Any, Generic from nexusrpc import OutputT +from typing_extensions import Self + + +class OperationTokenType(IntEnum): + """Type discriminator for Nexus operation tokens.""" + + WORKFLOW = 1 -OperationTokenType = Literal[1] -OPERATION_TOKEN_TYPE_WORKFLOW: OperationTokenType = 1 if TYPE_CHECKING: import temporalio.client +@dataclass(frozen=True, kw_only=True) +class OperationToken: + """Serializable token identifying a Nexus operation target.""" + + version: int | None = None + type: OperationTokenType + namespace: str + workflow_id: str + + def encode(self) -> str: + """Convert handle to a base64url-encoded token string.""" + token_details: dict[str, Any] = { + "t": self.type, + "ns": self.namespace, + "wid": self.workflow_id, + } + if self.version is not None: + token_details["v"] = self.version + return _base64url_encode_no_padding( + json.dumps( + token_details, + separators=(",", ":"), + ).encode("utf-8") + ) + + @classmethod + def decode(cls, token: str) -> Self: + """Decodes and validates a token from its base64url-encoded string representation.""" + if not token: + raise TypeError("invalid token: token is empty") + try: + decoded_bytes = _base64url_decode_no_padding(token) + except Exception as err: + raise TypeError("failed to decode token as base64url") from err + try: + token_details = json.loads(decoded_bytes.decode("utf-8")) + except Exception as err: + raise TypeError("failed to unmarshal operation token") from err + + if not isinstance(token_details, dict): + raise TypeError(f"invalid token: expected dict, got {type(token_details)}") + + raw_token_type = token_details.get("t") + if not isinstance(raw_token_type, int): + raise TypeError( + f"invalid token: expected token type to be an int, got {type(raw_token_type)}" + ) + + try: + token_type = OperationTokenType(raw_token_type) + except ValueError as err: + raise TypeError( + f"invalid token: unknown token type, got {raw_token_type}.", + f"Valid values: {', '.join([f'{t.value} ({t.name})' for t in OperationTokenType])}", + ) from err + + version = token_details.get("v") + if version is not None and not isinstance(version, int): + raise TypeError( + f"invalid token: expected version to be an int or null, got {type(version)}" + ) + + workflow_id = token_details.get("wid") + if not isinstance(workflow_id, str): + raise TypeError( + f"invalid token: expected workflow id to be a string, got {type(workflow_id)}" + ) + + if token_type == OperationTokenType.WORKFLOW and not workflow_id: + raise TypeError( + "invalid token: expected non-empty workflow id for token type `WORKFLOW`" + ) + + namespace = token_details.get("ns") + if not isinstance(namespace, str) or not namespace: + raise TypeError( + f"invalid token: expected namespace to be a non-empty string, got {type(namespace)}" + ) + + return cls( + type=OperationTokenType(token_type), + namespace=namespace, + workflow_id=workflow_id, + version=version, + ) + + @dataclass(frozen=True) class WorkflowHandle(Generic[OutputT]): """A handle to a workflow that is backing a Nexus operation. @@ -59,65 +152,36 @@ def _unsafe_from_client_workflow_handle( def to_token(self) -> str: """Convert handle to a base64url-encoded token string.""" - return _base64url_encode_no_padding( - json.dumps( - { - "t": OPERATION_TOKEN_TYPE_WORKFLOW, - "ns": self.namespace, - "wid": self.workflow_id, - }, - separators=(",", ":"), - ).encode("utf-8") - ) + return OperationToken( + type=OperationTokenType.WORKFLOW, + namespace=self.namespace, + workflow_id=self.workflow_id, + ).encode() @classmethod def from_token(cls, token: str) -> WorkflowHandle[OutputT]: """Decodes and validates a token from its base64url-encoded string representation.""" - if not token: - raise TypeError("invalid workflow token: token is empty") - try: - decoded_bytes = _base64url_decode_no_padding(token) - except Exception as err: - raise TypeError("failed to decode token as base64url") from err - try: - workflow_operation_token = json.loads(decoded_bytes.decode("utf-8")) - except Exception as err: - raise TypeError("failed to unmarshal workflow operation token") from err - - if not isinstance(workflow_operation_token, dict): + op_token = OperationToken.decode(token) + if op_token.type != OperationTokenType.WORKFLOW: raise TypeError( - f"invalid workflow token: expected dict, got {type(workflow_operation_token)}" + f"invalid workflow token type: {op_token.type}, expected: {OperationTokenType.WORKFLOW}" ) - token_type = workflow_operation_token.get("t") - if token_type != OPERATION_TOKEN_TYPE_WORKFLOW: - raise TypeError( - f"invalid workflow token type: {token_type}, expected: {OPERATION_TOKEN_TYPE_WORKFLOW}" - ) - - version = workflow_operation_token.get("v") - if version is not None and version != 0: + if op_token.version is not None and op_token.version != 0: raise TypeError( "invalid workflow token: 'v' field, if present, must be 0 or null/absent" ) - workflow_id = workflow_operation_token.get("wid") - if not workflow_id or not isinstance(workflow_id, str): - raise TypeError( - "invalid workflow token: missing, empty, or non-string workflow ID (wid)" - ) - - namespace = workflow_operation_token.get("ns") - if namespace is None or not isinstance(namespace, str): + if not isinstance(op_token.namespace, str): # Allow empty string for ns, but it must be present and a string raise TypeError( "invalid workflow token: missing or non-string namespace (ns)" ) return cls( - namespace=namespace, - workflow_id=workflow_id, - version=version, + namespace=op_token.namespace, + workflow_id=op_token.workflow_id, + version=op_token.version, ) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index 48d3ad644..ef7ccf78f 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -7,7 +7,6 @@ from collections.abc import Awaitable, Callable from typing import ( Any, - TypeVar, ) import nexusrpc @@ -16,18 +15,24 @@ OutputT, ) -from temporalio.nexus._operation_context import WorkflowRunOperationContext +from temporalio.nexus._operation_context import ( + TemporalStartOperationContext, + WorkflowRunOperationContext, +) +from temporalio.nexus._temporal_client import ( + TemporalNexusClient, + TemporalOperationResult, +) +from temporalio.types import NexusServiceType from ._token import ( WorkflowHandle as WorkflowHandle, ) -ServiceHandlerT = TypeVar("ServiceHandlerT") - def get_workflow_run_start_method_input_and_output_type_annotations( start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], + [NexusServiceType, WorkflowRunOperationContext, InputT], Awaitable[WorkflowHandle[OutputT]], ], ) -> tuple[ @@ -39,13 +44,54 @@ def get_workflow_run_start_method_input_and_output_type_annotations( ``start`` must be a type-annotated start method that returns a :py:class:`temporalio.nexus.WorkflowHandle`. """ - input_type, output_type = _get_start_method_input_and_output_type_annotations(start) + return _get_wrapped_start_method_input_and_output_type_annotations( + start, + expected_param_types=(WorkflowRunOperationContext,), + expected_return_origin=WorkflowHandle, + ) + + +def get_temporal_operation_start_method_input_and_output_type_annotations( + start: Callable[ + [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], +) -> tuple[ + type[InputT] | None, + type[OutputT] | None, +]: + """Return operation input and output types. + + ``start`` must be a type-annotated start method that returns a + :py:class:`temporalio.nexus.TemporalOperationResult`. + """ + return _get_wrapped_start_method_input_and_output_type_annotations( + start, + expected_param_types=(TemporalStartOperationContext, TemporalNexusClient), + expected_return_origin=TemporalOperationResult, + ) + + +def _get_wrapped_start_method_input_and_output_type_annotations( + start: Callable[..., Any], + *, + expected_param_types: tuple[type[Any], ...], + expected_return_origin: type[Any], +) -> tuple[ + type[Any] | None, + type[Any] | None, +]: + input_type, output_type = _get_start_method_input_and_output_type_annotations( + start, + expected_param_types=expected_param_types, + ) origin_type = typing.get_origin(output_type) if not origin_type: output_type = None - elif not issubclass(origin_type, WorkflowHandle): + elif not _is_subclass(origin_type, expected_return_origin): warnings.warn( - f"Expected return type of {start.__name__} to be a subclass of WorkflowHandle, " + f"Expected return type of {start.__name__} to be a subclass of " + f"{expected_return_origin.__name__}, " f"but is {output_type}" ) output_type = None @@ -65,13 +111,12 @@ def get_workflow_run_start_method_input_and_output_type_annotations( def _get_start_method_input_and_output_type_annotations( - start: Callable[ - [ServiceHandlerT, WorkflowRunOperationContext, InputT], - Awaitable[WorkflowHandle[OutputT]], - ], + start: Callable[..., Any], + *, + expected_param_types: tuple[type[Any], ...], ) -> tuple[ - type[InputT] | None, - type[OutputT] | None, + type[Any] | None, + type[Any] | None, ]: try: type_annotations = typing.get_type_hints(start) @@ -81,27 +126,39 @@ def _get_start_method_input_and_output_type_annotations( ) return None, None output_type = type_annotations.pop("return", None) + expected_parameter_count = len(expected_param_types) + 1 - if len(type_annotations) != 2: + if len(type_annotations) != expected_parameter_count: suffix = f": {type_annotations}" if type_annotations else "" warnings.warn( - f"Expected decorated start method {start} to have exactly 2 " - f"type-annotated parameters (ctx and input), but it has {len(type_annotations)}" + f"Expected decorated start method {start} to have exactly " + f"{expected_parameter_count} type-annotated parameters, " + f"but it has {len(type_annotations)}" f"{suffix}." ) input_type = None else: - ctx_type, input_type = type_annotations.values() - if not issubclass(ctx_type, WorkflowRunOperationContext): - warnings.warn( - f"Expected first parameter of {start} to be an instance of " - f"WorkflowRunOperationContext, but is {ctx_type}." - ) - input_type = None + *param_types, input_type = type_annotations.values() + for index, (param_type, expected_param_type) in enumerate( + zip(param_types, expected_param_types), start=1 + ): + if not _is_subclass(expected_param_type, param_type): + warnings.warn( + f"Expected parameter {index} of {start} to be an instance of " + f"{expected_param_type.__name__}, but is {param_type}." + ) + input_type = None return input_type, output_type +def _is_subclass(cls: Any, class_or_tuple: type[Any]) -> bool: + try: + return issubclass(cls, class_or_tuple) + except TypeError: + return False + + def get_callable_name(fn: Callable[..., Any]) -> str: """Return the name of a callable object.""" method_name = getattr(fn, "__name__", None) diff --git a/temporalio/workflow/__init__.py b/temporalio/workflow/__init__.py index f8002366b..ec74299c2 100644 --- a/temporalio/workflow/__init__.py +++ b/temporalio/workflow/__init__.py @@ -2,8 +2,6 @@ from __future__ import annotations -from temporalio.nexus._util import ServiceHandlerT - from ..types import ( AnyType, CallableAsyncNoParam, @@ -293,7 +291,6 @@ "_sandbox_unrestricted", # Re-export Temporal-owned names that old temporalio/workflow.py imported # at module scope so explicit imports from temporalio.workflow keep working. - "ServiceHandlerT", "AnyType", "CallableAsyncNoParam", "CallableAsyncSingleParam", diff --git a/temporalio/workflow/_nexus.py b/temporalio/workflow/_nexus.py index 0b80e6d91..29bd10715 100644 --- a/temporalio/workflow/_nexus.py +++ b/temporalio/workflow/_nexus.py @@ -12,7 +12,6 @@ import temporalio.bridge.proto.nexus import temporalio.nexus -from temporalio.nexus._util import ServiceHandlerT from temporalio.types import NexusServiceType from ._context import _Runtime @@ -138,7 +137,7 @@ async def start_operation( async def start_operation( self, operation: Callable[ - [ServiceHandlerT, temporalio.nexus.WorkflowRunOperationContext, InputT], + [NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT], Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], ], input: InputT, @@ -158,7 +157,7 @@ async def start_operation( async def start_operation( self, operation: Callable[ - [ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT], + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], Awaitable[OutputT], ], input: InputT, @@ -178,7 +177,7 @@ async def start_operation( async def start_operation( self, operation: Callable[ - [ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT], + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], OutputT, ], input: InputT, @@ -198,7 +197,32 @@ async def start_operation( async def start_operation( self, operation: Callable[ - [ServiceHandlerT], nexusrpc.handler.OperationHandler[InputT, OutputT] + [NexusServiceType], nexusrpc.handler.OperationHandler[InputT, OutputT] + ], + input: InputT, + *, + output_type: type[OutputT] | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Mapping[str, str] | None = None, + summary: str | None = None, + ) -> NexusOperationHandle[OutputT]: ... + + # Overload for temporal_operation methods + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [ + NexusServiceType, + temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusClient, + InputT, + ], + Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]], ], input: InputT, *, @@ -284,7 +308,7 @@ async def execute_operation( async def execute_operation( self, operation: Callable[ - [ServiceHandlerT, temporalio.nexus.WorkflowRunOperationContext, InputT], + [NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT], Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], ], input: InputT, @@ -358,6 +382,31 @@ async def execute_operation( summary: str | None = None, ) -> OutputT: ... + # Overload for temporal_operation methods + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [ + NexusServiceType, + temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusClient, + InputT, + ], + Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]], + ], + input: InputT, + *, + output_type: type[OutputT] | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, + headers: Mapping[str, str] | None = None, + summary: str | None = None, + ) -> OutputT: ... + @abstractmethod async def execute_operation( self, diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index f467f8aa3..fe37296e9 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -254,6 +254,44 @@ async def check() -> PendingActivityInfo: return await assert_eventually(check, timeout=timeout) +async def assert_event_subsequence( + wf_handle: WorkflowHandle, + expected_events: list[EventType.ValueType], + timeout: timedelta = timedelta(seconds=5), +) -> None: + """ + Given a workflow handle and a sequence of event types, assert that the workflow's history + contains that subsequence of events in the order specified. + """ + + async def check(): + history = await wf_handle.fetch_history() + + _all_events = iter(history.events) + _expected_events = iter(expected_events) + + previous_expected_event_type_name = None + for expected_event_type in _expected_events: + expected_event_type_name = EventType.Name(expected_event_type).removeprefix( + "EVENT_TYPE_" + ) + has_expected = next( + (e for e in _all_events if e.event_type == expected_event_type), + None, + ) + if not has_expected: + if previous_expected_event_type_name is not None: + prefix = f"After {previous_expected_event_type_name}, " + else: + prefix = "" + raise AssertionError( + f"{prefix}expected {expected_event_type_name} in workflow {wf_handle.id}" + ) + previous_expected_event_type_name = expected_event_type_name + + await assert_eventually(check, timeout=timeout) + + async def get_pending_activity_info( handle: WorkflowHandle, activity_id: str, diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index 8a0d6262a..f47e3f47f 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -3,6 +3,7 @@ and input/output types. """ +import warnings from dataclasses import dataclass from typing import Any @@ -99,3 +100,101 @@ async def test_collected_operation_names( assert actual_op.name == expected_op.name assert actual_op.input_type == expected_op.input_type assert actual_op.output_type == expected_op.output_type + + +def test_unsafe_narrow_context_annotations_warn_and_drop_input_type(): + """Unsafe context annotations warn and prevent input type inference. + + Decorators construct a specific context type at runtime. If a handler annotates a + narrower or unrelated context type, the decorator cannot safely call it, so we + should warn and avoid using the handler annotation to infer operation input type. + """ + + with pytest.warns( + UserWarning, + match="Expected parameter 1 .* TemporalStartOperationContext", + ): + + class MyTemporalOpCtx(nexus.TemporalStartOperationContext): + def custom_method(self): + raise NotImplementedError + + class TemporalOperationHandler: + @nexus.temporal_operation # type: ignore[arg-type] + async def op( + self, + _ctx: MyTemporalOpCtx, + _client: nexus.TemporalNexusClient, + _input: Input, + ) -> nexus.TemporalOperationResult[Output]: + raise NotImplementedError + + _, temporal_op = get_operation_factory(TemporalOperationHandler.op) + assert isinstance(temporal_op, nexusrpc.Operation) + assert temporal_op.input_type is None + assert temporal_op.output_type == Output + + with pytest.warns( + UserWarning, + match="Expected parameter 1 .* WorkflowRunOperationContext", + ): + + class MyWorkflowRunOpCtx(nexus.WorkflowRunOperationContext): + def custom_method(self): + raise NotImplementedError + + class WorkflowRunOperationHandler: + @workflow_run_operation # type: ignore[arg-type] + async def op( + self, + _ctx: MyWorkflowRunOpCtx, + _input: Input, + ) -> nexus.WorkflowHandle[Output]: + raise NotImplementedError + + _, workflow_op = get_operation_factory(WorkflowRunOperationHandler.op) + assert isinstance(workflow_op, nexusrpc.Operation) + assert workflow_op.input_type is None + assert workflow_op.output_type == Output + + +def test_safe_broader_context_annotations_preserve_input_type_without_warnings(): + """Safe context annotations preserve input type inference without warnings. + + A handler can safely annotate a context parameter with the exact runtime context + type or a broader base type. These cases should keep handler-derived operation + input metadata intact. + """ + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + + class TemporalOperationHandler: + @nexus.temporal_operation + async def op( + self, + _ctx: nexusrpc.handler.StartOperationContext, + _client: nexus.TemporalNexusClient, + _input: Input, + ) -> nexus.TemporalOperationResult[Output]: + raise NotImplementedError + + class WorkflowRunStartContextHandler: + @workflow_run_operation + async def op( + self, + _ctx: nexusrpc.handler.StartOperationContext, + _input: Input, + ) -> nexus.WorkflowHandle[Output]: + raise NotImplementedError + + assert not caught + + for method in ( + TemporalOperationHandler.op, + WorkflowRunStartContextHandler.op, + ): + _, op = get_operation_factory(method) + assert isinstance(op, nexusrpc.Operation) + assert op.input_type == Input + assert op.output_type == Output diff --git a/tests/nexus/test_nexus_client_updates.py b/tests/nexus/test_nexus_client_updates.py index 323f21f05..f63d5482c 100644 --- a/tests/nexus/test_nexus_client_updates.py +++ b/tests/nexus/test_nexus_client_updates.py @@ -62,7 +62,7 @@ async def test_nexus_client_updates_when_worker_client_changes( handler_task_queue = f"handler-{uuid.uuid4()}" # Create Nexus endpoint - endpoint_name = "test-endpoint" + endpoint_name = f"test-endpoint-{uuid.uuid4()}" await env.create_nexus_endpoint(endpoint_name, handler_task_queue) # Caller worker diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index 4a013e8a9..0ea5b60ff 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -26,10 +26,55 @@ class MyOutput: pass +@workflow.defn +class MyNoArgProcWorkflow: + @workflow.run + async def run(self) -> None: + pass + + +@workflow.defn +class MyOneArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput) -> None: + pass + + +@workflow.defn +class MyTwoArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput, _arg2: int) -> None: + pass + + +@workflow.defn +class MyThreeArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput, _arg2: int, _arg3: int) -> None: + pass + + +@workflow.defn +class MyFourArgProcWorkflow: + @workflow.run + async def run(self, _input: MyInput, _arg2: int, _arg3: int, _arg4: int) -> None: + pass + + +@workflow.defn +class MyFiveArgProcWorkflow: + @workflow.run + async def run( + self, _input: MyInput, _arg2: int, _arg3: int, _arg4: int, _arg5: int + ) -> None: + pass + + @nexusrpc.service class MyService: my_sync_operation: nexusrpc.Operation[MyInput, MyOutput] my_workflow_run_operation: nexusrpc.Operation[MyInput, MyOutput] + my_temporal_operation: nexusrpc.Operation[int, None] @nexusrpc.service @@ -51,6 +96,71 @@ async def my_workflow_run_operation( ) -> temporalio.nexus.WorkflowHandle[MyOutput]: raise NotImplementedError + @temporalio.nexus.temporal_operation + async def my_temporal_operation( + self, + _ctx: temporalio.nexus.TemporalStartOperationContext, + client: temporalio.nexus.TemporalNexusClient, + input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + """ + Typed proc workflow starts from a generic Temporal Nexus operation handler + infer TemporalOperationResult[None] for 0 to 5 workflow parameters. + """ + if input == 0: + result_0: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow(MyNoArgProcWorkflow.run, id="proc-0") + return result_0 + if input == 1: + result_1: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyOneArgProcWorkflow.run, MyInput(), id="proc-1" + ) + return result_1 + if input == 2: + result_2: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyTwoArgProcWorkflow.run, args=[MyInput(), 2], id="proc-2" + ) + return result_2 + if input == 3: + result_3: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyThreeArgProcWorkflow.run, + args=[MyInput(), 2, 3], + id="proc-3", + ) + return result_3 + if input == 4: + result_4: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyFourArgProcWorkflow.run, + args=[MyInput(), 2, 3, 4], + id="proc-4", + ) + return result_4 + if input == 5: + result_5: temporalio.nexus.TemporalOperationResult[ + None + ] = await client.start_workflow( + MyFiveArgProcWorkflow.run, + args=[MyInput(), 2, 3, 4, 5], + id="proc-5", + ) + return result_5 + # assert-type-error-pyright: 'No overloads for "start_workflow" match' + return await client.start_workflow( # type: ignore + MyOneArgProcWorkflow.run, + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter' + "wrong-input-type", # type: ignore + id="proc-wrong-input", + ) + @nexusrpc.handler.service_handler(service=MyService) class MyServiceHandler2: @@ -66,6 +176,15 @@ async def my_workflow_run_operation( ) -> temporalio.nexus.WorkflowHandle[MyOutput]: raise NotImplementedError + @temporalio.nexus.temporal_operation + async def my_temporal_operation( + self, + _ctx: temporalio.nexus.TemporalStartOperationContext, + _client: temporalio.nexus.TemporalNexusClient, + _input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + raise NotImplementedError + @nexusrpc.handler.service_handler class MyServiceHandlerWithoutServiceDefinition: @@ -81,6 +200,52 @@ async def my_workflow_run_operation( ) -> temporalio.nexus.WorkflowHandle[MyOutput]: raise NotImplementedError + @temporalio.nexus.temporal_operation + async def my_temporal_operation( + self, + _ctx: temporalio.nexus.TemporalStartOperationContext, + _client: temporalio.nexus.TemporalNexusClient, + _input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + raise NotImplementedError + + +class MyUnsafeContextAnnotationServiceHandler: + # A temporal operation receives TemporalStartOperationContext at runtime, so + # requiring an arbitrary user subclass is not safe. + class MyCustomTemporalStartOperationContext( + temporalio.nexus.TemporalStartOperationContext + ): + def custom_state(self) -> str: + raise NotImplementedError + + # assert-type-error-pyright: 'cannot be assigned to parameter "start".+temporal_operation' + @temporalio.nexus.temporal_operation # type: ignore + async def my_temporal_operation_with_workflow_run_context( + self, + _ctx: MyCustomTemporalStartOperationContext, + _client: temporalio.nexus.TemporalNexusClient, + _input: int, + ) -> temporalio.nexus.TemporalOperationResult[None]: + raise NotImplementedError + + # A workflow run operation receives WorkflowRunOperationContext at runtime, + # so requiring an arbitrary user subclass is not safe. + class MyCustomWorkflowRunOperationContext( + temporalio.nexus.WorkflowRunOperationContext + ): + def custom_state(self) -> str: + raise NotImplementedError + + # assert-type-error-pyright: 'cannot be assigned to parameter "start".+workflow_run_operation' + @temporalio.nexus.workflow_run_operation # type: ignore + async def my_workflow_run_operation_with_custom_context( + self, + _ctx: MyCustomWorkflowRunOperationContext, + _input: MyInput, + ) -> temporalio.nexus.WorkflowHandle[MyOutput]: + raise NotImplementedError + @workflow.defn class MyWorkflow1: @@ -116,6 +281,15 @@ async def test_invoke_by_operation_definition_happy_path(self) -> None: ) _output_2_1: MyOutput = await _handle_2 + # temporal operation + _output_3: None = await nexus_client.execute_operation( # type: ignore + MyService.my_temporal_operation, 0 + ) + _handle_3: workflow.NexusOperationHandle[ + None + ] = await nexus_client.start_operation(MyService.my_temporal_operation, 0) + _output_3_1: None = await _handle_3 # type: ignore + @workflow.defn class MyWorkflow2: @@ -153,6 +327,17 @@ async def test_invoke_by_operation_handler_happy_path(self) -> None: ) _output_2_1: MyOutput = await _handle_2 + # temporal operation + _output_3: None = await nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_temporal_operation, 0 + ) + _handle_3: workflow.NexusOperationHandle[ + None + ] = await nexus_client.start_operation( + MyServiceHandler.my_temporal_operation, 0 + ) + _output_3_1: None = await _handle_3 # type: ignore + @workflow.defn class MyWorkflow3: @@ -172,6 +357,12 @@ async def test_invoke_by_operation_definition_wrong_input_type(self) -> None: # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' "wrong-input-type", # type: ignore ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + MyService.my_temporal_operation, + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' + "wrong-input-type", # type: ignore + ) @workflow.defn @@ -192,6 +383,12 @@ async def test_invoke_by_operation_handler_wrong_input_type(self) -> None: # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' "wrong-input-type", # type: ignore ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_temporal_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "input"' + "wrong-input-type", # type: ignore + ) @workflow.defn @@ -216,8 +413,14 @@ async def test_invoke_by_operation_handler_method_on_wrong_service(self) -> None MyInput(), ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await nexus_client.execute_operation( # type: ignore + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "operation"' + MyServiceHandler2.my_temporal_operation, # type: ignore + 0, + ) + -# ── Standalone Nexus Operation type tests ── async def standalone_operation_type_tests(): client = Client(service_client=Mock(spec=ServiceClient)) nexus_client = client.create_nexus_client( @@ -366,7 +569,7 @@ async def standalone_operation_type_tests(): client.get_nexus_operation_handle("op-1", result_type=MyOutput) ) - # getting a handle with an operation defintion produces a handle of the operation + # getting a handle with an operation definition produces a handle of the operation # output type _op_defn_get_handle: NexusOperationHandle[MyOutput] = ( client.get_nexus_operation_handle("op-1", operation=MyService.my_sync_operation) diff --git a/tests/nexus/test_operation_token.py b/tests/nexus/test_operation_token.py new file mode 100644 index 000000000..d2d6b617b --- /dev/null +++ b/tests/nexus/test_operation_token.py @@ -0,0 +1,115 @@ +import base64 +import json +from typing import Any + +import pytest + +from temporalio.nexus._token import ( + OperationToken, + OperationTokenType, + WorkflowHandle, +) + + +def _encode_json_token(value: Any) -> str: + return _encode_bytes(json.dumps(value, separators=(",", ":")).encode("utf-8")) + + +def _encode_bytes(value: bytes) -> str: + return base64.urlsafe_b64encode(value).decode("utf-8").rstrip("=") + + +def test_operation_token_encode_decode_round_trip(): + token = OperationToken( + type=OperationTokenType.WORKFLOW, + namespace="default", + workflow_id="workflow-id", + version=0, + ).encode() + + assert "=" not in token + assert OperationToken.decode(token) == OperationToken( + type=OperationTokenType.WORKFLOW, + namespace="default", + workflow_id="workflow-id", + version=0, + ) + + +def test_workflow_handle_to_from_token_round_trip(): + handle = WorkflowHandle[str](namespace="default", workflow_id="workflow-id") + + assert WorkflowHandle[str].from_token(handle.to_token()) == handle + + +@pytest.mark.parametrize( + ("token", "message"), + [ + ("", "invalid token: token is empty"), + ("not+a-base64url-token", "failed to decode token as base64url"), + (_encode_bytes(b"not json"), "failed to unmarshal operation token"), + (_encode_json_token(["not", "a", "dict"]), "expected dict"), + ( + _encode_json_token({"ns": "default", "wid": "workflow-id"}), + "expected token type to be an int", + ), + ( + _encode_json_token({"t": "1", "ns": "default", "wid": "workflow-id"}), + "expected token type to be an int", + ), + ( + _encode_json_token({"t": 999, "ns": "default", "wid": "workflow-id"}), + "unknown token type", + ), + ( + _encode_json_token({"t": 1, "ns": "default"}), + "expected workflow id to be a string", + ), + ( + _encode_json_token({"t": 1, "ns": "default", "wid": 123}), + "expected workflow id to be a string", + ), + ( + _encode_json_token({"t": 1, "ns": "default", "wid": ""}), + "expected non-empty workflow id", + ), + ( + _encode_json_token({"t": 1, "wid": "workflow-id"}), + "expected namespace to be a non-empty string", + ), + ( + _encode_json_token({"t": 1, "ns": 123, "wid": "workflow-id"}), + "expected namespace to be a non-empty string", + ), + ( + _encode_json_token({"t": 1, "ns": "", "wid": "workflow-id"}), + "expected namespace to be a non-empty string", + ), + ( + _encode_json_token( + {"t": 1, "ns": "default", "wid": "workflow-id", "v": "0"} + ), + "expected version to be an int or null", + ), + ], +) +def test_operation_token_decode_rejects_invalid_tokens(token: str, message: str): + with pytest.raises(TypeError, match=message): + OperationToken.decode(token) + + +def test_workflow_handle_from_token_accepts_version_zero(): + token = _encode_json_token({"t": 1, "ns": "default", "wid": "workflow-id", "v": 0}) + + assert WorkflowHandle[str].from_token(token) == WorkflowHandle[str]( + namespace="default", + workflow_id="workflow-id", + version=0, + ) + + +def test_workflow_handle_from_token_rejects_unsupported_version(): + token = _encode_json_token({"t": 1, "ns": "default", "wid": "workflow-id", "v": 1}) + + with pytest.raises(TypeError, match="'v' field, if present, must be 0"): + WorkflowHandle[str].from_token(token) diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py new file mode 100644 index 000000000..8ac79b28b --- /dev/null +++ b/tests/nexus/test_temporal_operation.py @@ -0,0 +1,588 @@ +import asyncio +import uuid +from dataclasses import dataclass + +import nexusrpc +import pytest +from nexusrpc import HandlerErrorType, Operation, service +from nexusrpc.handler import ( + service_handler, +) + +import temporalio.exceptions +from temporalio import nexus, workflow +from temporalio.client import Client, WorkflowFailureError +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers import EventType, assert_event_subsequence +from tests.helpers.nexus import make_nexus_endpoint_name + + +@dataclass +class Input: + value: str + task_queue: str + + +@workflow.defn +class EchoWorkflow: + @workflow.run + async def run(self, input: Input) -> str: + return input.value + + +@service +class TestService: + echo: Operation[Input, str] + blocking: Operation[None, None] + double_start: Operation[Input, None] + concurrent_start: Operation[Input, str] + retry_after_failed_start: Operation[Input, str] + sync_result: Operation[Input, str] + + +@service_handler(service=TestService) +class EchoServiceHandler: + @nexus.temporal_operation + async def echo( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + return await client.start_workflow( + EchoWorkflow.run, input, id=f"echo-{input.value}" + ) + + @nexus.temporal_operation + async def blocking( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + _input: None, + ) -> nexus.TemporalOperationResult[None]: + return await client.start_workflow( + BlockingWorkflow.run, id=f"blocking-{uuid.uuid4()}" + ) + + @nexus.temporal_operation + async def double_start( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[None]: + await client.start_workflow( + EchoWorkflow.run, input, id=f"double-start-{uuid.uuid4()}" + ) + await client.start_workflow( + EchoWorkflow.run, input, id=f"double-start-{uuid.uuid4()}" + ) + return nexus.TemporalOperationResult.sync(None) + + @nexus.temporal_operation + async def concurrent_start( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + results = await asyncio.gather( + client.start_workflow( + EchoWorkflow.run, + input, + id=f"concurrent-start-1-{uuid.uuid4()}", + ), + client.start_workflow( + EchoWorkflow.run, + input, + id=f"concurrent-start-2-{uuid.uuid4()}", + ), + return_exceptions=True, + ) + + async_results: list[nexus.TemporalOperationResult[str]] = [] + handler_errors: list[nexusrpc.HandlerError] = [] + for result in results: + if isinstance(result, nexus.TemporalOperationResult): + async_results.append(result) + elif isinstance(result, nexusrpc.HandlerError): + handler_errors.append(result) + elif isinstance(result, BaseException): + raise result + else: + raise RuntimeError(f"Unexpected concurrent start result: {result}") + + if ( + len(async_results) == 1 + and len(handler_errors) == 1 + and handler_errors[0].type == HandlerErrorType.BAD_REQUEST + ): + return async_results[0] + + raise RuntimeError( + "Expected one async workflow start and one BAD_REQUEST HandlerError, " + f"got {len(async_results)} starts and {len(handler_errors)} handler errors" + ) + + @nexus.temporal_operation + async def retry_after_failed_start( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + try: + await client.start_workflow( + BlockingWorkflow.run, + id=input.value, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + except temporalio.exceptions.WorkflowAlreadyStartedError: + return await client.start_workflow( + EchoWorkflow.run, + input, + id=f"retry-after-failed-start-{uuid.uuid4()}", + ) + + raise RuntimeError("Expected first workflow start to fail") + + @nexus.temporal_operation + async def sync_result( + self, + _ctx: nexus.TemporalStartOperationContext, + _client: nexus.TemporalNexusClient, + input: Input, + ) -> nexus.TemporalOperationResult[str]: + return nexus.TemporalOperationResult.sync(input.value) + + +@workflow.defn +class EchoWorkflowCaller: + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + return await client.execute_operation(TestService.echo, input) + + +async def test_temporal_operation_start_workflow( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[EchoWorkflow, EchoWorkflowCaller], + ): + wf_handle = await client.start_workflow( + EchoWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=str(uuid.uuid4()), + ) + result = await wf_handle.result() + assert result == "test" + + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + EventType.EVENT_TYPE_NEXUS_OPERATION_STARTED, + EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED, + ], + ) + + +@workflow.defn +class BlockingWorkflow: + done: bool = False + + @workflow.run + async def run(self) -> None: + await workflow.wait_condition(lambda: self.done) + + @workflow.update + async def unblock(self): + self.done = True + + +@workflow.defn +class CancelBlockingWorkflowCaller: + op_started = False + + @workflow.run + async def run(self, input: Input) -> None: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + op_handle = await client.start_operation(TestService.blocking, None) + self.op_started = True + return await op_handle + + @workflow.update + async def wait_operation_started(self): + await workflow.wait_condition(lambda: self.op_started) + + +async def test_temporal_operation_cancel_workflow( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[BlockingWorkflow, CancelBlockingWorkflowCaller], + ): + wf_handle = await client.start_workflow( + CancelBlockingWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=f"blocking-{uuid.uuid4()}", + ) + + await wf_handle.execute_update( + CancelBlockingWorkflowCaller.wait_operation_started + ) + + await wf_handle.cancel() + + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED, + ], + ) + + +@workflow.defn +class DoubleStartWorkflowCaller: + @workflow.run + async def run(self, input: Input) -> None: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + op_handle = await client.start_operation(TestService.double_start, input) + return await op_handle + + +@workflow.defn +class ConcurrentStartWorkflowCaller: + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + return await client.execute_operation(TestService.concurrent_start, input) + + +@workflow.defn +class FailedStartRollbackWorkflowCaller: + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + return await client.execute_operation( + TestService.retry_after_failed_start, + input, + ) + + +async def test_temporal_operation_double_start_raises_handler_err( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[EchoWorkflow, DoubleStartWorkflowCaller], + ): + with pytest.raises(WorkflowFailureError) as err: + await client.execute_workflow( + DoubleStartWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=f"double-start-{uuid.uuid4()}", + ) + + assert isinstance(err.value.cause, temporalio.exceptions.NexusOperationError) + assert isinstance(err.value.cause.cause, nexusrpc.HandlerError) + assert err.value.cause.cause.type == HandlerErrorType.BAD_REQUEST + assert ( + "Only one async operation can be started per operation handler invocation" + in err.value.cause.cause.message + ) + + +async def test_temporal_operation_concurrent_start_raises_handler_err( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[EchoWorkflow, ConcurrentStartWorkflowCaller], + ): + result = await client.execute_workflow( + ConcurrentStartWorkflowCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=f"concurrent-start-{uuid.uuid4()}", + ) + + assert result == "test" + + +async def test_temporal_operation_failed_start_allows_retry( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + conflict_id = f"failed-start-rollback-{uuid.uuid4()}" + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[ + BlockingWorkflow, + EchoWorkflow, + FailedStartRollbackWorkflowCaller, + ], + ): + conflict_handle = await client.start_workflow( + BlockingWorkflow.run, + id=conflict_id, + task_queue=task_queue, + ) + + try: + result = await client.execute_workflow( + FailedStartRollbackWorkflowCaller.run, + Input(value=conflict_id, task_queue=task_queue), + task_queue=task_queue, + id=f"failed-start-rollback-caller-{uuid.uuid4()}", + ) + assert result == conflict_id + finally: + await conflict_handle.cancel() + + +@workflow.defn +class SyncResultCaller: + @workflow.run + async def run(self, input: Input) -> str: + client = workflow.create_nexus_client( + service=TestService, endpoint=make_nexus_endpoint_name(input.task_queue) + ) + return await client.execute_operation(TestService.sync_result, input) + + +async def test_temporal_operation_sync_result(client: Client, env: WorkflowEnvironment): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[EchoServiceHandler()], + workflows=[SyncResultCaller], + ): + wf_handle = await client.start_workflow( + SyncResultCaller.run, + Input(value="test", task_queue=task_queue), + task_queue=task_queue, + id=str(uuid.uuid4()), + ) + result = await wf_handle.result() + assert result == "test" + + # Sync results do not produce a NEXUS_OPERATION_STARTED event, + await assert_event_subsequence( + wf_handle, + [ + EventType.EVENT_TYPE_NEXUS_OPERATION_SCHEDULED, + EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED, + ], + ) + + +@dataclass +class TemporalOperationOverloadTestValue: + value: int + + +@workflow.defn +class TemporalOperationOverloadTestWorkflow: + @workflow.run + async def run( + self, input: TemporalOperationOverloadTestValue + ) -> TemporalOperationOverloadTestValue: + return TemporalOperationOverloadTestValue(value=input.value * 2) + + +@workflow.defn +class TemporalOperationOverloadTestWorkflowNoParam: + @workflow.run + async def run(self) -> TemporalOperationOverloadTestValue: + return TemporalOperationOverloadTestValue(value=0) + + +@service_handler +class TemporalOperationOverloadTestServiceHandler: + @nexus.temporal_operation + async def no_param( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + _input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + TemporalOperationOverloadTestWorkflowNoParam.run, + id=str(uuid.uuid4()), + ) + + @nexus.temporal_operation + async def single_param( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + TemporalOperationOverloadTestWorkflow.run, + input, + id=str(uuid.uuid4()), + ) + + @nexus.temporal_operation + async def multi_param( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + TemporalOperationOverloadTestWorkflow.run, + args=[input], + id=str(uuid.uuid4()), + ) + + @nexus.temporal_operation + async def by_name( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + "TemporalOperationOverloadTestWorkflow", + input, + id=str(uuid.uuid4()), + result_type=TemporalOperationOverloadTestValue, + ) + + @nexus.temporal_operation + async def by_name_multi_param( + self, + _ctx: nexus.TemporalStartOperationContext, + client: nexus.TemporalNexusClient, + input: TemporalOperationOverloadTestValue, + ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: + return await client.start_workflow( + "TemporalOperationOverloadTestWorkflow", + args=[input], + id=str(uuid.uuid4()), + result_type=TemporalOperationOverloadTestValue, + ) + + +@workflow.defn +class TemporalOperationOverloadTestCallerWorkflow: + @workflow.run + async def run( + self, op: str, input: TemporalOperationOverloadTestValue + ) -> TemporalOperationOverloadTestValue: + client = workflow.create_nexus_client( + service=TemporalOperationOverloadTestServiceHandler, + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), + ) + + if op == "no_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.no_param, input + ) + elif op == "single_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.single_param, input + ) + elif op == "multi_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.multi_param, input + ) + elif op == "by_name": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.by_name, input + ) + elif op == "by_name_multi_param": + return await client.execute_operation( + TemporalOperationOverloadTestServiceHandler.by_name_multi_param, input + ) + else: + raise ValueError(f"Unknown op: {op}") + + +@pytest.mark.parametrize( + "op", + [ + "no_param", + "single_param", + "multi_param", + "by_name", + "by_name_multi_param", + ], +) +async def test_temporal_operation_overloads( + client: Client, env: WorkflowEnvironment, op: str +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + async with Worker( + client, + task_queue=task_queue, + workflows=[ + TemporalOperationOverloadTestCallerWorkflow, + TemporalOperationOverloadTestWorkflow, + TemporalOperationOverloadTestWorkflowNoParam, + ], + nexus_service_handlers=[TemporalOperationOverloadTestServiceHandler()], + ): + result = await client.execute_workflow( + TemporalOperationOverloadTestCallerWorkflow.run, + args=[op, TemporalOperationOverloadTestValue(value=2)], + id=str(uuid.uuid4()), + task_queue=task_queue, + ) + assert result == ( + TemporalOperationOverloadTestValue(value=0) + if op == "no_param" + else TemporalOperationOverloadTestValue(value=4) + ) diff --git a/tests/nexus/test_workflow_caller_cancellation_types.py b/tests/nexus/test_workflow_caller_cancellation_types.py index 59a989d34..ad652e815 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types.py +++ b/tests/nexus/test_workflow_caller_cancellation_types.py @@ -20,7 +20,7 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers import assert_eventually +from tests.helpers import assert_event_subsequence, assert_eventually from tests.helpers.nexus import make_nexus_endpoint_name @@ -485,38 +485,3 @@ async def get_event_time( return event.event_time.ToDatetime().replace(tzinfo=timezone.utc) event_type_name = EventType.Name(event_type).removeprefix("EVENT_TYPE_") assert False, f"Event {event_type_name} not found in {wf_handle.id}" - - -async def assert_event_subsequence( - wf_handle: WorkflowHandle, - expected_events: list[EventType.ValueType], -) -> None: - """ - Given a workflow handle and a sequence of event types, assert that the workflow's history - contains that subsequence of events in the order specified. - """ - all_events = [] - async for e in wf_handle.fetch_history_events(): - all_events.append(e) - - _all_events = iter(all_events) - _expected_events = iter(expected_events) - - previous_expected_event_type_name = None - for expected_event_type in _expected_events: - expected_event_type_name = EventType.Name(expected_event_type).removeprefix( - "EVENT_TYPE_" - ) - has_expected = next( - (e for e in _all_events if e.event_type == expected_event_type), - None, - ) - if not has_expected: - if previous_expected_event_type_name is not None: - prefix = f"After {previous_expected_event_type_name}, " - else: - prefix = "" - pytest.fail( - f"{prefix}expected {expected_event_type_name} in workflow {wf_handle.id}" - ) - previous_expected_event_type_name = expected_event_type_name diff --git a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py index 3418e290f..4cdeeeb15 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py +++ b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py @@ -23,10 +23,9 @@ from temporalio.common import WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers import assert_eventually +from tests.helpers import assert_event_subsequence, assert_eventually from tests.helpers.nexus import make_nexus_endpoint_name from tests.nexus.test_workflow_caller_cancellation_types import ( - assert_event_subsequence, get_event_time, has_event, ) diff --git a/tests/test_workflow_exports.py b/tests/test_workflow_exports.py index e67040b64..8788addc5 100644 --- a/tests/test_workflow_exports.py +++ b/tests/test_workflow_exports.py @@ -46,7 +46,6 @@ "RootInfo", "SandboxImportNotificationPolicy", "SelfType", - "ServiceHandlerT", "UnfinishedSignalHandlersWarning", "UnfinishedUpdateHandlersWarning", "UpdateInfo", From 6cabb132420c292ac94cb2227b779acc2a46b63f Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Thu, 21 May 2026 15:03:52 -0700 Subject: [PATCH 4/7] Add overloads to sano client and increase type test coverage --- temporalio/client/_nexus.py | 58 ++++++++++++++++++++++ tests/nexus/test_nexus_type_errors.py | 70 +++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/temporalio/client/_nexus.py b/temporalio/client/_nexus.py index 060235e01..7eea155a9 100644 --- a/temporalio/client/_nexus.py +++ b/temporalio/client/_nexus.py @@ -611,6 +611,35 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... + # Overload for temporal_operation methods + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [ + NexusServiceType, + temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusClient, + InputT, + ], + Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]], + ], + arg: InputT, + *, + id: str, + id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + headers: Mapping[str, str] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> NexusOperationHandle[OutputT]: ... + @abstractmethod async def start_operation( self, @@ -804,6 +833,35 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... + # Overload for temporal_operation methods + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [ + NexusServiceType, + temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusClient, + InputT, + ], + Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]], + ], + arg: InputT, + *, + id: str, + id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + headers: Mapping[str, str] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> OutputT: ... + @abstractmethod async def execute_operation( self, diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index 0ea5b60ff..7b7e694cb 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -431,6 +431,10 @@ async def standalone_operation_type_tests(): MyNoInputService, endpoint="fake-endpoint", ) + handler_nexus_client = client.create_nexus_client( + MyServiceHandler, + endpoint="fake-endpoint", + ) # execute with an operation definition infers output type _op_defn_output: MyOutput = await nexus_client.execute_operation( @@ -455,6 +459,20 @@ async def standalone_operation_type_tests(): "my_sync_operation", MyInput(), id="op-1", result_type=MyOutput ) + # execute with workflow run handler infers output type + _workflow_run_output: MyOutput = await handler_nexus_client.execute_operation( + MyServiceHandler.my_workflow_run_operation, + MyInput(), + id="op-1", + ) + + # execute with temporal operation handler infers output type + _temporal_output: None = await handler_nexus_client.execute_operation( # type: ignore[func-returns-value] + MyServiceHandler.my_temporal_operation, + 0, + id="op-1", + ) + # omitting arg for string operation names is not supported # assert-type-error-pyright: 'No overloads for "execute_operation" match' await nexus_client.execute_operation( # type: ignore @@ -559,6 +577,58 @@ async def standalone_operation_type_tests(): await _str_op_result_type_handle.result() ) + # starting with workflow run handler infers output type on the handle + # and result from the handle + _workflow_run_handle: NexusOperationHandle[ + MyOutput + ] = await handler_nexus_client.start_operation( + MyServiceHandler.my_workflow_run_operation, + MyInput(), + id="op-1", + ) + + # starting with temporal operation handler infers output type on the handle + # and result from the handle + _workflow_run_handle_output: MyOutput = await _workflow_run_handle.result() + _temporal_handle: NexusOperationHandle[ + None + ] = await handler_nexus_client.start_operation( + MyServiceHandler.my_temporal_operation, + 0, + id="op-1", + ) + _temporal_handle_output: None = await _temporal_handle.result() # type: ignore[func-returns-value] + + # workflow run and temporal operation handlers reject wrong input types + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await handler_nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_workflow_run_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "arg"' + "wrong-input-type", # type: ignore + id="op-1", + ) + # assert-type-error-pyright: 'No overloads for "start_operation" match' + await handler_nexus_client.start_operation( # type: ignore + MyServiceHandler.my_workflow_run_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "arg"' + "wrong-input-type", # type: ignore + id="op-1", + ) + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + await handler_nexus_client.execute_operation( # type: ignore + MyServiceHandler.my_temporal_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "arg"' + "wrong-input-type", # type: ignore + id="op-1", + ) + # assert-type-error-pyright: 'No overloads for "start_operation" match' + await handler_nexus_client.start_operation( # type: ignore + MyServiceHandler.my_temporal_operation, # type: ignore[arg-type] + # assert-type-error-pyright: 'Argument of type .+ cannot be assigned to parameter "arg"' + "wrong-input-type", # type: ignore + id="op-1", + ) + # getting a handle with a string produces a handle to Any _str_op_handle: NexusOperationHandle[Any] = client.get_nexus_operation_handle( "op-1" From eb3c560701ef12067ae26199ca26f42fbf856730 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 22 May 2026 15:43:05 -0700 Subject: [PATCH 5/7] Expose customizable Temporal Nexus operation handlers. Rename the Temporal operation start context to TemporalNexusStartOperationContext and add TemporalNexusCancelOperationContext with access to the worker client. Make TemporalNexusOperationHandler a public abstract base with overrideable start_operation and cancel_workflow_run hooks, while keeping the decorator-backed implementation private. Update type annotations, exports, and tests, including coverage for custom Temporal operation cancellation. --- temporalio/client/_nexus.py | 4 +- temporalio/nexus/__init__.py | 8 +- temporalio/nexus/_decorators.py | 58 ++++++--- temporalio/nexus/_operation_context.py | 78 +++++++++--- temporalio/nexus/_operation_handlers.py | 113 ++++++++++++----- temporalio/nexus/_util.py | 11 +- temporalio/workflow/_nexus.py | 4 +- .../test_handler_operation_definitions.py | 4 +- tests/nexus/test_nexus_type_errors.py | 8 +- tests/nexus/test_temporal_operation.py | 119 ++++++++++++++---- 10 files changed, 303 insertions(+), 104 deletions(-) diff --git a/temporalio/client/_nexus.py b/temporalio/client/_nexus.py index 7eea155a9..8cd95b26f 100644 --- a/temporalio/client/_nexus.py +++ b/temporalio/client/_nexus.py @@ -619,7 +619,7 @@ async def start_operation( operation: Callable[ [ NexusServiceType, - temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusStartOperationContext, temporalio.nexus.TemporalNexusClient, InputT, ], @@ -841,7 +841,7 @@ async def execute_operation( operation: Callable[ [ NexusServiceType, - temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusStartOperationContext, temporalio.nexus.TemporalNexusClient, InputT, ], diff --git a/temporalio/nexus/__init__.py b/temporalio/nexus/__init__.py index d8129e601..5c3d3b7a8 100644 --- a/temporalio/nexus/__init__.py +++ b/temporalio/nexus/__init__.py @@ -8,7 +8,8 @@ Info, LoggerAdapter, NexusCallback, - TemporalStartOperationContext, + TemporalNexusCancelOperationContext, + TemporalNexusStartOperationContext, WorkflowRunOperationContext, client, in_operation, @@ -19,6 +20,7 @@ wait_for_worker_shutdown, wait_for_worker_shutdown_sync, ) +from ._operation_handlers import TemporalNexusOperationHandler from ._temporal_client import TemporalNexusClient, TemporalOperationResult from ._token import WorkflowHandle @@ -28,7 +30,8 @@ "LoggerAdapter", "NexusCallback", "WorkflowRunOperationContext", - "TemporalStartOperationContext", + "TemporalNexusCancelOperationContext", + "TemporalNexusStartOperationContext", "client", "in_operation", "info", @@ -39,6 +42,7 @@ "wait_for_worker_shutdown_sync", "WorkflowHandle", "TemporalNexusClient", + "TemporalNexusOperationHandler", "TemporalOperationResult", "temporal_operation", ) diff --git a/temporalio/nexus/_decorators.py b/temporalio/nexus/_decorators.py index 7bbd689f0..e139f2820 100644 --- a/temporalio/nexus/_decorators.py +++ b/temporalio/nexus/_decorators.py @@ -19,12 +19,12 @@ from temporalio.types import NexusServiceType from ._operation_context import ( - TemporalStartOperationContext, + TemporalNexusStartOperationContext, WorkflowRunOperationContext, ) from ._operation_handlers import ( - TemporalNexusOperationHandler, WorkflowRunOperationHandler, + _TemporalNexusOperationHandler, ) from ._token import WorkflowHandle from ._util import ( @@ -145,11 +145,16 @@ async def _start( @overload def temporal_operation( start: Callable[ - [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + [ + NexusServiceType, + TemporalNexusStartOperationContext, + TemporalNexusClient, + InputT, + ], Awaitable[TemporalOperationResult[OutputT]], ], ) -> Callable[ - [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + [NexusServiceType, TemporalNexusStartOperationContext, TemporalNexusClient, InputT], Awaitable[TemporalOperationResult[OutputT]], ]: ... @@ -163,7 +168,7 @@ def temporal_operation( Callable[ [ NexusServiceType, - TemporalStartOperationContext, + TemporalNexusStartOperationContext, TemporalNexusClient, InputT, ], @@ -171,7 +176,12 @@ def temporal_operation( ] ], Callable[ - [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + [ + NexusServiceType, + TemporalNexusStartOperationContext, + TemporalNexusClient, + InputT, + ], Awaitable[TemporalOperationResult[OutputT]], ], ]: ... @@ -183,7 +193,7 @@ def temporal_operation( Callable[ [ NexusServiceType, - TemporalStartOperationContext, + TemporalNexusStartOperationContext, TemporalNexusClient, InputT, ], @@ -194,7 +204,12 @@ def temporal_operation( name: str | None = None, ) -> ( Callable[ - [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + [ + NexusServiceType, + TemporalNexusStartOperationContext, + TemporalNexusClient, + InputT, + ], Awaitable[TemporalOperationResult[OutputT]], ] | Callable[ @@ -202,7 +217,7 @@ def temporal_operation( Callable[ [ NexusServiceType, - TemporalStartOperationContext, + TemporalNexusStartOperationContext, TemporalNexusClient, InputT, ], @@ -212,7 +227,7 @@ def temporal_operation( Callable[ [ NexusServiceType, - TemporalStartOperationContext, + TemporalNexusStartOperationContext, TemporalNexusClient, InputT, ], @@ -220,20 +235,29 @@ def temporal_operation( ], ] ): - """Decorator marking a method as the start method for an operation that interacts with Temporal.""" + """Decorator marking a method as the start method for an operation that interacts with Temporal. + + .. warning:: + This API is experimental and unstable. + """ def decorator( start: Callable[ [ NexusServiceType, - TemporalStartOperationContext, + TemporalNexusStartOperationContext, TemporalNexusClient, InputT, ], Awaitable[TemporalOperationResult[OutputT]], ], ) -> Callable[ - [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + [ + NexusServiceType, + TemporalNexusStartOperationContext, + TemporalNexusClient, + InputT, + ], Awaitable[TemporalOperationResult[OutputT]], ]: ( @@ -245,17 +269,19 @@ def operation_handler_factory( self: NexusServiceType, ) -> OperationHandler[InputT, OutputT]: async def _start( - ctx: StartOperationContext, client: TemporalNexusClient, input: InputT + ctx: TemporalNexusStartOperationContext, + client: TemporalNexusClient, + input: InputT, ) -> TemporalOperationResult[OutputT]: return await start( self, - TemporalStartOperationContext._from_start_operation_context(ctx), + ctx, client, input, ) _start.__doc__ = start.__doc__ - return TemporalNexusOperationHandler(_start) + return _TemporalNexusOperationHandler(_start) method_name = get_callable_name(start) op = nexusrpc.Operation( diff --git a/temporalio/nexus/_operation_context.py b/temporalio/nexus/_operation_context.py index b714b1cb8..6280a1cfb 100644 --- a/temporalio/nexus/_operation_context.py +++ b/temporalio/nexus/_operation_context.py @@ -280,26 +280,6 @@ def _add_outbound_links( return workflow_handle -class TemporalStartOperationContext(StartOperationContext): - """Context received by a Temporal operation.""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """Initialize the Temporal operation context.""" - super().__init__(*args, **kwargs) - self._temporal_context = _TemporalStartOperationContext.get() - - @classmethod - def _from_start_operation_context(cls, ctx: StartOperationContext) -> Self: - return cls( - **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, - ) - - @property - def metric_meter(self) -> temporalio.common.MetricMeter: - """The metric meter""" - return self._temporal_context.metric_meter - - class WorkflowRunOperationContext(StartOperationContext): """Context received by a workflow run operation.""" @@ -569,6 +549,64 @@ def set(self) -> None: _temporal_cancel_operation_context.set(self) +class TemporalNexusStartOperationContext(StartOperationContext): + """Start context received by a Temporal operation. + + .. warning:: + This API is experimental and unstable. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the Temporal operation context.""" + super().__init__(*args, **kwargs) + self._temporal_context = _TemporalStartOperationContext.get() + + @classmethod + def _from_start_operation_context(cls, ctx: StartOperationContext) -> Self: + return cls( + **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, + ) + + @property + def metric_meter(self) -> temporalio.common.MetricMeter: + """The metric meter""" + return self._temporal_context.metric_meter + + @property + def client(self) -> temporalio.client.Client: + """The Temporal client used by the worker handling the current Nexus operation.""" + return self._temporal_context.client + + +class TemporalNexusCancelOperationContext(CancelOperationContext): + """Cancellation context received by a Temporal operation. + + .. warning:: + This API is experimental and unstable. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the Temporal operation context.""" + super().__init__(*args, **kwargs) + self._temporal_context = _TemporalCancelOperationContext.get() + + @classmethod + def _from_cancel_operation_context(cls, ctx: CancelOperationContext) -> Self: + return cls( + **{f.name: getattr(ctx, f.name) for f in dataclasses.fields(ctx)}, + ) + + @property + def metric_meter(self) -> temporalio.common.MetricMeter: + """The metric meter""" + return self._temporal_context.metric_meter + + @property + def client(self) -> temporalio.client.Client: + """The Temporal client used by the worker handling the current Nexus operation.""" + return self._temporal_context.client + + class LoggerAdapter(logging.LoggerAdapter): """Logger adapter that adds Nexus operation context information.""" diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index ac0642b14..70ec2fee2 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from typing import Any @@ -16,8 +17,11 @@ StartOperationResultAsync, StartOperationResultSync, ) +from typing_extensions import override from temporalio.nexus._operation_context import ( + TemporalNexusCancelOperationContext, + TemporalNexusStartOperationContext, _temporal_cancel_operation_context, _TemporalCancelOperationContext, ) @@ -118,54 +122,105 @@ async def _cancel_workflow( await client_workflow_handle.cancel(**kwargs) -class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT]): - """Operation handler for Nexus operations that interact with Temporal.""" +class TemporalNexusOperationHandler(OperationHandler[InputT, OutputT], ABC): + """Operation handler for Nexus operations that interact with Temporal. + Implementations override the start_operation method. - def __init__( + .. warning:: + This API is experimental and unstable. + """ + + @abstractmethod + async def start_operation( self, - start: Callable[ - [StartOperationContext, TemporalNexusClient, InputT], - Awaitable[TemporalOperationResult[OutputT]], - ], - ) -> None: - """Initialize the Temporal operation handler.""" - if not is_async_callable(start): - raise RuntimeError( - f"{start} is not an `async def` method. " - "TemporalNexusOperationHandler must be initialized with an " - "`async def` start method." - ) - self._start = start - if start.__doc__: - if start_func := getattr(self.start, "__func__", None): - start_func.__doc__ = start.__doc__ + ctx: TemporalNexusStartOperationContext, + client: TemporalNexusClient, + input: InputT, + ) -> TemporalOperationResult[OutputT]: + """Start the Temporal-backed Nexus operation.""" + ... async def start( self, ctx: StartOperationContext, input: InputT ) -> StartOperationResultSync[OutputT] | StartOperationResultAsync: - """Start the Nexus operation using a Nexus-aware Temporal client.""" + """Start the Nexus operation using a Nexus-aware Temporal client. + + .. warning:: + This API is experimental and unstable. + """ nexus_client = TemporalNexusClient() - result = await self._start(ctx, nexus_client, input) + temporal_ctx = TemporalNexusStartOperationContext._from_start_operation_context( + ctx + ) + result = await self.start_operation(temporal_ctx, nexus_client, input) return result._to_nexus_result() async def cancel(self, ctx: CancelOperationContext, token: str) -> None: - """Cancel a Nexus operation using its operation token.""" - temporal_context = _TemporalCancelOperationContext.get() - client = temporal_context.client + """Cancel a Nexus operation using its operation token. + + .. warning:: + This API is experimental and unstable. + """ + cancel_ctx = TemporalNexusCancelOperationContext._from_cancel_operation_context( + ctx + ) operation_token = OperationToken.decode(token) - if client.namespace != operation_token.namespace: + if cancel_ctx.client.namespace != operation_token.namespace: raise ValueError( - f"Client namespace {client.namespace} does not match " + f"Client namespace {cancel_ctx.client.namespace} does not match " f"operation token namespace {operation_token.namespace}" ) match operation_token.type: case OperationTokenType.WORKFLOW: - await self.cancel_workflow_run(ctx, operation_token.workflow_id) + await self.cancel_workflow_run(cancel_ctx, operation_token.workflow_id) + + async def cancel_workflow_run( + self, _ctx: TemporalNexusCancelOperationContext, workflow_id: str + ): + """Cancels the workflow identified by workflow_id. - async def cancel_workflow_run(self, _ctx: CancelOperationContext, workflow_id: str): - """Cancels the workflow identified by workflow_id""" + .. warning:: + This API is experimental and unstable. + """ temporal_context = _TemporalCancelOperationContext.get() workflow_handle = temporal_context.client.get_workflow_handle(workflow_id) await workflow_handle.cancel() + + +class _TemporalNexusOperationHandler(TemporalNexusOperationHandler[InputT, OutputT]): # pyright: ignore[reportUnusedClass] + """Default implementation of TemporalNexusHandler that uses the provided callable + to start the Temporal operation. + + .. warning:: + This API is experimental and unstable. + """ + + def __init__( + self, + start: Callable[ + [TemporalNexusStartOperationContext, TemporalNexusClient, InputT], + Awaitable[TemporalOperationResult[OutputT]], + ], + ) -> None: + """Initialize the Temporal operation handler.""" + if not is_async_callable(start): + raise RuntimeError( + f"{start} is not an `async def` method. " + "TemporalNexusOperationHandler must be initialized with an " + "`async def` start method." + ) + self._start = start + if start.__doc__: + if start_func := getattr(self.start, "__func__", None): + start_func.__doc__ = start.__doc__ + + @override + async def start_operation( + self, + ctx: TemporalNexusStartOperationContext, + client: TemporalNexusClient, + input: InputT, + ) -> TemporalOperationResult[OutputT]: + return await self._start(ctx, client, input) diff --git a/temporalio/nexus/_util.py b/temporalio/nexus/_util.py index ef7ccf78f..737ff1684 100644 --- a/temporalio/nexus/_util.py +++ b/temporalio/nexus/_util.py @@ -16,7 +16,7 @@ ) from temporalio.nexus._operation_context import ( - TemporalStartOperationContext, + TemporalNexusStartOperationContext, WorkflowRunOperationContext, ) from temporalio.nexus._temporal_client import ( @@ -53,7 +53,12 @@ def get_workflow_run_start_method_input_and_output_type_annotations( def get_temporal_operation_start_method_input_and_output_type_annotations( start: Callable[ - [NexusServiceType, TemporalStartOperationContext, TemporalNexusClient, InputT], + [ + NexusServiceType, + TemporalNexusStartOperationContext, + TemporalNexusClient, + InputT, + ], Awaitable[TemporalOperationResult[OutputT]], ], ) -> tuple[ @@ -67,7 +72,7 @@ def get_temporal_operation_start_method_input_and_output_type_annotations( """ return _get_wrapped_start_method_input_and_output_type_annotations( start, - expected_param_types=(TemporalStartOperationContext, TemporalNexusClient), + expected_param_types=(TemporalNexusStartOperationContext, TemporalNexusClient), expected_return_origin=TemporalOperationResult, ) diff --git a/temporalio/workflow/_nexus.py b/temporalio/workflow/_nexus.py index 29bd10715..b8c8e88a1 100644 --- a/temporalio/workflow/_nexus.py +++ b/temporalio/workflow/_nexus.py @@ -218,7 +218,7 @@ async def start_operation( operation: Callable[ [ NexusServiceType, - temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusStartOperationContext, temporalio.nexus.TemporalNexusClient, InputT, ], @@ -390,7 +390,7 @@ async def execute_operation( operation: Callable[ [ NexusServiceType, - temporalio.nexus.TemporalStartOperationContext, + temporalio.nexus.TemporalNexusStartOperationContext, temporalio.nexus.TemporalNexusClient, InputT, ], diff --git a/tests/nexus/test_handler_operation_definitions.py b/tests/nexus/test_handler_operation_definitions.py index f47e3f47f..4a6e644b9 100644 --- a/tests/nexus/test_handler_operation_definitions.py +++ b/tests/nexus/test_handler_operation_definitions.py @@ -112,10 +112,10 @@ def test_unsafe_narrow_context_annotations_warn_and_drop_input_type(): with pytest.warns( UserWarning, - match="Expected parameter 1 .* TemporalStartOperationContext", + match="Expected parameter 1 .* TemporalNexusStartOperationContext", ): - class MyTemporalOpCtx(nexus.TemporalStartOperationContext): + class MyTemporalOpCtx(nexus.TemporalNexusStartOperationContext): def custom_method(self): raise NotImplementedError diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index 597d6e040..215c01f09 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -99,7 +99,7 @@ async def my_workflow_run_operation( @temporalio.nexus.temporal_operation async def my_temporal_operation( self, - _ctx: temporalio.nexus.TemporalStartOperationContext, + _ctx: temporalio.nexus.TemporalNexusStartOperationContext, client: temporalio.nexus.TemporalNexusClient, input: int, ) -> temporalio.nexus.TemporalOperationResult[None]: @@ -179,7 +179,7 @@ async def my_workflow_run_operation( @temporalio.nexus.temporal_operation async def my_temporal_operation( self, - _ctx: temporalio.nexus.TemporalStartOperationContext, + _ctx: temporalio.nexus.TemporalNexusStartOperationContext, _client: temporalio.nexus.TemporalNexusClient, _input: int, ) -> temporalio.nexus.TemporalOperationResult[None]: @@ -203,7 +203,7 @@ async def my_workflow_run_operation( @temporalio.nexus.temporal_operation async def my_temporal_operation( self, - _ctx: temporalio.nexus.TemporalStartOperationContext, + _ctx: temporalio.nexus.TemporalNexusStartOperationContext, _client: temporalio.nexus.TemporalNexusClient, _input: int, ) -> temporalio.nexus.TemporalOperationResult[None]: @@ -214,7 +214,7 @@ class MyUnsafeContextAnnotationServiceHandler: # A temporal operation receives TemporalStartOperationContext at runtime, so # requiring an arbitrary user subclass is not safe. class MyCustomTemporalStartOperationContext( - temporalio.nexus.TemporalStartOperationContext + temporalio.nexus.TemporalNexusStartOperationContext ): def custom_state(self) -> str: raise NotImplementedError diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index 8ac79b28b..2781785f6 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -5,17 +5,16 @@ import nexusrpc import pytest from nexusrpc import HandlerErrorType, Operation, service -from nexusrpc.handler import ( - service_handler, -) +from nexusrpc.handler import operation_handler, service_handler +from typing_extensions import override import temporalio.exceptions from temporalio import nexus, workflow -from temporalio.client import Client, WorkflowFailureError -from temporalio.common import WorkflowIDConflictPolicy +from temporalio.client import Client, WorkflowExecutionStatus, WorkflowFailureError +from temporalio.common import NexusOperationExecutionStatus, WorkflowIDConflictPolicy from temporalio.testing import WorkflowEnvironment from temporalio.worker import Worker -from tests.helpers import EventType, assert_event_subsequence +from tests.helpers import EventType, assert_event_subsequence, assert_eventually from tests.helpers.nexus import make_nexus_endpoint_name @@ -40,14 +39,21 @@ class TestService: concurrent_start: Operation[Input, str] retry_after_failed_start: Operation[Input, str] sync_result: Operation[Input, str] + custom_cancel: Operation[str, None] @service_handler(service=TestService) -class EchoServiceHandler: +class TestServiceHandler: + # tell Pytest this is not a test class + __test__ = False + + def __init__(self) -> None: + self.started_custom_cancel_workflow = asyncio.Event() + @nexus.temporal_operation async def echo( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: Input, ) -> nexus.TemporalOperationResult[str]: @@ -58,7 +64,7 @@ async def echo( @nexus.temporal_operation async def blocking( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, _input: None, ) -> nexus.TemporalOperationResult[None]: @@ -69,7 +75,7 @@ async def blocking( @nexus.temporal_operation async def double_start( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: Input, ) -> nexus.TemporalOperationResult[None]: @@ -84,7 +90,7 @@ async def double_start( @nexus.temporal_operation async def concurrent_start( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: Input, ) -> nexus.TemporalOperationResult[str]: @@ -129,7 +135,7 @@ async def concurrent_start( @nexus.temporal_operation async def retry_after_failed_start( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: Input, ) -> nexus.TemporalOperationResult[str]: @@ -151,12 +157,42 @@ async def retry_after_failed_start( @nexus.temporal_operation async def sync_result( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, _client: nexus.TemporalNexusClient, input: Input, ) -> nexus.TemporalOperationResult[str]: return nexus.TemporalOperationResult.sync(input.value) + @operation_handler + def custom_cancel(self) -> nexus.TemporalNexusOperationHandler[str, None]: + event = self.started_custom_cancel_workflow + + class CustomCancelNexusOpHandler( + nexus.TemporalNexusOperationHandler[str, None] + ): + @override + async def start_operation( + self, + ctx: nexus.TemporalNexusStartOperationContext, + client: nexus.TemporalNexusClient, + input: str, + ) -> nexus.TemporalOperationResult[None]: + result = await client.start_workflow(BlockingWorkflow.run, id=input) + event.set() + return result + + @override + async def cancel_workflow_run( + self, ctx: nexus.TemporalNexusCancelOperationContext, workflow_id: str + ): + # get a handle to the workflow + handle = ctx.client.get_workflow_handle(workflow_id) + + # cancel the workflow + await handle.cancel() + + return CustomCancelNexusOpHandler() + @workflow.defn class EchoWorkflowCaller: @@ -177,7 +213,7 @@ async def test_temporal_operation_start_workflow( async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[EchoServiceHandler()], + nexus_service_handlers=[TestServiceHandler()], workflows=[EchoWorkflow, EchoWorkflowCaller], ): wf_handle = await client.start_workflow( @@ -239,7 +275,7 @@ async def test_temporal_operation_cancel_workflow( async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[EchoServiceHandler()], + nexus_service_handlers=[TestServiceHandler()], workflows=[BlockingWorkflow, CancelBlockingWorkflowCaller], ): wf_handle = await client.start_workflow( @@ -265,6 +301,41 @@ async def test_temporal_operation_cancel_workflow( ) +async def test_customized_temporal_operation_cancel_workflow( + client: Client, env: WorkflowEnvironment +): + task_queue = str(uuid.uuid4()) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) + + service_handler = TestServiceHandler() + async with Worker( + env.client, + task_queue=task_queue, + nexus_service_handlers=[service_handler], + workflows=[BlockingWorkflow, CancelBlockingWorkflowCaller], + ): + nexus_client = client.create_nexus_client(TestService, endpoint_name) + + wf_id = f"custom-cancel-{uuid.uuid4()}" + op_handle = await nexus_client.start_operation( + TestService.custom_cancel, wf_id, id=str(uuid.uuid4()) + ) + + await service_handler.started_custom_cancel_workflow.wait() + + await op_handle.cancel() + + async def check_cancelled(): + wf_handle = client.get_workflow_handle(wf_id) + wf_desc = await wf_handle.describe() + assert wf_desc.status is WorkflowExecutionStatus.CANCELED + op_desc = await op_handle.describe() + assert op_desc.status is NexusOperationExecutionStatus.CANCELED + + await assert_eventually(check_cancelled) + + @workflow.defn class DoubleStartWorkflowCaller: @workflow.run @@ -308,7 +379,7 @@ async def test_temporal_operation_double_start_raises_handler_err( async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[EchoServiceHandler()], + nexus_service_handlers=[TestServiceHandler()], workflows=[EchoWorkflow, DoubleStartWorkflowCaller], ): with pytest.raises(WorkflowFailureError) as err: @@ -337,7 +408,7 @@ async def test_temporal_operation_concurrent_start_raises_handler_err( async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[EchoServiceHandler()], + nexus_service_handlers=[TestServiceHandler()], workflows=[EchoWorkflow, ConcurrentStartWorkflowCaller], ): result = await client.execute_workflow( @@ -360,7 +431,7 @@ async def test_temporal_operation_failed_start_allows_retry( async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[EchoServiceHandler()], + nexus_service_handlers=[TestServiceHandler()], workflows=[ BlockingWorkflow, EchoWorkflow, @@ -402,7 +473,7 @@ async def test_temporal_operation_sync_result(client: Client, env: WorkflowEnvir async with Worker( env.client, task_queue=task_queue, - nexus_service_handlers=[EchoServiceHandler()], + nexus_service_handlers=[TestServiceHandler()], workflows=[SyncResultCaller], ): wf_handle = await client.start_workflow( @@ -450,7 +521,7 @@ class TemporalOperationOverloadTestServiceHandler: @nexus.temporal_operation async def no_param( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, _input: TemporalOperationOverloadTestValue, ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: @@ -462,7 +533,7 @@ async def no_param( @nexus.temporal_operation async def single_param( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: TemporalOperationOverloadTestValue, ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: @@ -475,7 +546,7 @@ async def single_param( @nexus.temporal_operation async def multi_param( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: TemporalOperationOverloadTestValue, ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: @@ -488,7 +559,7 @@ async def multi_param( @nexus.temporal_operation async def by_name( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: TemporalOperationOverloadTestValue, ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: @@ -502,7 +573,7 @@ async def by_name( @nexus.temporal_operation async def by_name_multi_param( self, - _ctx: nexus.TemporalStartOperationContext, + _ctx: nexus.TemporalNexusStartOperationContext, client: nexus.TemporalNexusClient, input: TemporalOperationOverloadTestValue, ) -> nexus.TemporalOperationResult[TemporalOperationOverloadTestValue]: From e8905f380df58c8e49e7a1bf3953b0b0c3835660 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 22 May 2026 16:06:52 -0700 Subject: [PATCH 6/7] Add time-skippping check in test that leverages sano --- tests/nexus/test_temporal_operation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index 2781785f6..b93c38677 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -304,6 +304,11 @@ async def test_temporal_operation_cancel_workflow( async def test_customized_temporal_operation_cancel_workflow( client: Client, env: WorkflowEnvironment ): + if env.supports_time_skipping: + pytest.skip( + "Standalone Nexus Operation tests don't work with time-skipping server" + ) + task_queue = str(uuid.uuid4()) endpoint_name = make_nexus_endpoint_name(task_queue) await env.create_nexus_endpoint(endpoint_name, task_queue) From 16f2af101f2ff288ed6c004f47d714071f9307c1 Mon Sep 17 00:00:00 2001 From: Alex Mazzeo Date: Fri, 22 May 2026 16:37:35 -0700 Subject: [PATCH 7/7] Address claude/codex review suggestions. Make TemporalNexusClient ABC to prevent users from instantiating. --- temporalio/nexus/_operation_handlers.py | 23 +++-- temporalio/nexus/_temporal_client.py | 117 +++++++++++++++++++----- tests/nexus/test_temporal_operation.py | 3 +- 3 files changed, 109 insertions(+), 34 deletions(-) diff --git a/temporalio/nexus/_operation_handlers.py b/temporalio/nexus/_operation_handlers.py index 70ec2fee2..3bad90874 100644 --- a/temporalio/nexus/_operation_handlers.py +++ b/temporalio/nexus/_operation_handlers.py @@ -23,11 +23,11 @@ TemporalNexusCancelOperationContext, TemporalNexusStartOperationContext, _temporal_cancel_operation_context, - _TemporalCancelOperationContext, ) from temporalio.nexus._temporal_client import ( TemporalNexusClient, TemporalOperationResult, + _TemporalNexusClient, ) from temporalio.nexus._token import OperationToken, OperationTokenType, WorkflowHandle @@ -148,7 +148,7 @@ async def start( .. warning:: This API is experimental and unstable. """ - nexus_client = TemporalNexusClient() + nexus_client = _TemporalNexusClient() temporal_ctx = TemporalNexusStartOperationContext._from_start_operation_context( ctx ) @@ -165,11 +165,19 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: ctx ) - operation_token = OperationToken.decode(token) + try: + operation_token = OperationToken.decode(token) + except Exception as err: + raise HandlerError( + "Unable to decode operation token to cancel", + type=HandlerErrorType.INTERNAL, + ) from err + if cancel_ctx.client.namespace != operation_token.namespace: - raise ValueError( + raise HandlerError( f"Client namespace {cancel_ctx.client.namespace} does not match " - f"operation token namespace {operation_token.namespace}" + f"operation token namespace {operation_token.namespace}", + type=HandlerErrorType.BAD_REQUEST, ) match operation_token.type: @@ -177,15 +185,14 @@ async def cancel(self, ctx: CancelOperationContext, token: str) -> None: await self.cancel_workflow_run(cancel_ctx, operation_token.workflow_id) async def cancel_workflow_run( - self, _ctx: TemporalNexusCancelOperationContext, workflow_id: str + self, ctx: TemporalNexusCancelOperationContext, workflow_id: str ): """Cancels the workflow identified by workflow_id. .. warning:: This API is experimental and unstable. """ - temporal_context = _TemporalCancelOperationContext.get() - workflow_handle = temporal_context.client.get_workflow_handle(workflow_id) + workflow_handle = ctx.client.get_workflow_handle(workflow_id) await workflow_handle.cancel() diff --git a/temporalio/nexus/_temporal_client.py b/temporalio/nexus/_temporal_client.py index 372d779ca..0505d2ad4 100644 --- a/temporalio/nexus/_temporal_client.py +++ b/temporalio/nexus/_temporal_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable, Iterator, Mapping, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -41,13 +42,17 @@ @dataclass(frozen=True) class TemporalOperationResult(Generic[_ResultT]): - """Unified result: sync value or async token.""" + """Unified result: sync value or async token. + + .. warning:: + This API is experimental and unstable. + """ value: _ResultT | object = temporalio.common._arg_unset token: str | None = None @classmethod - def sync(cls, value: _ResultT) -> "TemporalOperationResult[_ResultT]": + def sync(cls, value: _ResultT) -> Self: """Create a result that completes the Nexus operation synchronously.""" return cls(value=value) @@ -69,34 +74,22 @@ def _to_nexus_result( ) -class TemporalNexusClient: - """Nexus-aware wrapper around a Temporal Client.""" +class TemporalNexusClient(ABC): + """Nexus-aware wrapper around a Temporal Client. - def __init__(self) -> None: - """Initialize the client wrapper from the active Nexus operation context.""" - self._temporal_context = _TemporalStartOperationContext.get() - self._started_async = False + .. warning:: + This API is experimental and unstable. + """ @property + @abstractmethod def client(self) -> temporalio.client.Client: - """Return the Temporal client for the active Nexus operation.""" - return self._temporal_context.client - - @contextmanager - def _reserve_async_start(self) -> Iterator[None]: - if self._started_async: - raise HandlerError( - "Only one async operation can be started per operation handler invocation. Use TemporalNexusClient.client for additional workflow interactions", - type=HandlerErrorType.BAD_REQUEST, - ) + """The underlying Temporal Client - # Reserve the started flag before sending to prevent concurrent starts - self._started_async = True - try: - yield - except BaseException: - self._started_async = False - raise + .. warning:: + This API is experimental and unstable. + """ + ... # Overload for no-param workflow @overload @@ -233,6 +226,80 @@ async def start_workflow( versioning_override: temporalio.common.VersioningOverride | None = None, ) -> TemporalOperationResult[ReturnType]: ... + @abstractmethod + async def start_workflow( + self, + workflow: str | Callable[..., Awaitable[ReturnType]], + arg: Any = temporalio.common._arg_unset, + *, + args: Sequence[Any] = [], + id: str, + task_queue: str | None = None, + result_type: type | None = None, + execution_timeout: timedelta | None = None, + run_timeout: timedelta | None = None, + task_timeout: timedelta | None = None, + id_reuse_policy: temporalio.common.WorkflowIDReusePolicy = temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.WorkflowIDConflictPolicy = temporalio.common.WorkflowIDConflictPolicy.UNSPECIFIED, + retry_policy: temporalio.common.RetryPolicy | None = None, + cron_schedule: str = "", + memo: Mapping[str, Any] | None = None, + search_attributes: None + | ( + temporalio.common.TypedSearchAttributes | temporalio.common.SearchAttributes + ) = None, + static_summary: str | None = None, + static_details: str | None = None, + start_delay: timedelta | None = None, + start_signal: str | None = None, + start_signal_args: Sequence[Any] = [], + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + request_eager_start: bool = False, + priority: temporalio.common.Priority = temporalio.common.Priority.default, + versioning_override: temporalio.common.VersioningOverride | None = None, + ) -> TemporalOperationResult[ReturnType]: + """Start a workflow as the backing asynchronous Nexus operation. + + .. warning:: + This API is experimental and unstable. + """ + ... + + +class _TemporalNexusClient(TemporalNexusClient): # pyright: ignore[reportUnusedClass] + """Nexus-aware wrapper around a Temporal Client. + + .. warning:: + This API is experimental and unstable. + """ + + def __init__(self) -> None: + """Initialize the client wrapper from the active Nexus operation context.""" + self._temporal_context = _TemporalStartOperationContext.get() + self._started_async = False + + @property + def client(self) -> temporalio.client.Client: + """Return the Temporal client for the active Nexus operation.""" + return self._temporal_context.client + + @contextmanager + def _reserve_async_start(self) -> Iterator[None]: + if self._started_async: + raise HandlerError( + "Only one async operation can be started per operation handler invocation. Use TemporalNexusClient.client for additional workflow interactions", + type=HandlerErrorType.BAD_REQUEST, + ) + + # Reserve the started flag before sending to prevent concurrent starts + self._started_async = True + try: + yield + except BaseException: + self._started_async = False + raise + async def start_workflow( self, workflow: str | Callable[..., Awaitable[ReturnType]], diff --git a/tests/nexus/test_temporal_operation.py b/tests/nexus/test_temporal_operation.py index b93c38677..57a3b58d4 100644 --- a/tests/nexus/test_temporal_operation.py +++ b/tests/nexus/test_temporal_operation.py @@ -237,7 +237,8 @@ async def test_temporal_operation_start_workflow( @workflow.defn class BlockingWorkflow: - done: bool = False + def __init__(self) -> None: + self.done: bool = False @workflow.run async def run(self) -> None: