Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 103 additions & 36 deletions temporalio/client/_nexus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping, Sequence
from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Generic, cast, overload
Expand Down Expand Up @@ -473,7 +473,7 @@ class NexusClient(ABC, Generic[NexusServiceType]):
Use :py:meth:`Client.create_nexus_client` to create a client.
"""

# Overload for nexusrpc.Operation with input
# Overload for nexusrpc.Operation
@overload
@abstractmethod
async def start_operation(
Expand All @@ -494,18 +494,18 @@ async def start_operation(
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for Callable with result_type
# Overload for string operation name
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[..., Any],
operation: str,
arg: Any,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
result_type: type[OutputT],
result_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
Expand All @@ -516,13 +516,16 @@ async def start_operation(
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for Callable without result_type
# Overload for workflow_run_operation methods
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[..., Any],
arg: Any,
operation: Callable[
[NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT],
Awaitable[temporalio.nexus.WorkflowHandle[OutputT]],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
Expand All @@ -535,20 +538,22 @@ async def start_operation(
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[Any]: ...
) -> NexusOperationHandle[OutputT]: ...

# Overload for str with result_type
# Overload for sync_operation methods (async def)
@overload
@abstractmethod
async def start_operation(
self,
operation: str,
arg: Any,
operation: Callable[
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
Awaitable[OutputT],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
result_type: type[OutputT],
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
Expand All @@ -559,13 +564,16 @@ async def start_operation(
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for str without result_type
# Overload for sync_operation methods (def)
@overload
@abstractmethod
async def start_operation(
self,
operation: str,
arg: Any,
operation: Callable[
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
OutputT,
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
Expand All @@ -578,7 +586,30 @@ async def start_operation(
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[Any]: ...
) -> NexusOperationHandle[OutputT]: ...

# Overload for operation_handler
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[
[NexusServiceType], nexusrpc.handler.OperationHandler[InputT, OutputT]
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
summary: str | None = None,
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

@abstractmethod
async def start_operation(
Expand Down Expand Up @@ -611,7 +642,8 @@ async def start_operation(
id: Unique identifier for this operation.
id_reuse_policy: Policy for reusing operation IDs.
id_conflict_policy: Policy for handling ID conflicts.
result_type: The result type to deserialize into.
result_type: For string operation names, this can set the specific
result type hint to deserialize into.
schedule_to_close_timeout: End-to-end timeout for the Nexus
operation. If unset, defaults to the maximum allowed by the
Temporal server.
Expand All @@ -633,7 +665,7 @@ async def start_operation(
"""
...

