diff --git a/temporalio/client/_nexus.py b/temporalio/client/_nexus.py index 991ab34a3..060235e01 100644 --- a/temporalio/client/_nexus.py +++ b/temporalio/client/_nexus.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Generic, cast, overload @@ -473,7 +473,7 @@ class NexusClient(ABC, Generic[NexusServiceType]): Use :py:meth:`Client.create_nexus_client` to create a client. """ - # Overload for nexusrpc.Operation with input + # Overload for nexusrpc.Operation @overload @abstractmethod async def start_operation( @@ -494,18 +494,18 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... - # Overload for Callable with result_type + # Overload for string operation name @overload @abstractmethod async def start_operation( self, - operation: Callable[..., Any], + operation: str, arg: Any, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -516,13 +516,16 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... - # Overload for Callable without result_type + # Overload for workflow_run_operation methods @overload @abstractmethod async def start_operation( self, - operation: Callable[..., Any], - arg: Any, + operation: Callable[ + [NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT], + Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, @@ -535,20 +538,22 @@ async def start_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> NexusOperationHandle[Any]: ... + ) -> NexusOperationHandle[OutputT]: ... - # Overload for str with result_type + # Overload for sync_operation methods (async def) @overload @abstractmethod async def start_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + Awaitable[OutputT], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -559,13 +564,16 @@ async def start_operation( rpc_timeout: timedelta | None = None, ) -> NexusOperationHandle[OutputT]: ... - # Overload for str without result_type + # Overload for sync_operation methods (def) @overload @abstractmethod async def start_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + OutputT, + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, @@ -578,7 +586,30 @@ async def start_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> NexusOperationHandle[Any]: ... + ) -> NexusOperationHandle[OutputT]: ... + + # Overload for operation_handler + @overload + @abstractmethod + async def start_operation( + self, + operation: Callable[ + [NexusServiceType], nexusrpc.handler.OperationHandler[InputT, OutputT] + ], + arg: InputT, + *, + id: str, + id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + headers: Mapping[str, str] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> NexusOperationHandle[OutputT]: ... @abstractmethod async def start_operation( @@ -611,7 +642,8 @@ async def start_operation( id: Unique identifier for this operation. id_reuse_policy: Policy for reusing operation IDs. id_conflict_policy: Policy for handling ID conflicts. - result_type: The result type to deserialize into. + result_type: For string operation names, this can set the specific + result type hint to deserialize into. schedule_to_close_timeout: End-to-end timeout for the Nexus operation. If unset, defaults to the maximum allowed by the Temporal server. @@ -633,7 +665,7 @@ async def start_operation( """ ... - # Overload for nexusrpc.Operation with input + # Overload for nexusrpc.Operation @overload @abstractmethod async def execute_operation( @@ -654,18 +686,18 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... - # Overload for Callable with result_type + # Overload for string operation name @overload @abstractmethod async def execute_operation( self, - operation: Callable[..., Any], + operation: str, arg: Any, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], + result_type: type[OutputT] | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -676,13 +708,16 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... - # Overload for Callable without result_type + # Overload for workflow_run_operation methods @overload @abstractmethod async def execute_operation( self, - operation: Callable[..., Any], - arg: Any, + operation: Callable[ + [NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT], + Awaitable[temporalio.nexus.WorkflowHandle[OutputT]], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, @@ -695,20 +730,22 @@ async def execute_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> Any: ... + ) -> OutputT: ... - # Overload for str with result_type + # Overload for sync_operation methods (async def) @overload @abstractmethod async def execute_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + Awaitable[OutputT], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, - result_type: type[OutputT], schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, start_to_close_timeout: timedelta | None = None, @@ -719,13 +756,40 @@ async def execute_operation( rpc_timeout: timedelta | None = None, ) -> OutputT: ... - # Overload for str without result_type + # Overload for sync_operation methods (async def) @overload @abstractmethod async def execute_operation( self, - operation: str, - arg: Any, + operation: Callable[ + [NexusServiceType, nexusrpc.handler.StartOperationContext, InputT], + OutputT, + ], + arg: InputT, + *, + id: str, + id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, + id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + search_attributes: temporalio.common.TypedSearchAttributes | None = None, + summary: str | None = None, + headers: Mapping[str, str] | None = None, + rpc_metadata: Mapping[str, str | bytes] = {}, + rpc_timeout: timedelta | None = None, + ) -> OutputT: ... + + # Overload for operation_handler + @overload + @abstractmethod + async def execute_operation( + self, + operation: Callable[ + [NexusServiceType], + nexusrpc.handler.OperationHandler[InputT, OutputT], + ], + arg: InputT, *, id: str, id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE, @@ -738,7 +802,7 @@ async def execute_operation( headers: Mapping[str, str] | None = None, rpc_metadata: Mapping[str, str | bytes] = {}, rpc_timeout: timedelta | None = None, - ) -> Any: ... + ) -> OutputT: ... @abstractmethod async def execute_operation( @@ -773,7 +837,8 @@ async def execute_operation( id: Unique identifier for this operation. id_reuse_policy: Policy for reusing operation IDs. id_conflict_policy: Policy for handling ID conflicts. - result_type: The result type to deserialize into. + result_type: For string operation names, this can set the specific + result type hint to deserialize into. schedule_to_close_timeout: End-to-end timeout for the Nexus operation. If unset, defaults to the maximum allowed by the Temporal server. @@ -860,7 +925,9 @@ async def start_operation( This API is experimental and unstable. """ op_name, output_type = self._resolve_operation(operation) - final_result_type: type | None = result_type or output_type + final_result_type: type | None = ( + result_type if isinstance(operation, str) else output_type + ) return await self._client._impl.start_nexus_operation( StartNexusOperationInput( diff --git a/tests/nexus/test_nexus_type_errors.py b/tests/nexus/test_nexus_type_errors.py index c669f8a5b..f97aeae42 100644 --- a/tests/nexus/test_nexus_type_errors.py +++ b/tests/nexus/test_nexus_type_errors.py @@ -238,14 +238,13 @@ async def standalone_operation_type_tests(): start_to_close_timeout=timedelta(seconds=2), ) - # result_type overrides output type from operation definition - # conflicting result_type and annotation on variable cause type error - # assert-type-error-pyright: 'Type "str" is not assignable to declared type "MyOutput"' - _bad_result_type_output: MyOutput = await nexus_client.execute_operation( # type: ignore - MyServiceHandler.my_sync_operation, + # result_type is not allowed when an operation is provided + await nexus_client.execute_operation( + # assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"' + MyService.my_sync_operation, # type: ignore MyInput(), id="op-1", - result_type=str, # type: ignore + result_type=str, ) # string operation name and result_type infers output type @@ -337,19 +336,14 @@ async def standalone_operation_type_tests(): ) _defn_handle_output: MyOutput = await _defn_handle.result() - # result_type overrides output type from operation definition - # conflicting result_type and annotation on variable cause type error - _result_type_handle: NexusOperationHandle[ - MyOutput - # assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"' - ] = await nexus_client.start_operation( # type: ignore - MyServiceHandler.my_sync_operation, + # result_type is not allowed when an operation is provided + await nexus_client.start_operation( + # assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"' + MyServiceHandler.my_sync_operation, # type: ignore MyInput(), id="op-1", - result_type=str, # type: ignore + result_type=str, ) - # handle still respects type declaration on the variable - _result_type_handle_output: MyOutput = await _result_type_handle.result() # starting with string operation name and result_type infers output type on the handle # and result from the handle @@ -389,7 +383,7 @@ async def standalone_operation_type_tests(): ) ) - # mismatched types on get_nexus_operation_handle produces type error + # mismatched types on get_nexus_operation_handle produce a type error # assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"' _mismatch_handle: NexusOperationHandle[MyOutput] = ( client.get_nexus_operation_handle( # type: ignore @@ -397,3 +391,22 @@ async def standalone_operation_type_tests(): result_type=str, # type: ignore ) ) + + # functions with invalid signatures produce a type error + class InvalidServiceHandler: + async def invalid(self, _ctx: str, _input: str) -> str: + raise NotImplementedError() + + # assert-type-error-pyright: 'No overloads for "start_operation" match' + _invalid_handle: NexusOperationHandle[str] = await nexus_client.start_operation( + InvalidServiceHandler.invalid, # type: ignore + "foo", + id="invalid", + ) + + # assert-type-error-pyright: 'No overloads for "execute_operation" match' + _invalid_result: str = await nexus_client.execute_operation( + InvalidServiceHandler.invalid, # type: ignore + "foo", + id="invalid", + )