diff --git a/.github/workflows/durabletask.yml b/.github/workflows/durabletask.yml index e7465ef..ce79a21 100644 --- a/.github/workflows/durabletask.yml +++ b/.github/workflows/durabletask.yml @@ -57,19 +57,7 @@ jobs: - name: Pytest unit tests working-directory: tests/durabletask run: | - pytest -m "not e2e and not dts" --verbose - # Sidecar for running e2e tests requires Go SDK - - name: Install Go SDK - uses: actions/setup-go@v5 - with: - go-version: 'stable' - # Install and run the durabletask-go sidecar for running e2e tests - - name: Pytest e2e tests - working-directory: tests/durabletask - run: | - go install github.com/microsoft/durabletask-go@main - durabletask-go --port 4001 & - pytest -m "e2e and not dts" --verbose + pytest -m "not dts" --verbose publish-release: if: startsWith(github.ref, 'refs/tags/v') # Only run if a matching tag is pushed diff --git a/CHANGELOG.md b/CHANGELOG.md index d4f27fd..3cc5e03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +ADDED + +- Added `durabletask.testing` module with `InMemoryOrchestrationBackend` for testing orchestrations without a sidecar process + FIXED: - Fix unbound variable in entity V1 processing diff --git a/Makefile b/Makefile index 5a05f33..c554da8 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,8 @@ init: pip3 install -r requirements.txt -test-unit: - pytest -m "not e2e" --verbose - -test-e2e: - pytest -m e2e --verbose +test: + pytest --verbose install: python3 -m pip install . @@ -16,4 +13,4 @@ gen-proto: python3 -m grpc_tools.protoc --proto_path=. --python_out=. --pyi_out=. --grpc_python_out=. ./durabletask/internal/orchestrator_service.proto rm durabletask/internal/*.proto -.PHONY: init test-unit test-e2e gen-proto install +.PHONY: init test gen-proto install diff --git a/docs/development.md b/docs/development.md index 3308316..dc0f88f 100644 --- a/docs/development.md +++ b/docs/development.md @@ -11,25 +11,10 @@ make gen-proto This will download the `orchestrator_service.proto` from the `microsoft/durabletask-protobuf` repo and compile it using `grpcio-tools`. The version of the source proto file that was downloaded can be found in the file `durabletask/internal/PROTO_SOURCE_COMMIT_HASH`. -### Running unit tests +### Running tests -Unit tests can be run using the following command from the project root. Unit tests _don't_ require a sidecar process to be running. +Tests can be run using the following command from the project root. ```sh -make test-unit -``` - -### Running E2E tests - -The E2E (end-to-end) tests require a sidecar process to be running. You can use the Durable Task test sidecar using the following `docker` command: - -```sh -go install github.com/microsoft/durabletask-go@main -durabletask-go --port 4001 -``` - -To run the E2E tests, run the following command from the project root: - -```sh -make test-e2e +make test ``` \ No newline at end of file diff --git a/durabletask/testing/README.md b/durabletask/testing/README.md new file mode 100644 index 0000000..0a6e985 --- /dev/null +++ b/durabletask/testing/README.md @@ -0,0 +1,293 @@ +# Durable Task Testing Utilities + +This package provides testing utilities for the Durable Task Python SDK, +including an in-memory backend that eliminates the need for external +dependencies during testing. + +## In-Memory Backend + +The `InMemoryOrchestrationBackend` is a lightweight, in-memory implementation +of the Durable Task backend that runs as a gRPC server. It's designed for +testing scenarios where you want to test orchestrations without requiring a +sidecar process or external storage. + +### Features + +- **In-memory state storage**: All orchestration state is stored in memory +- **Full gRPC compatibility**: Implements the same gRPC interface as the production backend +- **Thread-safe**: Safe for concurrent access from multiple threads +- **Work item streaming**: Supports streaming work items to workers +- **Event handling**: Supports raising events, timers, and sub-orchestrations +- **Entity support**: Supports function-based and class-based entities +- **Lifecycle management**: Supports suspend, resume, terminate, and restart operations +- **State waiting**: Built-in support for waiting on orchestration state changes + +### Quick Start + +```python +import pytest +from durabletask.testing import create_test_backend +from durabletask.client import TaskHubGrpcClient, OrchestrationStatus +from durabletask.worker import TaskHubGrpcWorker + +@pytest.fixture +def backend(): + """Create an in-memory backend for testing.""" + backend = create_test_backend(port=50051) + yield backend + backend.stop() + backend.reset() + +def test_simple_orchestration(backend): + # Create client and worker + client = TaskHubGrpcClient(host_address="localhost:50051") + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + # Define orchestrator and activity + def hello_orchestrator(ctx, _): + result = yield ctx.call_activity(say_hello, input="World") + return result + + def say_hello(ctx, name: str): + return f"Hello, {name}!" + + # Register orchestrator and activity with the worker + worker.add_orchestrator(hello_orchestrator) + worker.add_activity(say_hello) + + # Start worker + worker.start() + + try: + # Schedule orchestration + instance_id = client.schedule_new_orchestration(hello_orchestrator) + + # Wait for completion + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + + # Verify results + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == '"Hello, World!"' + finally: + worker.stop() +``` + +### Advanced Usage + +#### Testing with Multiple Ports + +```python +import random +import pytest +from durabletask.testing import create_test_backend +from durabletask.client import TaskHubGrpcClient +from durabletask.worker import TaskHubGrpcWorker + +@pytest.fixture +def backend(): + # Use a random port to avoid conflicts + port = random.randint(50000, 60000) + backend = create_test_backend(port=port) + yield backend, port + backend.stop() + backend.reset() + +def test_orchestration(backend): + backend_instance, port = backend + client = TaskHubGrpcClient(host_address=f"localhost:{port}") + worker = TaskHubGrpcWorker(host_address=f"localhost:{port}") + # ... +``` + +#### Testing Event Handling + +```python +def test_external_events(backend): + client = TaskHubGrpcClient(host_address="localhost:50051") + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + def wait_for_event_orchestrator(ctx, _): + event_data = yield ctx.wait_for_external_event("approval") + return event_data + + worker.add_orchestrator(wait_for_event_orchestrator) + worker.start() + + try: + instance_id = client.schedule_new_orchestration(wait_for_event_orchestrator) + + # Wait for orchestration to start + client.wait_for_orchestration_start(instance_id, timeout=5) + + # Raise an event + client.raise_orchestration_event(instance_id, "approval", data="approved") + + # Wait for completion + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == '"approved"' + finally: + worker.stop() +``` + +#### Testing Sub-Orchestrations + +```python +def test_sub_orchestrations(backend): + client = TaskHubGrpcClient(host_address="localhost:50051") + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + def parent_orchestrator(ctx, _): + result1 = yield ctx.call_sub_orchestrator(child_orchestrator, input=1) + result2 = yield ctx.call_sub_orchestrator(child_orchestrator, input=2) + return result1 + result2 + + def child_orchestrator(ctx, input: int): + return input * 2 + + worker.add_orchestrator(parent_orchestrator) + worker.add_orchestrator(child_orchestrator) + worker.start() + + try: + instance_id = client.schedule_new_orchestration(parent_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == "6" # (1*2) + (2*2) + finally: + worker.stop() +``` + +#### Testing Timers + +```python +def test_durable_timers(backend): + import time + from datetime import timedelta + + client = TaskHubGrpcClient(host_address="localhost:50051") + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + def timer_orchestrator(ctx, _): + fire_at = ctx.current_utc_datetime + timedelta(seconds=1) + yield ctx.create_timer(fire_at) + return "timer_fired" + + worker.add_orchestrator(timer_orchestrator) + worker.start() + + try: + start_time = time.time() + instance_id = client.schedule_new_orchestration(timer_orchestrator) + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + elapsed = time.time() - start_time + + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert elapsed >= 1.0 # Timer should have waited at least 1 second + finally: + worker.stop() +``` + +#### Testing Termination + +```python +def test_orchestration_termination(backend): + client = TaskHubGrpcClient(host_address="localhost:50051") + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + def long_running_orchestrator(ctx, _): + yield ctx.wait_for_external_event("never_happens") + return "completed" + + worker.add_orchestrator(long_running_orchestrator) + worker.start() + + try: + instance_id = client.schedule_new_orchestration(long_running_orchestrator) + + # Wait for it to start + client.wait_for_orchestration_start(instance_id, timeout=5) + + # Terminate it + client.terminate_orchestration(instance_id, output="terminated_by_test") + + # Verify termination + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + + assert state.runtime_status == OrchestrationStatus.TERMINATED + finally: + worker.stop() +``` + +### Configuration Options + +The `InMemoryOrchestrationBackend` supports the following configuration options: + +- **port** (int): Port to listen on for gRPC connections (default: 50051) +- **max_history_size** (int): Maximum number of history events per orchestration (default: 10000) + +```python +backend = InMemoryOrchestrationBackend( + port=50051, + max_history_size=100000 # Support larger orchestrations +) +backend.start() +``` + +Or use the convenience factory, which starts the server automatically: + +```python +backend = create_test_backend(port=50051, max_history_size=10000) +``` + +### Thread Safety + +The in-memory backend is thread-safe and can be safely accessed from +multiple threads. All state mutations are protected by locks to ensure +consistency. + +### Limitations + +The in-memory backend is designed for testing and has some limitations compared to production backends: + +1. **No persistence**: All state is lost when the backend is stopped +2. **No distributed execution**: Runs in a single process +3. **No history streaming**: StreamInstanceHistory is not implemented +4. **No rewind**: RewindInstance is not implemented +5. **No recursive termination**: Recursive termination is not supported + +### Best Practices + +1. **Use fixtures**: Create pytest fixtures to manage backend lifecycle +2. **Reset between tests**: Call `backend.reset()` to clear state between tests +3. **Use random ports**: When running tests in parallel, use random port assignments +4. **Set appropriate timeouts**: Use reasonable timeout values in wait operations +5. **Clean up workers**: Always stop workers in finally blocks to prevent resource leaks + +### Troubleshooting + +#### Connection Errors + +If you see connection errors: + +- Ensure the backend is started before creating clients/workers +- Verify the port is not already in use +- Check that the host address matches the backend port + +#### Timeouts + +If tests timeout: + +- Increase timeout values in `wait_for_orchestration_completion` +- Check that workers are started and processing work items +- Verify orchestrators and activities are registered correctly + +#### State Not Found + +If orchestration state is not found: + +- Ensure you're using the correct instance ID +- Verify the orchestration was successfully scheduled +- Check that the backend wasn't reset between operations diff --git a/durabletask/testing/__init__.py b/durabletask/testing/__init__.py new file mode 100644 index 0000000..891f09c --- /dev/null +++ b/durabletask/testing/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Testing utilities for the Durable Task Python SDK.""" + +from durabletask.testing.in_memory_backend import ( + InMemoryOrchestrationBackend, + create_test_backend, +) + +__all__ = [ + "InMemoryOrchestrationBackend", + "create_test_backend", +] diff --git a/durabletask/testing/in_memory_backend.py b/durabletask/testing/in_memory_backend.py new file mode 100644 index 0000000..b66aecf --- /dev/null +++ b/durabletask/testing/in_memory_backend.py @@ -0,0 +1,1556 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +In-memory backend for durable orchestrations suitable for testing. + +This backend stores all orchestration state in memory and processes +work items synchronously within the same process. It is designed for +unit testing and integration testing scenarios where a sidecar process +or external storage is not desired. +""" + +import logging +import threading +import time +import uuid +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Callable, Optional + +import grpc +from concurrent import futures +from google.protobuf import empty_pb2, timestamp_pb2, wrappers_pb2 + +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.orchestrator_service_pb2_grpc as stubs +import durabletask.internal.helpers as helpers + + +@dataclass +class OrchestrationInstance: + """Internal orchestration instance state stored by the in-memory backend.""" + instance_id: str + name: str + status: pb.OrchestrationStatus + input: Optional[str] = None + output: Optional[str] = None + custom_status: Optional[str] = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + last_updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + failure_details: Optional[pb.TaskFailureDetails] = None + history: list[pb.HistoryEvent] = field(default_factory=list) + pending_events: list[pb.HistoryEvent] = field(default_factory=list) + dispatched_events: list[pb.HistoryEvent] = field(default_factory=list) + completion_token: int = 0 + tags: Optional[dict[str, str]] = None + + +@dataclass +class ActivityWorkItem: + """Activity work item that needs to be executed.""" + instance_id: str + name: str + task_id: int + input: Optional[str] + completion_token: int + + +@dataclass +class EntityState: + """Internal entity state stored by the in-memory backend.""" + instance_id: str + serialized_state: Optional[str] = None + last_modified_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + locked_by: Optional[str] = None + pending_operations: list[pb.HistoryEvent] = field(default_factory=list) + completion_token: int = 0 + + +@dataclass +class PendingLockRequest: + """Pending lock request from an orchestration.""" + critical_section_id: str + parent_instance_id: str + lock_set: list[str] + + +@dataclass +class EntityWorkItem: + """Entity work item that needs to be executed.""" + instance_id: str + entity_state: Optional[str] + operations: list[pb.HistoryEvent] + completion_token: int + + +@dataclass +class StateWaiter: + """Promise resolver for waiting on orchestration state changes.""" + predicate: Callable[[OrchestrationInstance], bool] + event: threading.Event = field(default_factory=threading.Event) + result: Optional[OrchestrationInstance] = None + + +class InMemoryOrchestrationBackend(stubs.TaskHubSidecarServiceServicer): + """ + In-memory backend for durable orchestrations suitable for testing. + + This backend stores all orchestration state in memory and processes + work items synchronously within the same process. It is designed for + unit testing and integration testing scenarios where a sidecar process + or external storage is not desired. + + Thread-safety: All state mutations are performed with locks to ensure + thread-safe operations. The backend uses queues to manage work items + for orchestrations and activities. + """ + + def __init__(self, max_history_size: int = 10000, port: int = 50051): + """ + Creates a new in-memory backend. + + Args: + max_history_size: Maximum number of history events per orchestration (default 10000) + port: Port to listen on for gRPC connections (default 50051) + """ + self._lock = threading.RLock() + self._instances: dict[str, OrchestrationInstance] = {} + self._orchestration_queue: deque[str] = deque() + self._orchestration_queue_set: set[str] = set() + self._activity_queue: deque[ActivityWorkItem] = deque() + self._entities: dict[str, EntityState] = {} + self._entity_queue: deque[str] = deque() + self._entity_queue_set: set[str] = set() + self._entity_in_flight: set[str] = set() + self._pending_lock_requests: list[PendingLockRequest] = [] + self._orchestration_in_flight: set[str] = set() + self._state_waiters: dict[str, list[StateWaiter]] = {} + self._next_completion_token: int = 1 + self._max_history_size = max_history_size + self._port = port + self._server: Optional[grpc.Server] = None + self._logger = logging.getLogger(__name__) + self._shutdown_event = threading.Event() + self._work_available = threading.Event() + + def start(self) -> str: + """ + Starts the gRPC server on the configured port. + + Returns: + The address the server is listening on (e.g., "localhost:50051") + """ + self._shutdown_event.clear() + self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + stubs.add_TaskHubSidecarServiceServicer_to_server(self, self._server) + self._server.add_insecure_port(f'[::]:{self._port}') + self._server.start() + self._logger.info(f"In-memory backend started on port {self._port}") + return f"localhost:{self._port}" + + def stop(self, grace: Optional[float] = None): + """ + Stops the gRPC server. + + Args: + grace: Grace period in seconds for graceful shutdown + """ + self._shutdown_event.set() + self._work_available.set() # Unblock GetWorkItems loops + if self._server: + stop_future = self._server.stop(grace) + stop_future.wait() + self._server = None + self._logger.info("In-memory backend stopped") + + def reset(self): + """Resets the backend, clearing all state.""" + with self._lock: + self._instances.clear() + self._orchestration_queue.clear() + self._orchestration_queue_set.clear() + self._activity_queue.clear() + self._entities.clear() + self._entity_queue.clear() + self._entity_queue_set.clear() + self._entity_in_flight.clear() + self._pending_lock_requests.clear() + self._orchestration_in_flight.clear() + for waiters in self._state_waiters.values(): + for waiter in waiters: + waiter.event.set() + self._state_waiters.clear() + self._shutdown_event.clear() + self._work_available.clear() + + # gRPC Service Methods + + def Hello(self, request, context): + """Sends a hello request to the sidecar service.""" + return empty_pb2.Empty() + + def StartInstance(self, request: pb.CreateInstanceRequest, context): + """Starts a new orchestration instance.""" + instance_id = request.instanceId if request.instanceId else uuid.uuid4().hex + + with self._lock: + if instance_id in self._instances: + existing = self._instances[instance_id] + policy = request.orchestrationIdReusePolicy + replaceable = list(policy.replaceableStatus) if policy else [] + + if replaceable: + # If the existing status is in the replaceable list, + # we can replace the instance + if existing.status in replaceable: + # Remove existing to allow re-creation + del self._instances[instance_id] + self._orchestration_queue_set.discard(instance_id) + else: + # Status not replaceable - reject + context.abort( + grpc.StatusCode.ALREADY_EXISTS, + f"Orchestration instance '{instance_id}' already exists " + f"with non-replaceable status") + return pb.CreateInstanceResponse(instanceId=instance_id) + else: + context.abort( + grpc.StatusCode.ALREADY_EXISTS, + f"Orchestration instance '{instance_id}' already exists") + return pb.CreateInstanceResponse(instanceId=instance_id) + + now = datetime.now(timezone.utc) + start_time = request.scheduledStartTimestamp.ToDatetime(tzinfo=timezone.utc) \ + if request.HasField("scheduledStartTimestamp") else now + + instance = OrchestrationInstance( + instance_id=instance_id, + name=request.name, + status=pb.ORCHESTRATION_STATUS_PENDING, + input=request.input.value if request.input else None, + created_at=now, + last_updated_at=now, + completion_token=self._next_completion_token, + tags=dict(request.tags) if request.tags else None, + ) + self._next_completion_token += 1 + + # Add initial events to start the orchestration + orchestrator_started = helpers.new_orchestrator_started_event(start_time) + execution_started = helpers.new_execution_started_event( + request.name, instance_id, + request.input.value if request.input else None, + dict(request.tags) if request.tags else None + ) + + instance.pending_events.append(orchestrator_started) + instance.pending_events.append(execution_started) + + self._instances[instance_id] = instance + self._enqueue_orchestration(instance_id) + + self._logger.info(f"Created orchestration instance '{instance_id}' for '{request.name}'") + + return pb.CreateInstanceResponse(instanceId=instance_id) + + def GetInstance(self, request: pb.GetInstanceRequest, context): + """Gets the status of an existing orchestration instance.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + return pb.GetInstanceResponse(exists=False) + + return self._build_instance_response(instance, request.getInputsAndOutputs) + + def WaitForInstanceStart(self, request: pb.GetInstanceRequest, context): + """Waits for an orchestration instance to reach a running or completion state.""" + def predicate(inst: OrchestrationInstance) -> bool: + return inst.status != pb.ORCHESTRATION_STATUS_PENDING + + instance = self._wait_for_state(request.instanceId, predicate, timeout=context.time_remaining()) + + if not instance: + return pb.GetInstanceResponse(exists=False) + + return self._build_instance_response(instance, request.getInputsAndOutputs) + + def WaitForInstanceCompletion(self, request: pb.GetInstanceRequest, context): + """Waits for an orchestration instance to reach a completion state.""" + instance = self._wait_for_state( + request.instanceId, + self._is_terminal_status_check, + timeout=context.time_remaining() + ) + + if not instance: + return pb.GetInstanceResponse(exists=False) + + return self._build_instance_response(instance, request.getInputsAndOutputs) + + def RaiseEvent(self, request: pb.RaiseEventRequest, context): + """Raises an event to a running orchestration instance.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + context.abort(grpc.StatusCode.NOT_FOUND, + f"Orchestration instance '{request.instanceId}' not found") + return pb.RaiseEventResponse() + + event = helpers.new_event_raised_event( + request.name, + request.input.value if request.input else None + ) + instance.pending_events.append(event) + instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(instance.instance_id) + + self._logger.info(f"Raised event '{request.name}' for instance '{request.instanceId}'") + return pb.RaiseEventResponse() + + def TerminateInstance(self, request: pb.TerminateRequest, context): + """Terminates a running orchestration instance.""" + with self._lock: + self._terminate_instance_internal( + request.instanceId, + request.output.value if request.output else None, + request.recursive + ) + + return pb.TerminateResponse() + + def SuspendInstance(self, request: pb.SuspendRequest, context): + """Suspends a running orchestration instance.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + context.abort(grpc.StatusCode.NOT_FOUND, + f"Orchestration instance '{request.instanceId}' not found") + return pb.SuspendResponse() + + if instance.status == pb.ORCHESTRATION_STATUS_SUSPENDED: + return pb.SuspendResponse() + + event = helpers.new_suspend_event() + instance.pending_events.append(event) + instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(instance.instance_id) + + self._logger.info(f"Suspended instance '{request.instanceId}'") + return pb.SuspendResponse() + + def ResumeInstance(self, request: pb.ResumeRequest, context): + """Resumes a suspended orchestration instance.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + context.abort(grpc.StatusCode.NOT_FOUND, + f"Orchestration instance '{request.instanceId}' not found") + return pb.ResumeResponse() + + event = helpers.new_resume_event() + instance.pending_events.append(event) + instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(instance.instance_id) + + self._logger.info(f"Resumed instance '{request.instanceId}'") + return pb.ResumeResponse() + + def PurgeInstances(self, request: pb.PurgeInstancesRequest, context): + """Purges orchestration instances from the store.""" + purged_count = 0 + + with self._lock: + instance_id = request.instanceId + if instance_id: + # Single instance purge + instance = self._instances.get(instance_id) + if instance and self._is_terminal_status(instance.status): + del self._instances[instance_id] + self._state_waiters.pop(instance_id, None) + purged_count = 1 + elif request.HasField("purgeInstanceFilter"): + # Filter-based purge + pf = request.purgeInstanceFilter + to_purge = [] + for iid, inst in self._instances.items(): + if not self._is_terminal_status(inst.status): + continue + if pf.runtimeStatus and inst.status not in pf.runtimeStatus: + continue + if pf.HasField("createdTimeFrom") and inst.created_at < pf.createdTimeFrom.ToDatetime(timezone.utc): + continue + if pf.HasField("createdTimeTo") and inst.created_at >= pf.createdTimeTo.ToDatetime(timezone.utc): + continue + to_purge.append(iid) + for iid in to_purge: + del self._instances[iid] + self._state_waiters.pop(iid, None) + purged_count = len(to_purge) + + self._logger.info(f"Purged {purged_count} instance(s)") + return pb.PurgeInstancesResponse( + deletedInstanceCount=purged_count, + isComplete=wrappers_pb2.BoolValue(value=True), + ) + + def RestartInstance(self, request: pb.RestartInstanceRequest, context): + """Restarts a completed orchestration instance.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + context.abort( + grpc.StatusCode.NOT_FOUND, + f"Orchestration instance '{request.instanceId}' not found") + return pb.RestartInstanceResponse() + + if not self._is_terminal_status(instance.status): + context.abort( + grpc.StatusCode.FAILED_PRECONDITION, + f"Orchestration instance '{request.instanceId}' is not in a terminal state") + return pb.RestartInstanceResponse() + + name = instance.name + original_input = instance.input + + if request.restartWithNewInstanceId: + new_instance_id = uuid.uuid4().hex + else: + new_instance_id = request.instanceId + # Remove the old instance so we can recreate it + del self._instances[request.instanceId] + self._orchestration_queue_set.discard(request.instanceId) + self._state_waiters.pop(request.instanceId, None) + + self._create_instance_internal(new_instance_id, name, original_input) + + self._logger.info( + f"Restarted instance '{request.instanceId}' as '{new_instance_id}'") + return pb.RestartInstanceResponse(instanceId=new_instance_id) + + def GetWorkItems(self, request: pb.GetWorkItemsRequest, context): + """Streams work items to the worker (orchestration and activity work items).""" + self._logger.info("Worker connected and requesting work items") + + try: + while context.is_active() and not self._shutdown_event.is_set(): + work_item = None + + with self._lock: + # Check for orchestration work + while self._orchestration_queue: + instance_id = self._orchestration_queue.popleft() + self._orchestration_queue_set.discard(instance_id) + instance = self._instances.get(instance_id) + + if not instance or not instance.pending_events: + continue + + if instance_id in self._orchestration_in_flight: + # Already being processed — re-add to queue + if instance_id not in self._orchestration_queue_set: + self._orchestration_queue.append(instance_id) + self._orchestration_queue_set.add(instance_id) + break + + # Move pending events to dispatched_events + instance.dispatched_events = list(instance.pending_events) + instance.pending_events.clear() + + # Add OrchestratorStarted for re-dispatches so that + # ctx.current_utc_datetime advances correctly + if instance.history: + now = datetime.now(timezone.utc) + orch_started = helpers.new_orchestrator_started_event(now) + instance.dispatched_events.insert(0, orch_started) + + self._orchestration_in_flight.add(instance_id) + + # Create orchestrator work item + work_item = pb.WorkItem( + completionToken=str(instance.completion_token), + orchestratorRequest=pb.OrchestratorRequest( + instanceId=instance.instance_id, + pastEvents=list(instance.history), + newEvents=list(instance.dispatched_events), + ) + ) + break + + # Check for activity work + if not work_item and self._activity_queue: + activity = self._activity_queue.popleft() + work_item = pb.WorkItem( + completionToken=str(activity.completion_token), + activityRequest=pb.ActivityRequest( + name=activity.name, + taskId=activity.task_id, + input=wrappers_pb2.StringValue(value=activity.input) if activity.input else None, + orchestrationInstance=pb.OrchestrationInstance(instanceId=activity.instance_id) + ) + ) + + # Check for entity work + if not work_item: + while self._entity_queue: + entity_id = self._entity_queue.popleft() + self._entity_queue_set.discard(entity_id) + entity = self._entities.get(entity_id) + + if entity and entity.pending_operations: + # Skip if this entity is already being processed + if entity_id in self._entity_in_flight: + continue + + # Mark as in-flight to prevent duplicate dispatch + self._entity_in_flight.add(entity_id) + + # Drain all pending operations into a batch + operations = list(entity.pending_operations) + entity.pending_operations.clear() + + # Use V2 EntityRequest format so the worker + # can properly build operation_infos + work_item = pb.WorkItem( + completionToken=str(entity.completion_token), + entityRequestV2=pb.EntityRequest( + instanceId=entity.instance_id, + entityState=wrappers_pb2.StringValue( + value=entity.serialized_state + ) if entity.serialized_state else None, + operationRequests=operations, + ), + ) + break + + if work_item: + yield work_item + else: + # Wait for work to become available (with timeout for shutdown checks) + self._work_available.wait(timeout=0.1) + self._work_available.clear() + + except Exception: + self._logger.exception("Error in GetWorkItems stream") + + def CompleteOrchestratorTask(self, request: pb.OrchestratorResponse, context): + """Completes an orchestration execution with the given actions.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + self._logger.warning(f"Instance '{request.instanceId}' not found for completion") + self._orchestration_in_flight.discard(request.instanceId) + return pb.CompleteTaskResponse() + + if str(instance.completion_token) != request.completionToken: + self._logger.warning( + f"Stale completion for instance '{request.instanceId}' - ignoring" + ) + self._orchestration_in_flight.discard(request.instanceId) + return pb.CompleteTaskResponse() + + # Check history size limit + projected_size = len(instance.history) + len(instance.dispatched_events) + if projected_size > self._max_history_size: + self._orchestration_in_flight.discard(request.instanceId) + context.abort( + grpc.StatusCode.RESOURCE_EXHAUSTED, + f"Orchestration '{request.instanceId}' would exceed maximum history size" + ) + return pb.CompleteTaskResponse() + + # Move dispatched events to history + new_events = list(instance.dispatched_events) + instance.history.extend(new_events) + instance.dispatched_events.clear() + instance.last_updated_at = datetime.now(timezone.utc) + + if request.customStatus: + instance.custom_status = request.customStatus.value + + # Transition to RUNNING once processed for the first time + if instance.status == pb.ORCHESTRATION_STATUS_PENDING: + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + # Check for suspend/resume events and update status + for evt in new_events: + if evt.HasField("executionSuspended"): + instance.status = pb.ORCHESTRATION_STATUS_SUSPENDED + elif evt.HasField("executionResumed"): + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + # Process actions + for action in request.actions: + self._process_action(instance, action) + + # Update completion token for next execution + instance.completion_token = self._next_completion_token + self._next_completion_token += 1 + + # Remove from in-flight before notifying or re-enqueuing + self._orchestration_in_flight.discard(request.instanceId) + + # Notify waiters + self._notify_waiters(request.instanceId) + + # Re-enqueue if new events arrived while the orchestration was + # in-flight (between dispatch and completion) + not_terminal = not self._is_terminal_status(instance.status) + not_suspended = instance.status != pb.ORCHESTRATION_STATUS_SUSPENDED + if instance.pending_events and not_terminal and not_suspended: + self._enqueue_orchestration(request.instanceId) + + return pb.CompleteTaskResponse() + + def CompleteActivityTask(self, request: pb.ActivityResponse, context): + """Completes an activity execution.""" + with self._lock: + instance = self._instances.get(request.instanceId) + if not instance: + self._logger.warning(f"Instance '{request.instanceId}' not found for activity completion") + return pb.CompleteTaskResponse() + + if request.failureDetails and request.failureDetails.errorMessage: + # Activity failed + event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + taskFailed=pb.TaskFailedEvent( + taskScheduledId=request.taskId, + failureDetails=request.failureDetails + ) + ) + else: + # Activity succeeded + event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + taskCompleted=pb.TaskCompletedEvent( + taskScheduledId=request.taskId, + result=request.result + ) + ) + + instance.pending_events.append(event) + instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(request.instanceId) + + return pb.CompleteTaskResponse() + + def CompleteEntityTask(self, request: pb.EntityBatchResult, context): + """Completes an entity batch execution.""" + with self._lock: + # Find entity by completion token + entity = None + for e in self._entities.values(): + if str(e.completion_token) == request.completionToken: + entity = e + break + + if not entity: + self._logger.warning( + f"No entity found for completion token '{request.completionToken}'" + ) + return pb.CompleteTaskResponse() + + # Update entity state + if request.entityState and request.entityState.value: + entity.serialized_state = request.entityState.value + else: + entity.serialized_state = None + entity.last_modified_at = datetime.now(timezone.utc) + + # Update completion token for next batch + entity.completion_token = self._next_completion_token + self._next_completion_token += 1 + + # Clear the in-flight flag + self._entity_in_flight.discard(entity.instance_id) + + # Deliver operation results to calling orchestrations + for i, op_info in enumerate(request.operationInfos): + dest = op_info.responseDestination + if dest and dest.instanceId: + parent_instance_id = op_info.responseDestination.instanceId + parent_instance = self._instances.get(parent_instance_id) + if parent_instance: + result = request.results[i] if i < len(request.results) else None + if result and result.HasField("success"): + event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityOperationCompleted=pb.EntityOperationCompletedEvent( + requestId=op_info.requestId, + output=result.success.result, + ) + ) + elif result and result.HasField("failure"): + event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityOperationFailed=pb.EntityOperationFailedEvent( + requestId=op_info.requestId, + failureDetails=result.failure.failureDetails, + ) + ) + else: + continue + + parent_instance.pending_events.append(event) + parent_instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(parent_instance_id) + + # Process side-effect actions (signals to other entities, new orchestrations) + for action in request.actions: + if action.HasField("sendSignal"): + signal = action.sendSignal + self._signal_entity_internal( + signal.instanceId, signal.name, + signal.input.value if signal.input else None + ) + elif action.HasField("startNewOrchestration"): + start_orch = action.startNewOrchestration + orch_input = start_orch.input.value if start_orch.input else None + instance_id = start_orch.instanceId or uuid.uuid4().hex + try: + self._create_instance_internal( + instance_id, start_orch.name, orch_input + ) + except Exception: + self._logger.warning( + f"Failed to create orchestration '{instance_id}' from entity action" + ) + + # If the entity has more pending operations, re-enqueue + if entity.pending_operations: + self._enqueue_entity(entity.instance_id) + + return pb.CompleteTaskResponse() + + def SignalEntity(self, request: pb.SignalEntityRequest, context): + """Signals an entity, queueing an operation for processing.""" + with self._lock: + entity_id = request.instanceId + entity = self._entities.get(entity_id) + if not entity: + entity = EntityState( + instance_id=entity_id, + completion_token=self._next_completion_token, + ) + self._next_completion_token += 1 + self._entities[entity_id] = entity + + # Create a signaled operation event + event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityOperationSignaled=pb.EntityOperationSignaledEvent( + requestId=request.requestId, + operation=request.name, + input=request.input if request.input else None, + targetInstanceId=wrappers_pb2.StringValue(value=entity_id), + ) + ) + entity.pending_operations.append(event) + self._enqueue_entity(entity_id) + + self._logger.info(f"Signaled entity '{entity_id}' operation '{request.name}'") + return pb.SignalEntityResponse() + + def GetEntity(self, request: pb.GetEntityRequest, context): + """Gets entity state.""" + with self._lock: + entity = self._entities.get(request.instanceId) + if not entity: + return pb.GetEntityResponse(exists=False) + + last_modified_ts = timestamp_pb2.Timestamp() + last_modified_ts.FromDatetime(entity.last_modified_at) + + metadata = pb.EntityMetadata( + instanceId=entity.instance_id, + lastModifiedTime=last_modified_ts, + backlogQueueSize=len(entity.pending_operations), + lockedBy=wrappers_pb2.StringValue(value=entity.locked_by) if entity.locked_by else None, + serializedState=wrappers_pb2.StringValue( + value=entity.serialized_state) if request.includeState and entity.serialized_state else None, + ) + + return pb.GetEntityResponse(exists=True, entity=metadata) + + def QueryInstances(self, request: pb.QueryInstancesRequest, context): + """Query orchestration instances with filtering support.""" + with self._lock: + query = request.query + start_index = 0 + if query.HasField("continuationToken") and query.continuationToken.value: + try: + start_index = int(query.continuationToken.value) + except ValueError: + start_index = 0 + + matching = [] + for instance in self._instances.values(): + # Filter by runtime status + if query.runtimeStatus and instance.status not in query.runtimeStatus: + continue + # Filter by created time range + if query.HasField("createdTimeFrom") and instance.created_at < query.createdTimeFrom.ToDatetime(timezone.utc): + continue + if query.HasField("createdTimeTo") and instance.created_at >= query.createdTimeTo.ToDatetime(timezone.utc): + continue + # Filter by instance ID prefix + if query.HasField("instanceIdPrefix") and query.instanceIdPrefix.value: + if not instance.instance_id.startswith(query.instanceIdPrefix.value): + continue + matching.append(instance) + + # Sort by created time for deterministic pagination + matching.sort(key=lambda i: i.created_at) + + # Apply pagination + page_size = query.maxInstanceCount if query.maxInstanceCount > 0 else len(matching) + page = matching[start_index:start_index + page_size] + + states = [] + for inst in page: + created_ts = timestamp_pb2.Timestamp() + created_ts.FromDatetime(inst.created_at) + updated_ts = timestamp_pb2.Timestamp() + updated_ts.FromDatetime(inst.last_updated_at) + + include = query.fetchInputsAndOutputs + state = pb.OrchestrationState( + instanceId=inst.instance_id, + name=inst.name, + orchestrationStatus=inst.status, + createdTimestamp=created_ts, + lastUpdatedTimestamp=updated_ts, + input=wrappers_pb2.StringValue(value=inst.input) if include and inst.input else None, + output=wrappers_pb2.StringValue(value=inst.output) if include and inst.output else None, + customStatus=wrappers_pb2.StringValue( + value=inst.custom_status) if inst.custom_status else None, + failureDetails=inst.failure_details if inst.failure_details else None, + ) + states.append(state) + + # Compute continuation token + next_index = start_index + page_size + continuation_token = None + if next_index < len(matching): + continuation_token = wrappers_pb2.StringValue(value=str(next_index)) + + return pb.QueryInstancesResponse( + orchestrationState=states, + continuationToken=continuation_token, + ) + + def QueryEntities(self, request: pb.QueryEntitiesRequest, context): + """Query entities with filtering support.""" + with self._lock: + query = request.query + start_index = 0 + if query.HasField("continuationToken") and query.continuationToken.value: + try: + start_index = int(query.continuationToken.value) + except ValueError: + start_index = 0 + + matching = [] + for entity in self._entities.values(): + # Filter by instance ID prefix + if query.HasField("instanceIdStartsWith") and query.instanceIdStartsWith.value: + if not entity.instance_id.startswith(query.instanceIdStartsWith.value): + continue + # Filter by last modified time range + if query.HasField("lastModifiedFrom") and entity.last_modified_at < query.lastModifiedFrom.ToDatetime(timezone.utc): + continue + if query.HasField("lastModifiedTo") and entity.last_modified_at >= query.lastModifiedTo.ToDatetime(timezone.utc): + continue + # Filter transient (entities with pending operations) + if not query.includeTransient and entity.pending_operations: + continue + matching.append(entity) + + # Sort by instance_id for deterministic pagination + matching.sort(key=lambda e: e.instance_id) + + # Apply pagination + page_size = query.pageSize.value if query.HasField("pageSize") and query.pageSize.value > 0 else len(matching) + page = matching[start_index:start_index + page_size] + + entities = [] + for ent in page: + last_modified_ts = timestamp_pb2.Timestamp() + last_modified_ts.FromDatetime(ent.last_modified_at) + + metadata = pb.EntityMetadata( + instanceId=ent.instance_id, + lastModifiedTime=last_modified_ts, + backlogQueueSize=len(ent.pending_operations), + lockedBy=wrappers_pb2.StringValue(value=ent.locked_by) if ent.locked_by else None, + serializedState=wrappers_pb2.StringValue( + value=ent.serialized_state + ) if query.includeState and ent.serialized_state else None, + ) + entities.append(metadata) + + # Compute continuation token + next_index = start_index + page_size + continuation_token = None + if next_index < len(matching): + continuation_token = wrappers_pb2.StringValue(value=str(next_index)) + + return pb.QueryEntitiesResponse( + entities=entities, + continuationToken=continuation_token, + ) + + def CleanEntityStorage(self, request: pb.CleanEntityStorageRequest, context): + """Clean entity storage: remove empty entities and release orphaned locks.""" + empty_removed = 0 + locks_released = 0 + + with self._lock: + if request.removeEmptyEntities: + to_remove = [ + eid for eid, ent in self._entities.items() + if ent.serialized_state is None and not ent.pending_operations + ] + for eid in to_remove: + del self._entities[eid] + self._entity_queue_set.discard(eid) + empty_removed = len(to_remove) + + if request.releaseOrphanedLocks: + for ent in self._entities.values(): + if ent.locked_by and ent.locked_by not in self._instances: + ent.locked_by = None + locks_released += 1 + + return pb.CleanEntityStorageResponse( + emptyEntitiesRemoved=empty_removed, + orphanedLocksReleased=locks_released, + ) + + def StreamInstanceHistory(self, request: pb.StreamInstanceHistoryRequest, context): + """Streams instance history (not implemented).""" + context.abort(grpc.StatusCode.UNIMPLEMENTED, "StreamInstanceHistory not implemented") + + def CreateTaskHub(self, request: pb.CreateTaskHubRequest, context): + """Creates task hub resources (no-op for in-memory).""" + return pb.CreateTaskHubResponse() + + def DeleteTaskHub(self, request: pb.DeleteTaskHubRequest, context): + """Deletes task hub resources (no-op for in-memory).""" + return pb.DeleteTaskHubResponse() + + def RewindInstance(self, request: pb.RewindInstanceRequest, context): + """Rewinds an orchestration instance (not implemented).""" + context.abort(grpc.StatusCode.UNIMPLEMENTED, "RewindInstance not implemented") + + def AbandonTaskActivityWorkItem(self, request: pb.AbandonActivityTaskRequest, context): + """Abandons an activity work item.""" + return pb.AbandonActivityTaskResponse() + + def AbandonTaskOrchestratorWorkItem(self, request: pb.AbandonOrchestrationTaskRequest, context): + """Abandons an orchestration work item.""" + return pb.AbandonOrchestrationTaskResponse() + + def AbandonTaskEntityWorkItem(self, request: pb.AbandonEntityTaskRequest, context): + """Abandons an entity work item.""" + return pb.AbandonEntityTaskResponse() + + # Internal helper methods + + def _enqueue_orchestration(self, instance_id: str): + """Enqueues an orchestration for processing.""" + if instance_id not in self._orchestration_queue_set: + self._orchestration_queue.append(instance_id) + self._orchestration_queue_set.add(instance_id) + self._work_available.set() + + def _is_terminal_status(self, status: pb.OrchestrationStatus) -> bool: + """Checks if a status is terminal.""" + return status in ( + pb.ORCHESTRATION_STATUS_COMPLETED, + pb.ORCHESTRATION_STATUS_FAILED, + pb.ORCHESTRATION_STATUS_TERMINATED + ) + + def _is_terminal_status_check(self, instance: OrchestrationInstance) -> bool: + """Predicate to check if instance is in terminal status.""" + return self._is_terminal_status(instance.status) + + def _create_instance_internal(self, instance_id: str, name: str, + encoded_input: Optional[str] = None): + """Creates a new instance directly in internal state (no gRPC context needed).""" + existing = self._instances.get(instance_id) + if existing: + if self._is_terminal_status(existing.status): + # Allow recreation of terminated instances (e.g., retry) + del self._instances[instance_id] + self._orchestration_queue_set.discard(instance_id) + else: + raise ValueError(f"Orchestration instance '{instance_id}' already exists") + + now = datetime.now(timezone.utc) + instance = OrchestrationInstance( + instance_id=instance_id, + name=name, + status=pb.ORCHESTRATION_STATUS_PENDING, + input=encoded_input, + created_at=now, + last_updated_at=now, + completion_token=self._next_completion_token, + ) + self._next_completion_token += 1 + + orchestrator_started = helpers.new_orchestrator_started_event(now) + execution_started = helpers.new_execution_started_event(name, instance_id, encoded_input) + instance.pending_events.append(orchestrator_started) + instance.pending_events.append(execution_started) + + self._instances[instance_id] = instance + self._enqueue_orchestration(instance_id) + + def _raise_event_internal(self, instance_id: str, event_name: str, + event_data: Optional[str] = None): + """Raises an event directly in internal state (no gRPC context needed).""" + instance = self._instances.get(instance_id) + if not instance: + raise ValueError(f"Orchestration instance '{instance_id}' not found") + + event = helpers.new_event_raised_event(event_name, event_data) + instance.pending_events.append(event) + instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(instance.instance_id) + + def _terminate_instance_internal(self, instance_id: str, output: Optional[str], + recursive: bool = False): + """Internal method to terminate an instance.""" + if recursive: + self._logger.warning( + "Recursive termination is not supported in the in-memory backend") + + instance = self._instances.get(instance_id) + if not instance: + return + + if self._is_terminal_status(instance.status): + return # Already terminated + + event = helpers.new_terminated_event(encoded_output=output) + instance.pending_events.append(event) + instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(instance.instance_id) + + self._logger.info(f"Terminated instance '{instance_id}'") + + def _build_instance_response(self, instance: OrchestrationInstance, + include_payloads: bool) -> pb.GetInstanceResponse: + """Builds a GetInstanceResponse from an instance.""" + created_ts = timestamp_pb2.Timestamp() + created_ts.FromDatetime(instance.created_at) + + updated_ts = timestamp_pb2.Timestamp() + updated_ts.FromDatetime(instance.last_updated_at) + + state = pb.OrchestrationState( + instanceId=instance.instance_id, + name=instance.name, + orchestrationStatus=instance.status, + createdTimestamp=created_ts, + lastUpdatedTimestamp=updated_ts, + input=wrappers_pb2.StringValue(value=instance.input) if include_payloads and instance.input else None, + output=wrappers_pb2.StringValue(value=instance.output) if include_payloads and instance.output else None, + customStatus=wrappers_pb2.StringValue(value=instance.custom_status) if instance.custom_status else None, + failureDetails=instance.failure_details if instance.failure_details else None, + ) + + return pb.GetInstanceResponse(exists=True, orchestrationState=state) + + def _wait_for_state(self, instance_id: str, + predicate: Callable[[OrchestrationInstance], bool], + timeout: Optional[float]) -> Optional[OrchestrationInstance]: + """Waits for an orchestration to reach a state matching the predicate.""" + with self._lock: + instance = self._instances.get(instance_id) + if instance and predicate(instance): + return instance + + waiter = StateWaiter(predicate=predicate) + if instance_id not in self._state_waiters: + self._state_waiters[instance_id] = [] + self._state_waiters[instance_id].append(waiter) + + # Wait outside the lock + wait_result = waiter.event.wait(timeout=timeout if timeout else 30.0) + + if wait_result: + return waiter.result + else: + # Timeout - remove waiter + with self._lock: + waiters = self._state_waiters.get(instance_id) + if waiters and waiter in waiters: + waiters.remove(waiter) + if not waiters: + self._state_waiters.pop(instance_id, None) + return None + + def _notify_waiters(self, instance_id: str): + """Notifies all waiters for an instance.""" + instance = self._instances.get(instance_id) + waiters = self._state_waiters.get(instance_id) + + if not waiters or not instance: + return + + # Find and notify matching waiters + matching_waiters = [w for w in waiters if w.predicate(instance)] + for waiter in matching_waiters: + waiter.result = instance + waiter.event.set() + + # Remove notified waiters + remaining = [w for w in waiters if w not in matching_waiters] + if remaining: + self._state_waiters[instance_id] = remaining + else: + self._state_waiters.pop(instance_id, None) + + def _process_action(self, instance: OrchestrationInstance, action: pb.OrchestratorAction): + """Processes an orchestrator action.""" + if action.HasField("completeOrchestration"): + self._process_complete_orchestration_action(instance, action.completeOrchestration) + elif action.HasField("scheduleTask"): + self._process_schedule_task_action(instance, action) + elif action.HasField("createTimer"): + self._process_create_timer_action(instance, action) + elif action.HasField("createSubOrchestration"): + self._process_create_sub_orchestration_action(instance, action) + elif action.HasField("sendEvent"): + self._process_send_event_action(action.sendEvent) + elif action.HasField("sendEntityMessage"): + self._process_send_entity_message_action(instance, action) + + def _process_complete_orchestration_action(self, instance: OrchestrationInstance, + complete_action: pb.CompleteOrchestrationAction): + """Processes a complete orchestration action.""" + status = complete_action.orchestrationStatus + instance.status = status + instance.output = complete_action.result.value if complete_action.result else None + instance.failure_details = complete_action.failureDetails if complete_action.failureDetails else None + + if status == pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: + # Handle continue-as-new + new_input = complete_action.result.value if complete_action.result else None + carryover_events = list(complete_action.carryoverEvents) + + # Reset instance state + instance.history.clear() + instance.input = new_input + instance.output = None + instance.failure_details = None + instance.status = pb.ORCHESTRATION_STATUS_PENDING + + # Save any events that arrived during the in-flight dispatch so + # they can be appended AFTER the new execution started events. + new_arrivals = list(instance.pending_events) + instance.pending_events.clear() + + # Build the new pending events in the correct order: + # OrchestratorStarted, ExecutionStarted, carryover, new arrivals + now = datetime.now(timezone.utc) + orchestrator_started = helpers.new_orchestrator_started_event(now) + execution_started = helpers.new_execution_started_event( + instance.name, instance.instance_id, new_input + ) + instance.pending_events.append(orchestrator_started) + instance.pending_events.append(execution_started) + instance.pending_events.extend(carryover_events) + instance.pending_events.extend(new_arrivals) + + self._enqueue_orchestration(instance.instance_id) + + def _process_schedule_task_action(self, instance: OrchestrationInstance, + action: pb.OrchestratorAction): + """Processes a schedule task action.""" + schedule_task = action.scheduleTask + task_id = action.id + task_name = schedule_task.name + input_value = schedule_task.input.value if schedule_task.input else None + + # Add TaskScheduled event to history + event = helpers.new_task_scheduled_event(task_id, task_name, input_value) + instance.history.append(event) + + # Mark instance as running + if instance.status == pb.ORCHESTRATION_STATUS_PENDING: + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + # Queue activity for execution + self._activity_queue.append(ActivityWorkItem( + instance_id=instance.instance_id, + name=task_name, + task_id=task_id, + input=input_value, + completion_token=instance.completion_token + )) + self._work_available.set() + + def _process_create_timer_action(self, instance: OrchestrationInstance, + action: pb.OrchestratorAction): + """Processes a create timer action.""" + create_timer = action.createTimer + timer_id = action.id + fire_at = create_timer.fireAt.ToDatetime(tzinfo=timezone.utc) + + # Add TimerCreated event to history + timer_created_event = helpers.new_timer_created_event(timer_id, fire_at) + instance.history.append(timer_created_event) + + # Mark instance as running + if instance.status == pb.ORCHESTRATION_STATUS_PENDING: + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + # Schedule timer firing + now = datetime.now(timezone.utc) + delay = max(0, (fire_at - now).total_seconds()) + + def fire_timer(): + time.sleep(delay) + with self._lock: + current_instance = self._instances.get(instance.instance_id) + if current_instance and not self._is_terminal_status(current_instance.status): + timer_fired_event = helpers.new_timer_fired_event(timer_id, fire_at) + current_instance.pending_events.append(timer_fired_event) + current_instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(instance.instance_id) + + timer_thread = threading.Thread(target=fire_timer, daemon=True) + timer_thread.start() + + def _process_create_sub_orchestration_action(self, instance: OrchestrationInstance, + action: pb.OrchestratorAction): + """Processes a create sub-orchestration action.""" + create_sub_orch = action.createSubOrchestration + task_id = action.id + name = create_sub_orch.name + sub_instance_id = create_sub_orch.instanceId + input_value = create_sub_orch.input.value if create_sub_orch.input else None + + # Add SubOrchestrationInstanceCreated event to history + event = helpers.new_sub_orchestration_created_event(task_id, name, sub_instance_id, input_value) + instance.history.append(event) + + # Mark instance as running + if instance.status == pb.ORCHESTRATION_STATUS_PENDING: + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + # Create the sub-orchestration directly via internal state + try: + self._create_instance_internal(sub_instance_id, name, input_value) + + # Watch for sub-orchestration completion + self._watch_sub_orchestration(instance.instance_id, sub_instance_id, task_id) + except Exception as ex: + # Sub-orchestration creation failed + failure_event = helpers.new_sub_orchestration_failed_event(task_id, ex) + instance.pending_events.append(failure_event) + self._enqueue_orchestration(instance.instance_id) + + def _watch_sub_orchestration(self, parent_instance_id: str, sub_instance_id: str, task_id: int): + """Watches a sub-orchestration for completion and delivers the result to the parent.""" + def watch(): + # Wait for sub-orchestration to complete + sub_instance = self._wait_for_state( + sub_instance_id, + self._is_terminal_status_check, + timeout=None # No timeout + ) + + with self._lock: + parent_instance = self._instances.get(parent_instance_id) + + if not sub_instance or not parent_instance: + return + + # If parent already terminated, don't deliver the completion event + if self._is_terminal_status(parent_instance.status): + return + + # Deliver the sub-orchestration completion/failure event to parent + if sub_instance.status == pb.ORCHESTRATION_STATUS_COMPLETED: + event = helpers.new_sub_orchestration_completed_event(task_id, sub_instance.output) + else: + error_msg = (sub_instance.failure_details.errorMessage + if sub_instance.failure_details else "Sub-orchestration failed") + event = helpers.new_sub_orchestration_failed_event(task_id, Exception(error_msg)) + + parent_instance.pending_events.append(event) + parent_instance.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(parent_instance_id) + + watcher_thread = threading.Thread(target=watch, daemon=True) + watcher_thread.start() + + def _process_send_event_action(self, send_event: pb.SendEventAction): + """Processes a send event action.""" + target_instance_id = send_event.instance.instanceId if send_event.instance else None + event_name = send_event.name + event_data = send_event.data.value if send_event.data else None + + if target_instance_id: + try: + self._raise_event_internal(target_instance_id, event_name, event_data) + except Exception: + # Target instance may not exist - ignore + pass + + def _process_send_entity_message_action(self, instance: OrchestrationInstance, + action: pb.OrchestratorAction): + """Processes a send entity message action from an orchestrator.""" + msg = action.sendEntityMessage + action_id = action.id + + if msg.HasField("entityOperationSignaled"): + signaled = msg.entityOperationSignaled + target_id = signaled.targetInstanceId.value if signaled.targetInstanceId else None + + # Add confirmation event to orchestration history + history_event = pb.HistoryEvent( + eventId=action_id, + timestamp=timestamp_pb2.Timestamp(), + entityOperationSignaled=signaled, + ) + instance.history.append(history_event) + + if target_id: + self._queue_entity_operation(target_id, pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityOperationSignaled=signaled, + )) + + elif msg.HasField("entityOperationCalled"): + called = msg.entityOperationCalled + target_id = called.targetInstanceId.value if called.targetInstanceId else None + + # Add confirmation event to orchestration history + history_event = pb.HistoryEvent( + eventId=action_id, + timestamp=timestamp_pb2.Timestamp(), + entityOperationCalled=called, + ) + instance.history.append(history_event) + + # Mark instance as running + if instance.status == pb.ORCHESTRATION_STATUS_PENDING: + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + if target_id: + self._queue_entity_operation(target_id, pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityOperationCalled=called, + )) + + elif msg.HasField("entityLockRequested"): + lock_req = msg.entityLockRequested + parent_id = lock_req.parentInstanceId.value if lock_req.parentInstanceId else None + + # Add confirmation event to orchestration history + history_event = pb.HistoryEvent( + eventId=action_id, + timestamp=timestamp_pb2.Timestamp(), + entityLockRequested=lock_req, + ) + instance.history.append(history_event) + + # Mark instance as running + if instance.status == pb.ORCHESTRATION_STATUS_PENDING: + instance.status = pb.ORCHESTRATION_STATUS_RUNNING + + if parent_id: + lock_set = list(lock_req.lockSet) + pending = PendingLockRequest( + critical_section_id=lock_req.criticalSectionId, + parent_instance_id=parent_id, + lock_set=lock_set, + ) + self._try_grant_lock(pending) + + elif msg.HasField("entityUnlockSent"): + unlock = msg.entityUnlockSent + target_id = unlock.targetInstanceId.value if unlock.targetInstanceId else None + + # Add confirmation event to orchestration history + history_event = pb.HistoryEvent( + eventId=action_id, + timestamp=timestamp_pb2.Timestamp(), + entityUnlockSent=unlock, + ) + instance.history.append(history_event) + + if target_id: + entity = self._entities.get(target_id) + if entity and entity.locked_by == unlock.criticalSectionId: + entity.locked_by = None + + # Try to grant any pending lock requests + self._try_grant_pending_locks() + + def _can_grant_lock(self, pending: PendingLockRequest) -> bool: + """Checks if all entities in the lock set are available for locking.""" + for entity_id in pending.lock_set: + entity = self._entities.get(entity_id) + if entity and entity.locked_by is not None: + return False + return True + + def _grant_lock(self, pending: PendingLockRequest): + """Grants a lock to entities and notifies the parent orchestration. + + Assumes lock availability has already been verified via _can_grant_lock. + """ + for entity_id in pending.lock_set: + entity = self._entities.get(entity_id) + if not entity: + entity = EntityState( + instance_id=entity_id, + completion_token=self._next_completion_token, + ) + self._next_completion_token += 1 + self._entities[entity_id] = entity + entity.locked_by = pending.critical_section_id + + parent = self._instances.get(pending.parent_instance_id) + if parent: + grant_event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityLockGranted=pb.EntityLockGrantedEvent( + criticalSectionId=pending.critical_section_id, + ), + ) + parent.pending_events.append(grant_event) + parent.last_updated_at = datetime.now(timezone.utc) + self._enqueue_orchestration(pending.parent_instance_id) + + def _try_grant_lock(self, pending: PendingLockRequest) -> bool: + """Tries to grant a lock request. Returns True if granted, False if queued.""" + if not self._can_grant_lock(pending): + self._pending_lock_requests.append(pending) + return False + self._grant_lock(pending) + return True + + def _try_grant_pending_locks(self): + """Attempts to grant any pending lock requests that can now be fulfilled.""" + still_pending = [] + for pending in self._pending_lock_requests: + if self._can_grant_lock(pending): + self._grant_lock(pending) + else: + still_pending.append(pending) + self._pending_lock_requests = still_pending + + def _queue_entity_operation(self, entity_id: str, event: pb.HistoryEvent): + """Queues an operation event for an entity.""" + entity = self._entities.get(entity_id) + if not entity: + entity = EntityState( + instance_id=entity_id, + completion_token=self._next_completion_token, + ) + self._next_completion_token += 1 + self._entities[entity_id] = entity + + entity.pending_operations.append(event) + self._enqueue_entity(entity_id) + + def _signal_entity_internal(self, entity_id: str, operation: str, + input_value: Optional[str] = None): + """Internal method to signal an entity (from entity side-effect actions).""" + event = pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + entityOperationSignaled=pb.EntityOperationSignaledEvent( + requestId=uuid.uuid4().hex, + operation=operation, + input=wrappers_pb2.StringValue(value=input_value) if input_value else None, + targetInstanceId=wrappers_pb2.StringValue(value=entity_id), + ) + ) + self._queue_entity_operation(entity_id, event) + + def _enqueue_entity(self, entity_id: str): + """Enqueues an entity for processing.""" + if entity_id not in self._entity_queue_set: + self._entity_queue.append(entity_id) + self._entity_queue_set.add(entity_id) + self._work_available.set() + + +def create_test_backend(port: int = 50051, max_history_size: int = 10000) -> InMemoryOrchestrationBackend: + """ + Factory function to create and start an in-memory backend for testing. + + Args: + port: Port to listen on for gRPC connections (default 50051) + max_history_size: Maximum number of history events per orchestration (default 10000) + + Returns: + A started InMemoryOrchestrationBackend instance + + Example: + ```python + import pytest + from durabletask.testing.in_memory_backend import create_test_backend + from durabletask.client import TaskHubGrpcClient + from durabletask.worker import TaskHubGrpcWorker + + @pytest.fixture + def backend(): + backend = create_test_backend(port=50051) + yield backend + backend.stop() + backend.reset() + + def test_orchestration(backend): + # Create client connected to the test backend + client = TaskHubGrpcClient(host_address="localhost:50051") + + # Create worker connected to the test backend + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + # Register orchestrators and activities + @worker.orchestrator() + def my_orchestrator(ctx): + result = yield ctx.call_activity(my_activity, input="hello") + return result + + @worker.activity() + def my_activity(ctx, input: str): + return f"processed: {input}" + + # Start the worker + worker.start() + + try: + # Schedule and wait for orchestration + instance_id = client.schedule_new_orchestration(my_orchestrator, input=None) + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + + assert state.runtime_status == OrchestrationStatus.COMPLETED + # Add more assertions... + finally: + worker.stop() + ``` + """ + backend = InMemoryOrchestrationBackend(max_history_size=max_history_size, port=port) + backend.start() + return backend diff --git a/examples/in_memory_backend_example.py b/examples/in_memory_backend_example.py new file mode 100644 index 0000000..4186f86 --- /dev/null +++ b/examples/in_memory_backend_example.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Example demonstrating the in-memory backend for testing orchestrations. + +This example shows how to: +1. Create and start an in-memory backend +2. Connect a client and worker to it +3. Define and run a simple orchestration +4. Verify the results +""" + +from durabletask.testing import create_test_backend +from durabletask.client import TaskHubGrpcClient, OrchestrationStatus +from durabletask.worker import TaskHubGrpcWorker + + +def main(): + # Create and start the in-memory backend on port 50051 + print("Starting in-memory backend...") + backend = create_test_backend(port=50051) + + try: + # Create client and worker + print("Creating client and worker...") + client = TaskHubGrpcClient(host_address="localhost:50051") + worker = TaskHubGrpcWorker(host_address="localhost:50051") + + # Define an orchestrator + def greet_orchestrator(ctx, name: str): + """Orchestrator that greets someone using activities.""" + greeting = yield ctx.call_activity(get_greeting, input=name) + punctuation = yield ctx.call_activity(add_punctuation, input=greeting) + return punctuation + + # Define activities + def get_greeting(ctx, name: str): + """Activity that generates a greeting.""" + return f"Hello, {name}" + + def add_punctuation(ctx, text: str): + """Activity that adds punctuation.""" + return f"{text}!" + + # Register orchestrators and activities + worker.add_orchestrator(greet_orchestrator) + worker.add_activity(get_greeting) + worker.add_activity(add_punctuation) + + # Start the worker + print("Starting worker...") + worker.start() + + try: + # Schedule an orchestration + print("\nScheduling orchestration...") + instance_id = client.schedule_new_orchestration( + greet_orchestrator, + input="World" + ) + print(f"Orchestration scheduled with ID: {instance_id}") + + # Wait for completion + print("Waiting for orchestration to complete...") + state = client.wait_for_orchestration_completion(instance_id, timeout=10) + + # Display results + print("\n" + "=" * 50) + print("Orchestration Results:") + print("=" * 50) + if state is None: + print("\n✗ Orchestration state is None (timed out?)") + else: + print(f"Instance ID: {state.instance_id}") + print(f"Name: {state.name}") + print(f"Status: {state.runtime_status}") + print(f"Output: {state.serialized_output}") + print(f"Created At: {state.created_at}") + print(f"Last Updated: {state.last_updated_at}") + print("=" * 50) + + # Verify the result + if state.runtime_status == OrchestrationStatus.COMPLETED: + print("\n✓ Orchestration completed successfully!") + else: + print(f"\n✗ Orchestration did not complete successfully: {state.runtime_status}") + + finally: + print("\nStopping worker...") + worker.stop() + + finally: + print("Stopping backend...") + backend.stop() + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index ec8a511..be5d8dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,3 @@ include = ["durabletask", "durabletask.*"] [tool.pytest.ini_options] minversion = "6.0" testpaths = ["tests"] -markers = [ - "e2e: mark a test as an end-to-end test that requires a running sidecar" -] diff --git a/tests/durabletask-azuremanaged/entities/test_dts_function_based_entities_e2e.py b/tests/durabletask-azuremanaged/entities/test_dts_function_based_entities_e2e.py index b3adebe..2fbf201 100644 --- a/tests/durabletask-azuremanaged/entities/test_dts_function_based_entities_e2e.py +++ b/tests/durabletask-azuremanaged/entities/test_dts_function_based_entities_e2e.py @@ -386,3 +386,12 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): # c.wait_for_orchestration_completion(id, timeout=30) # assert invoke_count == 6 + + +def test_get_entity_not_found(): + """Test that get_entity returns None for a non-existent entity.""" + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=None) + entity_id = entities.EntityInstanceId("counter", "nonexistent") + metadata = c.get_entity(entity_id, include_state=True) + assert metadata is None diff --git a/tests/durabletask/entities/__init__.py b/tests/durabletask/entities/__init__.py new file mode 100644 index 0000000..59e481e --- /dev/null +++ b/tests/durabletask/entities/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/tests/durabletask/entities/test_class_based_entities_e2e.py b/tests/durabletask/entities/test_class_based_entities_e2e.py new file mode 100644 index 0000000..ae8b31b --- /dev/null +++ b/tests/durabletask/entities/test_class_based_entities_e2e.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +E2E tests for class-based durable entities using the in-memory backend. +""" + +import time + +import pytest + +from durabletask import client, entities, task, worker +from durabletask.testing import create_test_backend + +HOST = "localhost:50059" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for entity testing.""" + b = create_test_backend(port=50059) + yield b + b.stop() + b.reset() + + +def test_client_signal_class_entity_and_custom_name(): + """Test signaling a class-based entity with a custom registration name from the client.""" + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(EmptyEntity, name="EntityNameCustom") + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + + assert invoked + + +def test_client_get_class_entity(): + """Test signaling a class-based entity and reading its state via the client.""" + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + self.set_state(1) + nonlocal invoked # don't do this in a real app! + invoked = True + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(EmptyEntity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + state = c.get_entity(entity_id, include_state=True) + assert state is not None + assert state.id == entity_id + assert state.get_state(int) == 1 + + assert invoked + + +def test_orchestration_signal_class_entity_and_custom_name(): + """Test signaling a class-based entity with a custom name from an orchestration.""" + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") + ctx.signal_entity(entity_id, "do_nothing") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(EmptyEntity, name="EntityNameCustom") + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_class_entity(): + """Test calling a class-based entity from an orchestration and awaiting the result.""" + invoked = False + + class EmptyEntity(entities.DurableEntity): + def do_nothing(self, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity") + yield ctx.call_entity(entity_id, "do_nothing") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(EmptyEntity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None diff --git a/tests/durabletask/entities/test_entity_failure_handling.py b/tests/durabletask/entities/test_entity_failure_handling.py new file mode 100644 index 0000000..2db398b --- /dev/null +++ b/tests/durabletask/entities/test_entity_failure_handling.py @@ -0,0 +1,183 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +E2E tests for entity failure handling using the in-memory backend. +""" + +import json + +import pytest + +from durabletask import client, entities, task, worker +from durabletask.testing import create_test_backend + +HOST = "localhost:50057" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for entity testing.""" + b = create_test_backend(port=50057) + yield b + b.stop() + b.reset() + + +def test_class_entity_unhandled_failure_fails(): + """Test that an unhandled exception in a class entity causes the orchestration to fail.""" + class FailingEntity(entities.DurableEntity): + def fail(self, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("FailingEntity", "testEntity") + yield ctx.call_entity(entity_id, "fail") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(FailingEntity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert "Something went wrong!" in state.failure_details.message + assert state.runtime_status == client.OrchestrationStatus.FAILED + + +def test_function_entity_unhandled_failure_fails(): + """Test that an unhandled exception in a function entity causes the orchestration to fail.""" + def failing_entity(ctx: entities.EntityContext, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("failing_entity", "testEntity") + yield ctx.call_entity(entity_id, "fail") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(failing_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert "Something went wrong!" in state.failure_details.message + assert state.runtime_status == client.OrchestrationStatus.FAILED + + +def test_class_entity_handled_failure_succeeds(): + """Test that a handled exception in a class entity allows the orchestration to succeed.""" + class FailingEntity(entities.DurableEntity): + def fail(self, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("FailingEntity", "testEntity") + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError as e: + return e.details.message + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(FailingEntity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert "Something went wrong!" in output + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + +def test_function_entity_handled_failure_succeeds(): + """Test that a handled exception in a function entity allows the orchestration to succeed.""" + def failing_entity(ctx: entities.EntityContext, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId("failing_entity", "testEntity") + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError as e: + return e.details.message + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(failing_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert "Something went wrong!" in output + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + +def test_entity_failure_unlocks_entity(): + """Test that an entity failure properly unlocks the entity for subsequent operations.""" + def failing_entity(ctx: entities.EntityContext, _): + raise ValueError("Something went wrong!") + + def test_orchestrator(ctx: task.OrchestrationContext, _): + exception_count = 0 + entity_id = entities.EntityInstanceId("failing_entity", "testEntity") + with (yield ctx.lock_entities([entity_id])): + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError: + exception_count += 1 + try: + yield ctx.call_entity(entity_id, "fail") + except task.TaskFailedError: + exception_count += 1 + return exception_count + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(test_orchestrator) + w.add_entity(failing_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(test_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(test_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + + assert state.serialized_output is not None + output = json.loads(state.serialized_output) + assert output == 2 + assert state.runtime_status == client.OrchestrationStatus.COMPLETED diff --git a/tests/durabletask/entities/test_function_based_entities_e2e.py b/tests/durabletask/entities/test_function_based_entities_e2e.py new file mode 100644 index 0000000..e19a6c4 --- /dev/null +++ b/tests/durabletask/entities/test_function_based_entities_e2e.py @@ -0,0 +1,357 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +E2E tests for function-based durable entities using the in-memory backend. +""" + +import time + +import pytest + +from durabletask import client, entities, task, worker +from durabletask.testing import create_test_backend + +HOST = "localhost:50056" + + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for entity testing.""" + b = create_test_backend(port=50056) + yield b + b.stop() + b.reset() + + +def test_client_signal_entity_and_custom_name(): + """Test signaling a function-based entity with a custom registration name from the client.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoked # don't do this in a real app! + if ctx.operation == "do_nothing": + invoked = True + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(empty_entity, name="EntityNameCustom") + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + + assert invoked + + +def test_client_get_entity(): + """Test signaling a function-based entity and reading its state via the client.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoked # don't do this in a real app! + if ctx.operation == "do_nothing": + invoked = True + ctx.set_state(1) + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + entity_id = entities.EntityInstanceId("empty_entity", "testEntity") + c.signal_entity(entity_id, "do_nothing") + time.sleep(2) # wait for the signal to be processed + state = c.get_entity(entity_id, include_state=True) + assert state is not None + assert state.id == entity_id + assert state.get_state(int) == 1 + + assert invoked + + +def test_orchestration_signal_entity_and_custom_name(): + """Test signaling a function-based entity with a custom name from an orchestration.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "EntityNameCustom", f"{ctx.instance_id}_testEntity") + ctx.signal_entity(entity_id, "do_nothing") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity, name="EntityNameCustom") + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the signal to be processed + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_entity(): + """Test calling a function-based entity from an orchestration and awaiting the result.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "empty_entity", f"{ctx.instance_id}_testEntity") + yield ctx.call_entity(entity_id, "do_nothing") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_orchestration_call_entity_with_lock(): + """Test calling a function-based entity from an orchestration with entity locking.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "empty_entity", f"{ctx.instance_id}_testEntity") + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "do_nothing") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + # Call a second time to ensure the entity is still responsive + # after being locked and unlocked + id_2 = c.schedule_new_orchestration(empty_orchestrator) + state_2 = c.wait_for_orchestration_completion(id_2, timeout=30) + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + assert state_2 is not None + assert state_2.name == task.get_name(empty_orchestrator) + assert state_2.instance_id == id_2 + assert state_2.failure_details is None + assert state_2.runtime_status == client.OrchestrationStatus.COMPLETED + assert state_2.serialized_input is None + assert state_2.serialized_output is None + assert state_2.serialized_custom_status is None + + +def test_orchestration_entity_signals_entity(): + """Test that an entity can signal another entity during an orchestration call.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "do_nothing": + nonlocal invoked # don't do this in a real app! + invoked = True + elif ctx.operation == "signal_other": + entity_id = entities.EntityInstanceId( + "empty_entity", + ctx.entity_id.key.replace("_testEntity", "_otherEntity")) + ctx.signal_entity(entity_id, "do_nothing") + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "empty_entity", f"{ctx.instance_id}_testEntity") + yield ctx.call_entity(entity_id, "signal_other") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + time.sleep(2) # wait for the entity-to-entity signal to be processed + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_entity_starts_orchestration(): + """Test that an entity can start a new orchestration.""" + invoked = False + + def empty_entity(ctx: entities.EntityContext, _): + if ctx.operation == "start_orchestration": + ctx.schedule_new_orchestration("empty_orchestrator") + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + c.signal_entity( + entities.EntityInstanceId("empty_entity", "testEntity"), + "start_orchestration") + time.sleep(3) # wait for the signal and orchestration to be processed + + assert invoked + + +def test_entity_locking_behavior(): + """Test entity locking constraints: cannot signal locked entities or double-call them.""" + def empty_entity(ctx: entities.EntityContext, _): + pass + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "empty_entity", f"{ctx.instance_id}_testEntity") + with (yield ctx.lock_entities([entity_id])): + # Cannot signal entities that have been locked + assert pytest.raises(Exception, ctx.signal_entity, entity_id, "do_nothing") + entity_call_task = ctx.call_entity(entity_id, "do_nothing") + # Cannot call entities that have been locked and already called + assert pytest.raises(Exception, ctx.call_entity, entity_id, "do_nothing") + yield entity_call_task + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + id = c.schedule_new_orchestration(empty_orchestrator) + state = c.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +def test_entity_unlocks_when_user_code_throws(): + """Test that entities are unlocked when orchestrator user code throws an exception.""" + invoke_count = 0 + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoke_count # don't do this in a real app! + invoke_count += 1 + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "empty_entity", f"{ctx.instance_id}_testEntity") + with (yield ctx.lock_entities([entity_id])): + yield ctx.call_entity(entity_id, "do_nothing") + raise Exception("Simulated exception") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + time.sleep(2) # wait for initial setup + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + + assert invoke_count == 2 + + +def test_entity_unlocks_when_user_mishandles_lock(): + """Test that entities are unlocked when the user yields lock but doesn't use context manager.""" + invoke_count = 0 + + def empty_entity(ctx: entities.EntityContext, _): + nonlocal invoke_count # don't do this in a real app! + invoke_count += 1 + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + entity_id = entities.EntityInstanceId( + "empty_entity", f"{ctx.instance_id}_testEntity") + yield ctx.lock_entities([entity_id]) + yield ctx.call_entity(entity_id, "do_nothing") + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(empty_orchestrator) + w.add_entity(empty_entity) + w.start() + + c = client.TaskHubGrpcClient(host_address=HOST) + time.sleep(2) # wait for initial setup + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + id = c.schedule_new_orchestration(empty_orchestrator) + c.wait_for_orchestration_completion(id, timeout=30) + + assert invoke_count == 2 + + +def test_get_entity_not_found(): + """Test that get_entity returns None for a non-existent entity.""" + c = client.TaskHubGrpcClient(host_address=HOST) + entity_id = entities.EntityInstanceId("counter", "nonexistent") + metadata = c.get_entity(entity_id, include_state=True) + assert metadata is None diff --git a/tests/durabletask/test_batch_actions.py b/tests/durabletask/test_batch_actions.py index afca768..f25805b 100644 --- a/tests/durabletask/test_batch_actions.py +++ b/tests/durabletask/test_batch_actions.py @@ -1,11 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Tests for batch query and purge APIs using the InMemoryOrchestrationBackend. +""" + +import logging +import time +from datetime import datetime, timedelta, timezone import pytest -from durabletask import task +from durabletask import client, entities, task +from durabletask.client import TaskHubGrpcClient +from durabletask.testing import create_test_backend +from durabletask.worker import TaskHubGrpcWorker + +BATCH_TEST_PORT = 50058 +HOST = f"localhost:{BATCH_TEST_PORT}" -# NOTE: These tests assume a sidecar process is running. Example command: -# go install github.com/microsoft/durabletask-go@main -# durabletask-go --port 4001 -pytestmark = pytest.mark.e2e + +@pytest.fixture +def backend(): + """Create an in-memory backend for batch action testing.""" + backend = create_test_backend(port=BATCH_TEST_PORT) + yield backend + backend.stop() + backend.reset() def empty_orchestrator(ctx: task.OrchestrationContext, _): @@ -16,58 +36,424 @@ def failing_orchestrator(ctx: task.OrchestrationContext, _): raise Exception("Orchestration failed") -# NOTE: The Go sidecar does not implement the batch query APIs, so these tests will need to be completed and -# enabled at such time when we have a way to test this outside of `durabletask-azuremanaged`. +def test_get_all_orchestration_states(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(empty_orchestrator) + worker.start() + + try: + id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") + c.wait_for_orchestration_completion(id, timeout=30) + + all_orchestrations = c.get_all_orchestration_states() + query = client.OrchestrationQuery() + query.fetch_inputs_and_outputs = True + all_orchestrations_with_state = c.get_all_orchestration_states(query) + this_orch = c.get_orchestration_state(id) + finally: + worker.stop() + + assert this_orch is not None + assert this_orch.instance_id == id + + assert all_orchestrations is not None + matching_orchestrations = [o for o in all_orchestrations if o.instance_id == id] + assert len(matching_orchestrations) == 1 + orchestration_state = matching_orchestrations[0] + assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert orchestration_state.serialized_input is None + assert orchestration_state.serialized_output is None + assert orchestration_state.failure_details is None + + assert all_orchestrations_with_state is not None + matching_orchestrations = [o for o in all_orchestrations_with_state if o.instance_id == id] + assert len(matching_orchestrations) == 1 + orchestration_state = matching_orchestrations[0] + assert orchestration_state.runtime_status == client.OrchestrationStatus.COMPLETED + assert orchestration_state.serialized_input == '"Hello"' + assert orchestration_state.serialized_output == '"Complete"' + assert orchestration_state.failure_details is None + + +def test_get_orchestration_state_by_status(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(empty_orchestrator) + worker.add_orchestrator(failing_orchestrator) + worker.start() + + try: + # Schedule completed orchestration + completed_id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") + c.wait_for_orchestration_completion(completed_id, timeout=30) + + # Schedule failed orchestration + failed_id = c.schedule_new_orchestration(failing_orchestrator) + try: + c.wait_for_orchestration_completion(failed_id, timeout=30) + except client.OrchestrationFailedError: + pass # Expected failure + + # Query by completed status + query = client.OrchestrationQuery() + query.runtime_status = [client.OrchestrationStatus.COMPLETED] + query.fetch_inputs_and_outputs = True + completed_orchestrations = c.get_all_orchestration_states(query) + + # Query by failed status + query = client.OrchestrationQuery() + query.runtime_status = [client.OrchestrationStatus.FAILED] + query.fetch_inputs_and_outputs = True + failed_orchestrations = c.get_all_orchestration_states(query) + finally: + worker.stop() + + assert len([o for o in completed_orchestrations if o.instance_id == completed_id]) == 1 + completed_orch = [o for o in completed_orchestrations if o.instance_id == completed_id][0] + assert completed_orch.runtime_status == client.OrchestrationStatus.COMPLETED + assert completed_orch.serialized_output == '"Complete"' + + assert len([o for o in failed_orchestrations if o.instance_id == failed_id]) == 1 + failed_orch = [o for o in failed_orchestrations if o.instance_id == failed_id][0] + assert failed_orch.runtime_status == client.OrchestrationStatus.FAILED + assert failed_orch.failure_details is not None + + +def test_get_orchestration_state_by_time_range(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(empty_orchestrator) + worker.start() + + try: + # Get current time + before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) + + # Schedule orchestration + id = c.schedule_new_orchestration(empty_orchestrator, input="TimeTest") + c.wait_for_orchestration_completion(id, timeout=30) + + after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) + + # Query by time range + query = client.OrchestrationQuery( + created_time_from=before_creation, + created_time_to=after_creation, + fetch_inputs_and_outputs=True + ) + orchestrations_in_range = c.get_all_orchestration_states(query) + + # Query outside time range + query = client.OrchestrationQuery( + created_time_from=after_creation, + created_time_to=after_creation + timedelta(hours=1), + fetch_inputs_and_outputs=True + ) + orchestrations_outside_range = c.get_all_orchestration_states(query) + finally: + worker.stop() + + assert len([o for o in orchestrations_in_range if o.instance_id == id]) == 1 + assert len([o for o in orchestrations_outside_range if o.instance_id == id]) == 0 + + +def test_get_orchestration_state_pagination_succeeds(backend): + # Create a custom handler to capture log messages + log_records = [] + + class ListHandler(logging.Handler): + def emit(self, record): + log_records.append(record) + + handler = ListHandler() + + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST, log_handler=handler) + + worker.add_orchestrator(empty_orchestrator) + worker.start() + + try: + # Create at least 3 orchestrations to test the limit + ids = [] + for i in range(3): + id = c.schedule_new_orchestration(empty_orchestrator, input=f"Test{i}") + ids.append(id) + + # Wait for all to complete + for id in ids: + c.wait_for_orchestration_completion(id, timeout=30) + + # Query with max_instance_count=2 + query = client.OrchestrationQuery(max_instance_count=2) + orchestrations = c.get_all_orchestration_states(query) + finally: + worker.stop() + + # Should return more than 2 instances since we created at least 3 + assert len(orchestrations) > 2 + # Verify the pagination loop ran by checking for the continuation token log message + assert any("Received continuation token" in record.getMessage() for record in log_records), \ + "Expected pagination loop to execute with continuation token" + + +def test_purge_orchestration(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(empty_orchestrator) + worker.start() + + try: + # Schedule and complete orchestration + id = c.schedule_new_orchestration(empty_orchestrator, input="ToPurge") + c.wait_for_orchestration_completion(id, timeout=30) + + # Verify it exists + state_before = c.get_orchestration_state(id) + assert state_before is not None + + # Purge the orchestration + result = c.purge_orchestration(id, recursive=True) + + # Verify purge result + assert result.deleted_instance_count >= 1 + + # Verify it no longer exists + state_after = c.get_orchestration_state(id) + assert state_after is None + finally: + worker.stop() + + +def test_purge_orchestrations_by_status(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(failing_orchestrator) + worker.start() + + try: + # Schedule and let it fail + failed_id = c.schedule_new_orchestration(failing_orchestrator) + try: + c.wait_for_orchestration_completion(failed_id, timeout=30) + except client.OrchestrationFailedError: + pass # Expected failure + + # Verify it exists and is failed + state_before = c.get_orchestration_state(failed_id) + assert state_before is not None + assert state_before.runtime_status == client.OrchestrationStatus.FAILED + + # Purge failed orchestrations + result = c.purge_orchestrations_by( + runtime_status=[client.OrchestrationStatus.FAILED], + recursive=True + ) + + # Verify purge result + assert result.deleted_instance_count >= 1 + + # Verify the failed orchestration no longer exists + state_after = c.get_orchestration_state(failed_id) + assert state_after is None + finally: + worker.stop() + + +def test_purge_orchestrations_by_time_range(backend): + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_orchestrator(empty_orchestrator) + worker.start() + + try: + # Get current time + before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) + + # Schedule orchestration + id = c.schedule_new_orchestration(empty_orchestrator, input="ToPurgeByTime") + c.wait_for_orchestration_completion(id, timeout=30) + + after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) + + # Verify it exists + state_before = c.get_orchestration_state(id) + assert state_before is not None + + # Purge by time range + result = c.purge_orchestrations_by( + created_time_from=before_creation, + created_time_to=after_creation, + runtime_status=[client.OrchestrationStatus.COMPLETED], + recursive=True + ) + + # Verify purge result + assert result.deleted_instance_count >= 1 + + # Verify it no longer exists + state_after = c.get_orchestration_state(id) + assert state_after is None + finally: + worker.stop() + + +def test_get_all_entities(backend): + counter_value = 0 + + def counter_entity(ctx: entities.EntityContext, input): + nonlocal counter_value + if ctx.operation == "add": + counter_value += input + ctx.set_state(counter_value) + elif ctx.operation == "get": + return ctx.get_state(int, 0) + + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_entity(counter_entity) + worker.start() + + try: + # Create entity + entity_id = entities.EntityInstanceId("counter_entity", "testCounter1") + c.signal_entity(entity_id, "add", input=5) + time.sleep(3) # Wait for signal to be processed + + # Get all entities without state + query = client.EntityQuery(include_state=False) + all_entities = c.get_all_entities(query) + assert len([e for e in all_entities if e.id == entity_id]) == 1 + entity_without_state = [e for e in all_entities if e.id == entity_id][0] + assert entity_without_state.get_state(int) is None + + # Get all entities with state + query = client.EntityQuery(include_state=True) + all_entities_with_state = c.get_all_entities(query) + assert len([e for e in all_entities_with_state if e.id == entity_id]) == 1 + entity_with_state = [e for e in all_entities_with_state if e.id == entity_id][0] + assert entity_with_state.get_state(int) == 5 + finally: + worker.stop() + + +def test_get_entities_by_instance_id_prefix(backend): + def counter_entity(ctx: entities.EntityContext, input): + if ctx.operation == "set": + ctx.set_state(input) + + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_entity(counter_entity) + worker.start() + + try: + # Create entities with different prefixes + entity_id_1 = entities.EntityInstanceId("counter_entity", "prefix1_counter") + entity_id_2 = entities.EntityInstanceId("counter_entity", "prefix2_counter") + + c.signal_entity(entity_id_1, "set", input=10) + c.signal_entity(entity_id_2, "set", input=20) + time.sleep(3) # Wait for signals to be processed + + # Query by prefix + query = client.EntityQuery( + instance_id_starts_with="@counter_entity@prefix1", + include_state=True + ) + entities_prefix1 = c.get_all_entities(query) + + query = client.EntityQuery( + instance_id_starts_with="@counter_entity@prefix2", + include_state=True + ) + entities_prefix2 = c.get_all_entities(query) + finally: + worker.stop() + + assert len([e for e in entities_prefix1 if e.id == entity_id_1]) == 1 + assert len([e for e in entities_prefix1 if e.id == entity_id_2]) == 0 + + assert len([e for e in entities_prefix2 if e.id == entity_id_2]) == 1 + assert len([e for e in entities_prefix2 if e.id == entity_id_1]) == 0 + + +def test_get_entities_by_time_range(backend): + def simple_entity(ctx: entities.EntityContext, input): + if ctx.operation == "set": + ctx.set_state(input) + + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) + + worker.add_entity(simple_entity) + worker.start() + + try: + # Get current time + before_creation = datetime.now(timezone.utc) - timedelta(seconds=5) -# def test_get_all_orchestration_states(): -# # Start a worker, which will connect to the sidecar in a background thread -# with worker.TaskHubGrpcWorker() as w: -# w.add_orchestrator(empty_orchestrator) -# w.start() + # Create entity + entity_id = entities.EntityInstanceId("simple_entity", "timeTestEntity") + c.signal_entity(entity_id, "set", input="test_value") + time.sleep(3) # Wait for signal to be processed -# c = client.TaskHubGrpcClient() -# id = c.schedule_new_orchestration(empty_orchestrator, input="Hello") -# c.wait_for_orchestration_completion(id, timeout=30) + after_creation = datetime.now(timezone.utc) + timedelta(seconds=5) -# with pytest.raises(_InactiveRpcError) as exec_info: -# c.get_all_orchestration_states() -# assert "unimplemented" in str(exec_info.value) + # Query by time range + query = client.EntityQuery( + last_modified_from=before_creation, + last_modified_to=after_creation, + include_state=True + ) + entities_in_range = c.get_all_entities(query) + # Query outside time range + query = client.EntityQuery( + last_modified_from=after_creation, + last_modified_to=after_creation + timedelta(hours=1) + ) + entities_outside_range = c.get_all_entities(query) + finally: + worker.stop() -# def test_get_all_entities(): -# # Start a worker, which will connect to the sidecar in a background thread -# with worker.TaskHubGrpcWorker() as w: -# w.add_orchestrator(empty_orchestrator) -# w.start() + assert len([e for e in entities_in_range if e.id == entity_id]) == 1 + assert len([e for e in entities_outside_range if e.id == entity_id]) == 0 -# c = client.TaskHubGrpcClient() -# with pytest.raises(_InactiveRpcError) as exec_info: -# c.get_all_entities() -# assert "method QueryEntities not implemented" in str(exec_info.value) +def test_clean_entity_storage(backend): + class EmptyEntity(entities.DurableEntity): + pass -# def test_clean_entity_storage(): -# # Start a worker, which will connect to the sidecar in a background thread -# with worker.TaskHubGrpcWorker() as w: -# w.add_orchestrator(empty_orchestrator) -# w.start() + worker = TaskHubGrpcWorker(host_address=HOST) + c = TaskHubGrpcClient(host_address=HOST) -# c = client.TaskHubGrpcClient() -# with pytest.raises(_InactiveRpcError) as exec_info: -# c.clean_entity_storage() -# assert "method CleanEntityStorage not implemented" in str(exec_info.value) + worker.add_entity(EmptyEntity) + worker.start() + try: + # Create an entity and then delete its state to make it empty + entity_id = entities.EntityInstanceId("EmptyEntity", "toClean") + c.signal_entity(entity_id, "delete") + time.sleep(3) # Wait for signal to be processed -# def test_purge_orchestrations_by_status(): -# with worker.TaskHubGrpcWorker() as w: -# w.add_orchestrator(failing_orchestrator) -# w.start() + # Clean entity storage + result = c.clean_entity_storage( + remove_empty_entities=True, + release_orphaned_locks=True + ) + finally: + worker.stop() -# c = client.TaskHubGrpcClient() -# with pytest.raises(_InactiveRpcError) as exec_info: -# c.purge_orchestrations_by( -# runtime_status=[client.OrchestrationStatus.FAILED], -# recursive=True -# ) -# # sic - error returned from sidecar -# assert "multi-instance purge is not unimplemented" in str(exec_info.value) + # Verify clean result + assert result.empty_entities_removed >= 0 + assert result.orphaned_locks_released >= 0 diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 92a20d9..a6f670c 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -10,13 +10,18 @@ import pytest from durabletask import client, task, worker +from durabletask.testing import create_test_backend -from grpc._channel import _InactiveRpcError +HOST = "localhost:50054" -# NOTE: These tests assume a sidecar process is running. Example command: -# go install github.com/microsoft/durabletask-go@main -# durabletask-go --port 4001 -pytestmark = pytest.mark.e2e + +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for testing.""" + b = create_test_backend(port=50054) + yield b + b.stop() + b.reset() def test_empty_orchestration(): @@ -27,12 +32,11 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): nonlocal invoked # don't do this in a real app! invoked = True - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(empty_orchestrator) w.start() - c = client.TaskHubGrpcClient() + c = client.TaskHubGrpcClient(host_address=HOST) id = c.schedule_new_orchestration(empty_orchestrator, tags={'Tagged': 'true'}) state = c.wait_for_orchestration_completion(id, timeout=30) @@ -60,13 +64,12 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): numbers.append(current) return numbers - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(sequence) w.add_activity(plus_one) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(sequence, input=1, tags={'Orchestration': 'Sequence'}) state = task_hub_client.wait_for_orchestration_completion( id, timeout=30) @@ -105,14 +108,13 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): return error_msg - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.add_activity(throw) w.add_activity(increment_counter) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) @@ -148,14 +150,13 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): # Wait for all sub-orchestrations to complete yield task.when_all(tasks) - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_activity(increment) w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=10) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) @@ -175,13 +176,12 @@ def orchestrator_child(ctx: task.OrchestrationContext, _): def parent_orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator("orchestrator_child") - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator_child) w.add_orchestrator(parent_orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) @@ -198,13 +198,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): c = yield ctx.wait_for_external_event('C') return [a, b, c] - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.start() # Start the orchestration and immediately raise events to it. - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator) task_hub_client.raise_orchestration_event(id, 'A', data='a') task_hub_client.raise_orchestration_event(id, 'B', data='b') @@ -227,13 +226,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): else: return "timed out" - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.start() # Start the orchestration and immediately raise events to it. - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator) if raise_event: task_hub_client.raise_orchestration_event(id, 'Approval') @@ -252,19 +250,20 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.wait_for_external_event("my_event") return result - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator) state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None # Suspend the orchestration and wait for it to go into the SUSPENDED state task_hub_client.suspend_orchestration(id) + deadline = time.time() + 10 while state.runtime_status == client.OrchestrationStatus.RUNNING: + assert time.time() < deadline, "Timed out waiting for SUSPENDED status" time.sleep(0.1) state = task_hub_client.get_orchestration_state(id) assert state is not None @@ -275,7 +274,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): try: state = task_hub_client.wait_for_orchestration_completion(id, timeout=3) assert False, "Orchestration should not have completed" - except (TimeoutError, _InactiveRpcError): + except TimeoutError: pass # Resume the orchestration and wait for it to complete @@ -291,12 +290,11 @@ def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.wait_for_external_event("my_event") return result - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator) state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None @@ -318,13 +316,12 @@ def child(ctx: task.OrchestrationContext, _): result = yield ctx.wait_for_external_event("my_event") return result - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(root) w.add_orchestrator(child) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(root) state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None @@ -346,7 +343,6 @@ def child(ctx: task.OrchestrationContext, _): assert state is None -@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance") def test_restart_with_same_instance_id(): def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.call_activity(say_hello, input="World") @@ -356,12 +352,12 @@ def say_hello(ctx: task.ActivityContext, input: str): return f"Hello, {input}!" # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.add_activity(say_hello) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None @@ -378,7 +374,6 @@ def say_hello(ctx: task.ActivityContext, input: str): assert state.serialized_output == json.dumps("Hello, World!") -@pytest.mark.skip(reason="durabletask-go does not yet support RestartInstance") def test_restart_with_new_instance_id(): def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.call_activity(say_hello, input="World") @@ -388,12 +383,12 @@ def say_hello(ctx: task.ActivityContext, input: str): return f"Hello, {input}!" # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.add_activity(say_hello) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None @@ -424,12 +419,11 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): else: return all_results - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(orchestrator) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(orchestrator, input=0) task_hub_client.raise_orchestration_event(id, "my_event", data=1) task_hub_client.raise_orchestration_event(id, "my_event", data=2) @@ -445,9 +439,6 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert all_results == [1, 2, 3, 4, 5] -# NOTE: This test fails when running against durabletask-go with sqlite because the sqlite backend does not yet -# support orchestration ID reuse. This gap is being tracked here: -# https://github.com/microsoft/durabletask-go/issues/42 def test_retry_policies(): # This test verifies that the retry policies are working as expected. # It does this by creating an orchestration that calls a sub-orchestrator, @@ -483,13 +474,13 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(parent_orchestrator_with_retry) w.add_orchestrator(child_orchestrator_with_retry) w.add_activity(throw_activity_with_retry) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(parent_orchestrator_with_retry) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None @@ -524,12 +515,12 @@ def throw_activity(ctx: task.ActivityContext, _): throw_activity_counter += 1 raise RuntimeError("Kah-BOOOOM!!!") - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(mock_orchestrator) w.add_activity(throw_activity) w.start() - task_hub_client = client.TaskHubGrpcClient() + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) id = task_hub_client.schedule_new_orchestration(mock_orchestrator) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None @@ -546,12 +537,11 @@ def test_custom_status(): def empty_orchestrator(ctx: task.OrchestrationContext, _): ctx.set_custom_status("foobaz") - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(empty_orchestrator) w.start() - c = client.TaskHubGrpcClient() + c = client.TaskHubGrpcClient(host_address=HOST) id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) @@ -577,13 +567,12 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): results.append(ctx.new_uuid()) return results - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(empty_orchestrator) w.add_activity(noop) w.start() - c = client.TaskHubGrpcClient() + c = client.TaskHubGrpcClient(host_address=HOST) id = c.schedule_new_orchestration(empty_orchestrator) state = c.wait_for_orchestration_completion(id, timeout=30) diff --git a/tests/durabletask/test_orchestration_versioning_e2e.py b/tests/durabletask/test_orchestration_versioning_e2e.py index 45dd2bd..6adce83 100644 --- a/tests/durabletask/test_orchestration_versioning_e2e.py +++ b/tests/durabletask/test_orchestration_versioning_e2e.py @@ -2,24 +2,25 @@ # Licensed under the MIT License. import json -import warnings import pytest from durabletask import client, task, worker +from durabletask.testing import create_test_backend -# NOTE: These tests assume a sidecar process is running. Example command: -# go install github.com/microsoft/durabletask-go@main -# durabletask-go --port 4001 -pytestmark = pytest.mark.e2e +HOST = "localhost:50055" -def test_versioned_orchestration_succeeds(): - warnings.warn("Skipping test_versioned_orchestration_succeeds. " - "Currently not passing as the sidecar does not support versioning yet") - return # Currently not passing as the sidecar does not support versioning yet - # Remove these lines to run the test after the sidecar is updated +@pytest.fixture(autouse=True) +def backend(): + """Create an in-memory backend for testing.""" + b = create_test_backend(port=50055) + yield b + b.stop() + b.reset() + +def test_versioned_orchestration_succeeds(): def plus_one(_: task.ActivityContext, input: int) -> int: return input + 1 @@ -31,8 +32,7 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): numbers.append(current) return numbers - # Start a worker, which will connect to the sidecar in a background thread - with worker.TaskHubGrpcWorker() as w: + with worker.TaskHubGrpcWorker(host_address=HOST) as w: w.add_orchestrator(sequence) w.add_activity(plus_one) w.use_versioning(worker.VersioningOptions( @@ -43,7 +43,7 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): )) w.start() - task_hub_client = client.TaskHubGrpcClient(default_version="1.0.0") + task_hub_client = client.TaskHubGrpcClient(host_address=HOST, default_version="1.0.0") id = task_hub_client.schedule_new_orchestration(sequence, input=1, tags={'Orchestration': 'Sequence'}, version="1.0.0") state = task_hub_client.wait_for_orchestration_completion( id, timeout=30)