diff --git a/temporalio/client/_impl.py b/temporalio/client/_impl.py index af221865a..a44938015 100644 --- a/temporalio/client/_impl.py +++ b/temporalio/client/_impl.py @@ -10,6 +10,7 @@ Callable, Mapping, ) +from contextvars import ContextVar from datetime import timedelta from typing import ( TYPE_CHECKING, @@ -26,6 +27,7 @@ import temporalio.api.schedule.v1 import temporalio.api.taskqueue.v1 import temporalio.api.update.v1 +import temporalio.api.workflow.v1 import temporalio.api.workflowservice.v1 import temporalio.common import temporalio.converter @@ -136,6 +138,14 @@ from ._client import Client +# Set by WorkflowTimeSkipper's outbound interceptor before super().start_workflow(input), +# read in _populate_start_workflow_execution_request to stamp time_skipping_config onto +# the outgoing request. Reset in the interceptor's finally block. +_start_workflow_time_skipping_config: ContextVar[ + temporalio.api.workflow.v1.TimeSkippingConfig | None +] = ContextVar("_start_workflow_time_skipping_config", default=None) + + class _ClientImpl(OutboundInterceptor): # pyright: ignore[reportUnusedClass] def __init__(self, client: Client) -> None: # type: ignore # We are intentionally not calling the base class's __init__ here @@ -340,6 +350,9 @@ async def _populate_start_workflow_execution_request( req.priority.CopyFrom(input.priority._to_proto()) if input.versioning_override is not None: req.versioning_override.CopyFrom(input.versioning_override._to_proto()) + ts_config = _start_workflow_time_skipping_config.get() + if ts_config is not None: + req.time_skipping_config.CopyFrom(ts_config) async def cancel_workflow(self, input: CancelWorkflowInput) -> None: await self._client.workflow_service.request_cancel_workflow_execution( diff --git a/temporalio/testing/__init__.py b/temporalio/testing/__init__.py index ac534b953..33f1c7b6e 100644 --- a/temporalio/testing/__init__.py +++ b/temporalio/testing/__init__.py @@ -1,9 +1,12 @@ """Test framework for workflows and activities.""" from ._activity import ActivityEnvironment +from ._timeskipping import WorkflowTimeSkipper, WorkflowTimeSkippingConfig from ._workflow import WorkflowEnvironment __all__ = [ "ActivityEnvironment", "WorkflowEnvironment", + "WorkflowTimeSkipper", + "WorkflowTimeSkippingConfig", ] diff --git a/temporalio/testing/_timeskipping.py b/temporalio/testing/_timeskipping.py new file mode 100644 index 000000000..57ee22512 --- /dev/null +++ b/temporalio/testing/_timeskipping.py @@ -0,0 +1,226 @@ +"""Utilities for per-workflow time skipping in tests.""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import timedelta +from typing import Any + +import google.protobuf.field_mask_pb2 + +import temporalio.api.common.v1 +import temporalio.api.enums.v1.event_type_pb2 as _event_type +import temporalio.api.workflow.v1 +import temporalio.api.workflowservice.v1 +import temporalio.client +from temporalio.client._impl import _start_workflow_time_skipping_config + + +@dataclass(frozen=True) +class WorkflowTimeSkippingConfig: + """Per-workflow time skipping configuration.""" + + enabled: bool = True + """Whether time skipping is enabled for the workflow.""" + + max_skip_duration: timedelta | None = None + """Maximum total virtual time that can be skipped before time skipping + is automatically disabled.""" + + def _to_proto(self) -> temporalio.api.workflow.v1.TimeSkippingConfig: + proto = temporalio.api.workflow.v1.TimeSkippingConfig(enabled=self.enabled) + if self.max_skip_duration is not None: + proto.max_skipped_duration.FromTimedelta(self.max_skip_duration) + return proto + + +_TERMINAL_EVENT_TYPES = frozenset( + { + _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, + _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_FAILED, + _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_TIMED_OUT, + _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_TERMINATED, + _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED, + _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_CONTINUED_AS_NEW, + } +) + + +class WorkflowTimeSkipper: + """Testing utility for per-workflow time skipping. + + Creates a cloned client that automatically enables time skipping on every + workflow started through it. Once a workflow's configured bound is + reached, :py:meth:`wait_for_skip_duration_reached` blocks until the + transition occurs and :py:meth:`resume` re-enables skipping with an + optional new delta. + + Example:: + + ts = WorkflowTimeSkipper(env.client, + config=WorkflowTimeSkippingConfig(max_skip_duration=timedelta(hours=1))) + + handle = await ts.client.start_workflow( + MyWorkflow.run, id="wf-1", task_queue="tq", + ) + await ts.wait_for_skip_duration_reached(handle) + # inspect state, signal, etc. + await ts.resume(handle, delta=timedelta(hours=1)) + result = await handle.result() + + Works against any client the test suite hands in (local, self-hosted, or + cloud). TODO: cloud usage assumes the namespace has server-side time + skipping enabled (``frontend.TimeSkippingEnabled``); add a ``cloud`` + fixture mode alongside ``local`` / ``time-skipping`` in ``conftest.env`` + so the same tests can be pointed at a cloud namespace once that lands. + """ + + def __init__( + self, + client: temporalio.client.Client, + *, + config: WorkflowTimeSkippingConfig = WorkflowTimeSkippingConfig(), + ) -> None: + """Create a workflow time skipper. + + Args: + client: The client to wrap. A cloned client with a time-skipping + interceptor is created; the original is left untouched. + config: Initial bound. Defaults to no bound — time skipping runs + until the workflow completes. + """ + self._config = config + client_config = client.config() + client_config["interceptors"] = [ + *client_config["interceptors"], + _TimeSkippingConfigInterceptor(self), + ] + self._client = temporalio.client.Client(**client_config) + # Per-workflow max_skip_duration last set on the server, keyed by + # (workflow_id, run_id). + self._bound_cache: dict[tuple[str, str], timedelta] = {} + + @property + def client(self) -> temporalio.client.Client: + """Client that enables time skipping on every started workflow.""" + return self._client + + @property + def config(self) -> WorkflowTimeSkippingConfig: + """Bound applied to future start_workflow calls.""" + return self._config + + @config.setter + def config(self, value: WorkflowTimeSkippingConfig) -> None: + self._config = value + + async def wait_for_skip_duration_reached( + self, + handle: temporalio.client.WorkflowHandle[Any, Any], + ) -> bool: + """Block until the workflow's configured skip duration is reached. + + Returns ``True`` once a time-skipping-disabled transition is observed. + Returns ``False`` if the workflow terminates before any bound is + reached. + """ + # TODO: Replace with a dedicated long-poll RPC once the server adds + # one for time-skipping transitions. The current path streams every + # history event since the workflow started, which is correct but not + # the most efficient if event volume is high. + async for event in handle.fetch_history_events(wait_new_event=True): + if ( + event.event_type + == _event_type.EVENT_TYPE_WORKFLOW_EXECUTION_TIME_SKIPPING_TRANSITIONED + ): + attrs = ( + event.workflow_execution_time_skipping_transitioned_event_attributes + ) + if attrs.disabled_after_bound: + return True + elif event.event_type in _TERMINAL_EVENT_TYPES: + return False + return False + + async def resume( + self, + handle: temporalio.client.WorkflowHandle[Any, Any], + delta: timedelta | None = None, + ) -> None: + """Re-enable time skipping after a bound was reached. + + With ``delta``, sets a new bound equal to (previously-set bound + + delta). Without ``delta``, resumes skipping with no bound — the + workflow auto-skips until completion. + """ + proto = temporalio.api.workflow.v1.TimeSkippingConfig(enabled=True) + if delta is not None: + cache_key = (handle.id, handle.run_id or "") + if cache_key not in self._bound_cache: + if self._config.max_skip_duration is None: + raise ValueError( + "resume(delta=...) requires an initial bound to have been " + "configured on the WorkflowTimeSkipper, or call resume() " + "with no delta to resume unbounded." + ) + self._bound_cache[cache_key] = self._config.max_skip_duration + new_value = self._bound_cache[cache_key] + delta + proto.max_skipped_duration.FromTimedelta(new_value) + self._bound_cache[cache_key] = new_value + + await self._client.workflow_service.update_workflow_execution_options( + temporalio.api.workflowservice.v1.UpdateWorkflowExecutionOptionsRequest( + namespace=self._client.namespace, + workflow_execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=handle.id, + run_id=handle.run_id or "", + ), + workflow_execution_options=temporalio.api.workflow.v1.WorkflowExecutionOptions( + time_skipping_config=proto, + ), + update_mask=google.protobuf.field_mask_pb2.FieldMask( + paths=["time_skipping_config"], + ), + identity=self._client.identity, + ), + retry=True, + ) + + +class _TimeSkippingConfigInterceptor(temporalio.client.Interceptor): + def __init__(self, skipper: WorkflowTimeSkipper) -> None: + super().__init__() + self._skipper = skipper + + def intercept_client( + self, next: temporalio.client.OutboundInterceptor + ) -> temporalio.client.OutboundInterceptor: + return _TimeSkippingConfigOutbound(next, self._skipper) + + +class _TimeSkippingConfigOutbound(temporalio.client.OutboundInterceptor): + def __init__( + self, + next: temporalio.client.OutboundInterceptor, + skipper: WorkflowTimeSkipper, + ) -> None: + super().__init__(next) + self._skipper = skipper + + async def start_workflow( + self, input: temporalio.client.StartWorkflowInput + ) -> temporalio.client.WorkflowHandle[Any, Any]: + proto = self._skipper.config._to_proto() + token = _start_workflow_time_skipping_config.set(proto) + try: + handle = await super().start_workflow(input) + finally: + _start_workflow_time_skipping_config.reset(token) + # Seed the bound cache so future resume(delta=...) calls have a + # baseline to add to. Captures the config at start time, even if the + # user mutates self._skipper.config afterwards. + cfg = self._skipper.config + if cfg.max_skip_duration is not None: + cache_key = (handle.id, handle.run_id or "") + self._skipper._bound_cache[cache_key] = cfg.max_skip_duration + return handle diff --git a/tests/conftest.py b/tests/conftest.py index 2005dbe57..b1059dd78 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,6 +134,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "nexusoperation.enableStandalone=true", "--dynamic-config-value", 'system.system.refreshNexusEndpointsMinWait="0s"', + "--dynamic-config-value", + "frontend.TimeSkippingEnabled=true", ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) diff --git a/tests/testing/test_timeskipping.py b/tests/testing/test_timeskipping.py new file mode 100644 index 000000000..f96e1986d --- /dev/null +++ b/tests/testing/test_timeskipping.py @@ -0,0 +1,149 @@ +"""Tests for per-workflow time skipping. + +Two usage patterns are exercised: + +- **Basic**: enable time skipping on a client to forward user timers until completion. +- **Interactive (complicated)**: forward the time by a duration before completion and when it is done, + notify the test to signal or update, and then resume skipping with another duration to forward or + no duration to let the time skipping only forward user timers until completion. +in either case, time skipping only happens when there is no in-flight work, so even if a duration is set +to forward but the workflow can just run to completion (for example a workflow that has no idle time). + +The ``env`` fixture from ``conftest.py`` works unchanged for cloud — no +separate cloud-only env fixture is needed. +""" + +import asyncio +import uuid +from datetime import timedelta +from time import monotonic + +import pytest + +from temporalio import workflow +from temporalio.testing import ( + WorkflowEnvironment, + WorkflowTimeSkipper, + WorkflowTimeSkippingConfig, +) +from tests.helpers import new_worker + + +@workflow.defn +class SingleTimerWorkflow: + @workflow.run + async def run(self) -> float: + """Sleep 1h of virtual time and return the elapsed virtual seconds.""" + start = workflow.now() + await workflow.sleep(timedelta(hours=1)) + return (workflow.now() - start).total_seconds() + + +@workflow.defn +class InteractionWorkflow: + """Completes after receiving two ``proceed`` signals; otherwise waits up to 10h.""" + + def __init__(self) -> None: + self.signals_received = 0 + + @workflow.run + async def run(self) -> str: + await workflow.wait_condition( + lambda: self.signals_received >= 2, + timeout=timedelta(hours=10), + ) + return "done" + + @workflow.signal + def proceed(self) -> None: + self.signals_received += 1 + + @workflow.query + def get_signal_count(self) -> int: + return self.signals_received + + +async def test_pattern_basic(env: WorkflowEnvironment) -> None: + """Pattern 1: enable time skipping, let workflow run to completion.""" + ts = WorkflowTimeSkipper(env.client) + async with new_worker(ts.client, SingleTimerWorkflow) as worker: + wall_start = monotonic() + result = await ts.client.execute_workflow( + SingleTimerWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + wall_elapsed = monotonic() - wall_start + + # Virtual time advanced by ~1h even though wall time was just a few seconds. + assert result >= 3600, ( + f"virtual elapsed was {result}s; expected >= 3600s (timer did not fire fully)" + ) + # 1-hour timer should be auto-skipped in well under 3s of wall time. + assert wall_elapsed < 3, ( + f"workflow took {wall_elapsed:.3f}s wall time; time skipping did not engage" + ) + + +async def test_pattern_basic_no_skipping_times_out( + env: WorkflowEnvironment, +) -> None: + """Without time skipping, the 1h timer does not complete in 3s.""" + async with new_worker(env.client, SingleTimerWorkflow) as worker: + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for( + env.client.execute_workflow( + SingleTimerWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ), + timeout=3, + ) + + +async def test_pattern2_bounded_with_resume(env: WorkflowEnvironment) -> None: + """Pattern 2: skip 1h, signal, resume +1h, signal, workflow completes. + + The workflow would otherwise sit on a 10-hour timer waiting for two + signals. Time skipping advances virtual time to each interaction point; + the test sends a signal once skipping pauses at each bound. + + TODO: this requires a dev-server build that enforces + ``TimeSkippingConfig.max_skipped_duration`` and + emits ``WorkflowExecutionTimeSkippingTransitioned`` events. The currently + downloaded CLI does not — point ``conftest.env`` at a local build of the + CLI branch that ships the bound + transition-event API once it lands. + """ + ts = WorkflowTimeSkipper( + env.client, + config=WorkflowTimeSkippingConfig(max_skip_duration=timedelta(hours=1)), + ) + async with new_worker(ts.client, InteractionWorkflow) as worker: + wall_start = monotonic() + handle = await ts.client.start_workflow( + InteractionWorkflow.run, + id=f"wf-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Skip 1h of virtual time; bound pauses skipping so we can interact. + assert await ts.wait_for_skip_duration_reached(handle), ( + "expected first bound at 1h" + ) + await handle.signal(InteractionWorkflow.proceed) + assert await handle.query(InteractionWorkflow.get_signal_count) == 1 + + # Skip another 1h, then send the second signal to release the workflow. + await ts.resume(handle, delta=timedelta(hours=1)) + assert await ts.wait_for_skip_duration_reached(handle), ( + "expected second bound at 2h total" + ) + await handle.signal(InteractionWorkflow.proceed) + + result = await handle.result() + wall_elapsed = monotonic() - wall_start + + assert result == "done" + assert wall_elapsed < 60, ( + f"workflow took {wall_elapsed:.1f}s wall time; expected fast finish" + )