# Overload for nexusrpc.Operation with input
# Overload for nexusrpc.Operation
@overload
@abstractmethod
async def execute_operation(
Expand All @@ -654,18 +686,18 @@ async def execute_operation(
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

# Overload for Callable with result_type
# Overload for string operation name
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[..., Any],
operation: str,
arg: Any,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
result_type: type[OutputT],
result_type: type[OutputT] | None = None,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
Expand All @@ -676,13 +708,16 @@ async def execute_operation(
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

# Overload for Callable without result_type
# Overload for workflow_run_operation methods
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[..., Any],
arg: Any,
operation: Callable[
[NexusServiceType, temporalio.nexus.WorkflowRunOperationContext, InputT],
Awaitable[temporalio.nexus.WorkflowHandle[OutputT]],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
Expand All @@ -695,20 +730,22 @@ async def execute_operation(
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> Any: ...
) -> OutputT: ...

# Overload for str with result_type
# Overload for sync_operation methods (async def)
@overload
@abstractmethod
async def execute_operation(
self,
operation: str,
arg: Any,
operation: Callable[
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
Awaitable[OutputT],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
result_type: type[OutputT],
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
Expand All @@ -719,13 +756,40 @@ async def execute_operation(
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

# Overload for str without result_type
# Overload for sync_operation methods (async def)
@overload
@abstractmethod
async def execute_operation(
self,
operation: str,
arg: Any,
operation: Callable[
[NexusServiceType, nexusrpc.handler.StartOperationContext, InputT],
OutputT,
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
summary: str | None = None,
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

# Overload for operation_handler
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[
[NexusServiceType],
nexusrpc.handler.OperationHandler[InputT, OutputT],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
Expand All @@ -738,7 +802,7 @@ async def execute_operation(
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> Any: ...
) -> OutputT: ...

@abstractmethod
async def execute_operation(
Expand Down Expand Up @@ -773,7 +837,8 @@ async def execute_operation(
id: Unique identifier for this operation.
id_reuse_policy: Policy for reusing operation IDs.
id_conflict_policy: Policy for handling ID conflicts.
result_type: The result type to deserialize into.
result_type: For string operation names, this can set the specific
result type hint to deserialize into.
schedule_to_close_timeout: End-to-end timeout for the Nexus
operation. If unset, defaults to the maximum allowed by the
Temporal server.
Expand Down Expand Up @@ -860,7 +925,9 @@ async def start_operation(
This API is experimental and unstable.
"""
op_name, output_type = self._resolve_operation(operation)
final_result_type: type | None = result_type or output_type
final_result_type: type | None = (
result_type if isinstance(operation, str) else output_type
)

return await self._client._impl.start_nexus_operation(
StartNexusOperationInput(
Expand Down
47 changes: 30 additions & 17 deletions tests/nexus/test_nexus_type_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,14 +238,13 @@ async def standalone_operation_type_tests():
start_to_close_timeout=timedelta(seconds=2),
)

# result_type overrides output type from operation definition
# conflicting result_type and annotation on variable cause type error
# assert-type-error-pyright: 'Type "str" is not assignable to declared type "MyOutput"'
_bad_result_type_output: MyOutput = await nexus_client.execute_operation( # type: ignore
MyServiceHandler.my_sync_operation,
# result_type is not allowed when an operation is provided
await nexus_client.execute_operation(
# assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"'
MyService.my_sync_operation, # type: ignore
MyInput(),
id="op-1",
result_type=str, # type: ignore
result_type=str,
)

# string operation name and result_type infers output type
Expand Down Expand Up @@ -337,19 +336,14 @@ async def standalone_operation_type_tests():
)
_defn_handle_output: MyOutput = await _defn_handle.result()

# result_type overrides output type from operation definition
# conflicting result_type and annotation on variable cause type error
_result_type_handle: NexusOperationHandle[
MyOutput
# assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"'
] = await nexus_client.start_operation( # type: ignore
MyServiceHandler.my_sync_operation,
# result_type is not allowed when an operation is provided
await nexus_client.start_operation(
# assert-type-error-pyright: 'cannot be assigned to parameter "operation" of type "str"'
MyServiceHandler.my_sync_operation, # type: ignore
MyInput(),
id="op-1",
result_type=str, # type: ignore
result_type=str,
)
# handle still respects type declaration on the variable
_result_type_handle_output: MyOutput = await _result_type_handle.result()

# starting with string operation name and result_type infers output type on the handle
# and result from the handle
Expand Down Expand Up @@ -389,11 +383,30 @@ async def standalone_operation_type_tests():
)
)

# mismatched types on get_nexus_operation_handle produces type error
# mismatched types on get_nexus_operation_handle produce a type error
# assert-type-error-pyright: 'Type "NexusOperationHandle\[str\]" is not assignable to declared type "NexusOperationHandle\[MyOutput\]"'
_mismatch_handle: NexusOperationHandle[MyOutput] = (
client.get_nexus_operation_handle( # type: ignore
"op-1",
result_type=str, # type: ignore
)
)

# functions with invalid signatures produce a type error
class InvalidServiceHandler:
async def invalid(self, _ctx: str, _input: str) -> str:
raise NotImplementedError()

# assert-type-error-pyright: 'No overloads for "start_operation" match'
_invalid_handle: NexusOperationHandle[str] = await nexus_client.start_operation(
InvalidServiceHandler.invalid, # type: ignore
"foo",
id="invalid",
)

# assert-type-error-pyright: 'No overloads for "execute_operation" match'
_invalid_result: str = await nexus_client.execute_operation(
InvalidServiceHandler.invalid, # type: ignore
"foo",
id="invalid",
)
Loading