Skip to content
58 changes: 58 additions & 0 deletions temporalio/client/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,35 @@ async def start_operation(
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for temporal_operation methods
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
summary: str | None = None,
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

@abstractmethod
async def start_operation(
self,
Expand Down Expand Up @@ -804,6 +833,35 @@ async def execute_operation(
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

# Overload for temporal_operation methods
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
summary: str | None = None,
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

@abstractmethod
async def execute_operation(
self,
Expand Down
12 changes: 11 additions & 1 deletion temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
See https://github.com/temporalio/sdk-python/tree/main#nexus
"""

from ._decorators import workflow_run_operation
from ._decorators import temporal_operation, workflow_run_operation
from ._operation_context import (
Info,
LoggerAdapter,
NexusCallback,
TemporalNexusCancelOperationContext,
TemporalNexusStartOperationContext,
WorkflowRunOperationContext,
client,
in_operation,
Expand All @@ -18,6 +20,8 @@
wait_for_worker_shutdown,
wait_for_worker_shutdown_sync,
)
from ._operation_handlers import TemporalNexusOperationHandler
from ._temporal_client import TemporalNexusClient, TemporalOperationResult
from ._token import WorkflowHandle

__all__ = (
Expand All @@ -26,6 +30,8 @@
"LoggerAdapter",
"NexusCallback",
"WorkflowRunOperationContext",
"TemporalNexusCancelOperationContext",
"TemporalNexusStartOperationContext",
"client",
"in_operation",
"info",
Expand All @@ -35,4 +41,8 @@
"wait_for_worker_shutdown",
"wait_for_worker_shutdown_sync",
"WorkflowHandle",
"TemporalNexusClient",
"TemporalNexusOperationHandler",
"TemporalOperationResult",
"temporal_operation",
)
201 changes: 185 additions & 16 deletions temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from collections.abc import Awaitable, Callable
from typing import (
TypeVar,
overload,
)

Expand All @@ -13,26 +12,37 @@
StartOperationContext,
)

from ._operation_context import WorkflowRunOperationContext
from ._operation_handlers import WorkflowRunOperationHandler
from temporalio.nexus._temporal_client import (
TemporalNexusClient,
TemporalOperationResult,
)
from temporalio.types import NexusServiceType

from ._operation_context import (
TemporalNexusStartOperationContext,
WorkflowRunOperationContext,
)
from ._operation_handlers import (
WorkflowRunOperationHandler,
_TemporalNexusOperationHandler,
)
from ._token import WorkflowHandle
from ._util import (
get_callable_name,
get_temporal_operation_start_method_input_and_output_type_annotations,
get_workflow_run_start_method_input_and_output_type_annotations,
set_operation_factory,
)

ServiceHandlerT = TypeVar("ServiceHandlerT")


@overload
def workflow_run_operation(
start: Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
) -> Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]: ...

Expand All @@ -44,12 +54,12 @@ def workflow_run_operation(
) -> Callable[
[
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
],
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
]: ...
Expand All @@ -59,26 +69,26 @@ def workflow_run_operation(
start: None
| (
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
) = None,
*,
name: str | None = None,
) -> (
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
| Callable[
[
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
],
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
]
Expand All @@ -87,11 +97,11 @@ def workflow_run_operation(

def decorator(
start: Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
) -> Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]:
(
Expand All @@ -100,7 +110,7 @@ def decorator(
) = get_workflow_run_start_method_input_and_output_type_annotations(start)

def operation_handler_factory(
self: ServiceHandlerT,
self: NexusServiceType,
) -> OperationHandler[InputT, OutputT]:
async def _start(
ctx: StartOperationContext, input: InputT
Expand Down Expand Up @@ -130,3 +140,162 @@ async def _start(
return decorator

return decorator(start)


@overload
def temporal_operation(
start: Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
) -> Callable[
[NexusServiceType, TemporalNexusStartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]: ...


@overload
def temporal_operation(
*,
name: str | None = None,
) -> Callable[
[
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
],
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
]: ...


def temporal_operation(
start: None
| (
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
) = None,
*,
name: str | None = None,
) -> (
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
| Callable[
[
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
],
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
]
):
"""Decorator marking a method as the start method for an operation that interacts with Temporal.

.. warning::
This API is experimental and unstable.
"""

def decorator(
start: Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
) -> Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]:
(
input_type,
output_type,
) = get_temporal_operation_start_method_input_and_output_type_annotations(start)

def operation_handler_factory(
self: NexusServiceType,
) -> OperationHandler[InputT, OutputT]:
async def _start(
ctx: TemporalNexusStartOperationContext,
client: TemporalNexusClient,
input: InputT,
) -> TemporalOperationResult[OutputT]:
return await start(
self,
ctx,
client,
input,
)

_start.__doc__ = start.__doc__
return _TemporalNexusOperationHandler(_start)

method_name = get_callable_name(start)
op = nexusrpc.Operation(
name=name or method_name,
input_type=input_type,
output_type=output_type,
)
op.method_name = method_name
nexusrpc.set_operation(operation_handler_factory, op)

set_operation_factory(start, operation_handler_factory)
return start

if start is None:
return decorator

return decorator(start)
Loading