diff --git a/README.md b/README.md index f26c7b837..223a1f113 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,9 @@ informal introduction to the features and their implementation. - [Data Conversion](#data-conversion) - [Pydantic Support](#pydantic-support) - [Custom Type Data Conversion](#custom-type-data-conversion) + - [External Storage](#external-storage) + - [Driver Selection](#driver-selection) + - [Custom Drivers](#custom-drivers) - [Workers](#workers) - [Workflows](#workflows) - [Definition](#definition) @@ -309,8 +312,9 @@ other_ns_client = Client(**config) Data converters are used to convert raw Temporal payloads to/from actual Python types. A custom data converter of type `temporalio.converter.DataConverter` can be set via the `data_converter` parameter of the `Client` constructor. Data -converters are a combination of payload converters, payload codecs, and failure converters. Payload converters convert -Python values to/from serialized bytes. Payload codecs convert bytes to bytes (e.g. for compression or encryption). +converters are a combination of payload converters, external storage, payload codecs, and failure converters. Payload +converters convert Python values to/from serialized bytes. External payload storage optionally stores and retrieves payloads +to/from external storage services using drivers. Payload codecs convert bytes to bytes (e.g. for compression or encryption). Failure converters convert exceptions to/from serialized failures. The default data converter supports converting multiple types including: @@ -455,6 +459,130 @@ my_data_converter = dataclasses.replace( Now `IPv4Address` can be used in type hints including collections, optionals, etc. +##### External Storage + +⚠️ **External storage support is currently at an experimental release stage.** ⚠️ + +External storage allows large payloads to be offloaded to an external storage service (such as Amazon S3) rather than stored inline in workflow history. This is useful when workflows or activities work with data that would otherwise exceed Temporal's payload size limits. + +External storage is configured via the `external_storage` parameter on `DataConverter`. It should be configured on the `Client` both for clients of your workflow as well as on the worker -- anywhere large payloads may be uploaded or downloaded. + +A `StorageDriver` handles uploading and downloading payloads. Temporal provides built-in drivers for common storage solutions, or you may customize one. Here's an example using our provided `InMemoryTestDriver`. + +```python +import dataclasses +from temporalio.client import Client +from temporalio.converter import DataConverter +from temporalio.converter import ExternalStorage + +driver = InMemoryTestDriver() + +client = await Client.connect( + "localhost:7233", + data_converter=dataclasses.replace( + DataConverter.default, + external_storage=ExternalStorage(drivers=[driver]), + ), +) +``` + +Some things to note about external storage: + +* Only payloads that meet or exceed `ExternalStorage.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. +* External storage applies transparently to all payloads, whether they are workflow inputs/outputs, activity inputs/outputs, signal inputs, query outputs, update inputs/outputs, or failure details. +* The `DataConverter`'s `payload_codec` (if configured) is applied to the payload *before* it is handed to the storage driver, so the driver always stores encoded bytes. The reference payload written to workflow history is not encoded by the `DataConverter` codec. +* Setting `ExternalStorage.payload_size_threshold` to `None` causes every payload to be considered for external storage regardless of size. + +###### Driver Selection + +When multiple storage backends are needed, list all drivers in `ExternalStorage.drivers` and provide a `driver_selector` to control which driver stores new payloads. Any driver in the list not chosen for storing is still available for retrieval, which is useful when migrating between storage backends. + +```python +from temporalio.converter import ExternalStorage + +options = ExternalStorage( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: ( + hot_driver if payload.ByteSize() < 5 * 1024 * 1024 else cold_driver + ), +) +``` + +For more complex selection logic, use a plain callable that reads from the `StorageDriverStoreContext`: + +```python +import temporalio.converter +from temporalio.api.common.v1 import Payload + +def feature_flag_is_on(workflow_id: str | None) -> bool: + """Check whether external storage is enabled for this workflow via a feature flag service.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 + +def feature_flag_selector( + context: temporalio.converter.StorageDriverStoreContext, _payload: Payload +) -> temporalio.converter.StorageDriver | None: + workflow_id = None + if isinstance(context.serialization_context, temporalio.converter.WorkflowSerializationContext): + workflow_id = context.serialization_context.workflow_id + elif isinstance(context.serialization_context, temporalio.converter.ActivitySerializationContext): + workflow_id = context.serialization_context.workflow_id + return my_driver if feature_flag_is_on(workflow_id) else None + +options = ExternalStorage( + drivers=[my_driver], + driver_selector=feature_flag_selector, +) +``` + +Some things to note about driver selection: + +* A `driver_selector` is required when more than one driver is registered. With a single driver, `driver_selector` may be omitted and that driver is used for all store operations. +* Returning `None` from a selector leaves the payload stored inline in workflow history rather than offloading it. +* The driver instance returned by the selector must be one of the instances registered in `ExternalStorage.drivers`. If it is not, an error is raised. + +###### Custom Drivers + +Implement `temporalio.converter.StorageDriver` to integrate with an external storage system: + +```python +from collections.abc import Sequence +from temporalio.converter import StorageDriver, StorageDriverClaim, StorageDriverRetrieveContext, StorageDriverStoreContext +from temporalio.api.common.v1 import Payload + +class MyDriver(StorageDriver): + def __init__(self, driver_name: str | None = None): + self._driver_name = driver_name or "my-org:driver:my-driver" + + def name(self) -> str: + return self._driver_name + + async def store( + self, context: StorageDriverStoreContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: + claims = [] + for payload in payloads: + key = await my_storage.put(payload.SerializeToString()) + claims.append(StorageDriverClaim(data={"key": key})) + return claims + + async def retrieve( + self, context: StorageDriverRetrieveContext, claims: Sequence[StorageDriverClaim] + ) -> list[Payload]: + payloads = [] + for claim in claims: + data = await my_storage.get(claim.data["key"]) + p = Payload() + p.ParseFromString(data) + payloads.append(p) + return payloads +``` + +Some things to note about implementing a custom driver: + +* `StorageDriver.name()` must return a string that is unique among all drivers in `ExternalStorage.drivers`. This name is embedded in the reference payload stored in workflow history and used to look up the correct driver during retrieval — changing it after payloads have been stored will break retrieval. +* `StorageDriver.type()` is automatically implemented to return the name of the class. This can be overridden in subclasses but must remain consistent across all instances of the subclass. +* Implement `temporalio.converter.WithSerializationContext` on your driver to receive workflow or activity context (namespace, workflow ID, activity ID, etc.) at serialization time. + ### Workers Workers host workflows and/or activities. Here's how to run a worker: diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c98afefca..c2e426d28 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -303,10 +303,9 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode all payloads in the activation.""" - if data_converter._decode_payload_has_effect: - await CommandAwarePayloadVisitor( - skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(data_converter._decode_payload_sequence), activation) + await CommandAwarePayloadVisitor( + skip_search_attributes=True, skip_headers=not decode_headers + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) async def encode_completion( diff --git a/temporalio/converter/__init__.py b/temporalio/converter/__init__.py index d70bd6e76..2777e7e80 100644 --- a/temporalio/converter/__init__.py +++ b/temporalio/converter/__init__.py @@ -4,6 +4,14 @@ DataConverter, default, ) +from temporalio.converter._extstore import ( + ExternalStorage, + StorageDriver, + StorageDriverClaim, + StorageDriverRetrieveContext, + StorageDriverStoreContext, + StorageWarning, +) from temporalio.converter._failure_converter import ( DefaultFailureConverter, DefaultFailureConverterWithEncodedAttributes, @@ -44,6 +52,12 @@ __all__ = [ "ActivitySerializationContext", + "ExternalStorage", + "StorageDriver", + "StorageDriverClaim", + "StorageDriverRetrieveContext", + "StorageDriverStoreContext", + "StorageWarning", "AdvancedJSONEncoder", "BinaryNullPayloadConverter", "BinaryPlainPayloadConverter", diff --git a/temporalio/converter/_data_converter.py b/temporalio/converter/_data_converter.py index e9ac33158..9c2163774 100644 --- a/temporalio/converter/_data_converter.py +++ b/temporalio/converter/_data_converter.py @@ -14,6 +14,11 @@ import temporalio.api.common.v1 import temporalio.api.failure.v1 import temporalio.common +from temporalio.converter._extstore import ( + _REFERENCE_ENCODING, + ExternalStorage, + StorageWarning, +) from temporalio.converter._failure_converter import ( FailureConverter, ) @@ -72,6 +77,13 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" + external_storage: ExternalStorage | None = None + """Options for external storage. If None, external storage is disabled. + + .. warning:: + This API is experimental. + """ + default: ClassVar[DataConverter] """Singleton default data converter.""" @@ -158,18 +170,22 @@ def with_context(self, context: SerializationContext) -> Self: payload_converter = self.payload_converter payload_codec = self.payload_codec failure_converter = self.failure_converter + external_storage = self.external_storage if isinstance(payload_converter, WithSerializationContext): payload_converter = payload_converter.with_context(context) if isinstance(payload_codec, WithSerializationContext): payload_codec = payload_codec.with_context(context) if isinstance(failure_converter, WithSerializationContext): failure_converter = failure_converter.with_context(context) + if isinstance(external_storage, WithSerializationContext): + external_storage = external_storage.with_context(context) if all( new is orig for new, orig in [ (payload_converter, self.payload_converter), (payload_codec, self.payload_codec), (failure_converter, self.failure_converter), + (external_storage, self.external_storage), ] ): return self @@ -177,6 +193,7 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "payload_converter", payload_converter) object.__setattr__(cloned, "payload_codec", payload_codec) object.__setattr__(cloned, "failure_converter", failure_converter) + object.__setattr__(cloned, "external_storage", external_storage) return cloned def _with_payload_error_limits( @@ -238,12 +255,16 @@ async def _encode_payload( ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] + if self.external_storage: + payload = await self.external_storage._store_payload(payload) self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) + if self.external_storage: + await self.external_storage._store_payloads(payloads) self._validate_payload_limits(payloads.payloads) async def _encode_payload_sequence( @@ -252,32 +273,63 @@ async def _encode_payload_sequence( encoded_payloads = list(payloads) if self.payload_codec: encoded_payloads = await self.payload_codec.encode(encoded_payloads) + if self.external_storage: + encoded_payloads = await self.external_storage._store_payload_sequence( + encoded_payloads + ) self._validate_payload_limits(encoded_payloads) return encoded_payloads async def _decode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: + if self.external_storage: + payload = await self.external_storage._retrieve_payload(payload) if self.payload_codec: payload = (await self.payload_codec.decode([payload]))[0] return payload async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): + if self.external_storage: + await self.external_storage._retrieve_payloads(payloads) + else: + if any( + p.metadata.get("encoding") == _REFERENCE_ENCODING + for p in payloads.payloads + ): + warnings.warn( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", + StorageWarning, + ) if self.payload_codec: await self.payload_codec.decode_wrapper(payloads) async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - if not self.payload_codec: - return list(payloads) - return await self.payload_codec.decode(payloads) + decoded_payloads = list(payloads) + if self.external_storage: + decoded_payloads = await self.external_storage._retrieve_payload_sequence( + decoded_payloads + ) + else: + if any( + p.metadata.get("encoding") == _REFERENCE_ENCODING + for p in decoded_payloads + ): + warnings.warn( + "[TMPRL1105] Detected externally stored payload(s) but external storage is not configured.", + StorageWarning, + ) + if self.payload_codec: + decoded_payloads = await self.payload_codec.decode(decoded_payloads) + return decoded_payloads # Temporary shortcircuit detection while the _decode_* methods may no-op if # a payload codec is not configured. Remove once those paths have more to them. @property def _decode_payload_has_effect(self) -> bool: - return self.payload_codec is not None + return self.payload_codec is not None or self.external_storage is not None def _validate_payload_limits( self, diff --git a/temporalio/converter/_extstore.py b/temporalio/converter/_extstore.py new file mode 100644 index 000000000..614c8ac10 --- /dev/null +++ b/temporalio/converter/_extstore.py @@ -0,0 +1,433 @@ +"""External payload storage support for offloading payloads to external storage +systems. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +from abc import ABC, abstractmethod +from collections.abc import Callable, Coroutine, Mapping, Sequence +from dataclasses import dataclass +from typing import Any, ClassVar, TypeVar + +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload, Payloads +from temporalio.converter._payload_converter import JSONPlainPayloadConverter +from temporalio.converter._serialization_context import ( + SerializationContext, + WithSerializationContext, +) + +_T = TypeVar("_T") + +_REFERENCE_ENCODING = b"json/external-storage-reference" + + +async def _gather_cancel_on_error( + coros: Sequence[Coroutine[Any, Any, _T]], +) -> list[_T]: + """Run coroutines concurrently; cancel all remaining tasks if any one fails.""" + tasks = [asyncio.create_task(c) for c in coros] + try: + return await asyncio.gather(*tasks) + except BaseException: + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + +@dataclass(frozen=True) +class StorageDriverClaim: + """A driver-defined reference to an externally-stored payload that can be used to + retrieve it. + + .. warning:: + This API is experimental. + """ + + claim_data: Mapping[str, str] + """Driver-defined data for identifying and retrieving an externally stored + payload. + """ + + +@dataclass(frozen=True) +class StorageDriverStoreContext: + """Context passed to :meth:`StorageDriver.store` and ``driver_selector`` calls. + + .. warning:: + This API is experimental. + """ + + serialization_context: SerializationContext | None = None + """The serialization context active when this store operation was initiated, + or ``None`` if no context has been set. + """ + + +@dataclass(frozen=True) +class StorageDriverRetrieveContext: + """Context passed to :meth:`StorageDriver.retrieve` calls. + + .. warning:: + This API is experimental. + """ + + +class StorageDriver(ABC): + """Base driver for storing and retrieve payloads from external storage systems. + + .. warning:: + This API is experimental. + """ + + @abstractmethod + def name(self) -> str: + """Returns the name of this driver instance. A driver may allow + its name to be parameterized at construction time so that multiple + instances of the same driver class can coexist in + :attr:`ExternalStorage.drivers` with distinct names. + """ + raise NotImplementedError + + def type(self) -> str: + """Returns the type of the storage driver. This string should be + the same across all instantiations of the same driver class. This + allows the equivalent driver implementation in different languages + to be named the same. + + Defaults to the class name. Subclasses may override this to return a + stable, language-agnostic identifier. + """ + return type(self).__name__ + + @abstractmethod + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + """Stores payloads in external storage and returns a + :class:`StorageDriverClaim` for each one. The returned list must be the + same length as ``payloads``. + """ + raise NotImplementedError + + @abstractmethod + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + """Retrieves payloads from external storage for the given + :class:`StorageDriverClaim` list. The returned list must be the same + length as ``claims``. + """ + raise NotImplementedError + + +class StorageWarning(RuntimeWarning): + """Warning for external storage issues. + + .. warning:: + This API is experimental. + """ + + +@dataclass(frozen=True) +class _StorageReference: + driver_name: str + driver_claim: StorageDriverClaim + + +@dataclass(frozen=True) +class ExternalStorage(WithSerializationContext): + """Configuration for external storage behavior. + + .. warning:: + This API is experimental. + """ + + drivers: Sequence[StorageDriver] + """Drivers available for storing and retrieving payloads. At least one + driver must be provided. If more than one driver is registered, + :attr:`driver_selector` must also be set. + + Drivers in this list are looked up by :meth:`StorageDriver.name` during + retrieval, so each driver must have a unique name. + """ + + driver_selector: ( + Callable[[StorageDriverStoreContext, Payload], StorageDriver | None] | None + ) = None + """Controls which driver stores a given payload. A callable that returns the + driver instance to use, or ``None`` to leave the payload stored inline. + The returned driver must be one of the instances registered in + :attr:`drivers`. + + Required when more than one driver is registered. When ``None`` and only + one driver is registered, that driver is used for all store operations. + """ + + payload_size_threshold: int | None = 256 * 1024 + """Minimum payload size in bytes before external storage is considered. + Defaults to 256 KiB. Set to ``None`` to consider every payload for + external storage regardless of size. + """ + + _driver_map: dict[str, StorageDriver] = dataclasses.field( + init=False, repr=False, compare=False + ) + """Name-keyed index of :attr:`drivers`, built at construction time. Used + for retrieval lookups. + """ + + _context: SerializationContext | None = dataclasses.field( + init=False, default=None, repr=False, compare=False + ) + + _claim_converter: ClassVar[JSONPlainPayloadConverter] = JSONPlainPayloadConverter( + encoding=_REFERENCE_ENCODING.decode() + ) + + def __post_init__(self) -> None: + """Validate drivers and build the internal name-keyed driver map. + + Raises :exc:`ValueError` if no drivers are provided, if more than one + driver is registered without a :attr:`driver_selector`, or if any two + drivers share the same name. + """ + if not self.drivers: + raise ValueError( + "ExternalStorage.drivers must contain at least one driver." + ) + if len(self.drivers) > 1 and self.driver_selector is None: + raise ValueError( + "ExternalStorage.driver_selector must be specified if multiple drivers are registered." + ) + driver_map: dict[str, StorageDriver] = {} + for driver in self.drivers: + name = driver.name() + if name in driver_map: + raise ValueError( + f"ExternalStorage.drivers contains multiple drivers with name '{name}'. " + "Each driver must have a unique name." + ) + driver_map[name] = driver + object.__setattr__(self, "_driver_map", driver_map) + + def with_context(self, context: SerializationContext) -> Self: + """Return a copy of these options with the serialization context applied.""" + result = dataclasses.replace(self) + object.__setattr__(result, "_context", context) + return result + + def _select_driver( + self, context: StorageDriverStoreContext, payload: Payload + ) -> StorageDriver | None: + """Returns the driver to use for this payload, or None to pass through.""" + if ( + self.payload_size_threshold is not None + and payload.ByteSize() < self.payload_size_threshold + ): + return None + selector = self.driver_selector + if selector is None: + return self.drivers[0] if self.drivers else None + driver = selector(context, payload) + if driver is None: + return None + registered = self._driver_map.get(driver.name()) + if registered is not driver: + raise ValueError( + f"Driver '{driver.name()}' returned by driver_selector is not registered in ExternalStorage.drivers" + ) + return driver + + def _get_driver_by_name(self, name: str) -> StorageDriver: + """Looks up a driver by name, raising :class:`ValueError` if not found.""" + driver = self._driver_map.get(name) + if driver is None: + raise ValueError(f"No driver found with name '{name}'") + return driver + + async def _store_payload(self, payload: Payload) -> Payload: + context = StorageDriverStoreContext(serialization_context=self._context) + + driver = self._select_driver(context, payload) + if driver is None: + return payload + + claims = await driver.store(context, [payload]) + + self._validate_claim_length(claims, expected=1, driver=driver) + + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claims[0], + ) + reference_payload = self._claim_converter.to_payload(reference) + if reference_payload is None: + raise ValueError( + f"Failed to serialize storage reference for driver '{driver.name()}'" + ) + reference_payload.external_payloads.add().size_bytes = payload.ByteSize() + return reference_payload + + async def _store_payloads(self, payloads: Payloads): + stored_payloads = await self._store_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def _store_payload_sequence( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + if len(payloads) == 1: + return [await self._store_payload(payloads[0])] + + results = list(payloads) + context = StorageDriverStoreContext(serialization_context=self._context) + + to_store: list[tuple[int, Payload, StorageDriver]] = [] + for index, payload in enumerate(payloads): + driver = self._select_driver(context, payload) + if driver is None: + continue + to_store.append((index, payload, driver)) + + if not to_store: + return results + + driver_groups: dict[StorageDriver, list[tuple[int, Payload]]] = {} + for orig_index, payload, driver in to_store: + driver_groups.setdefault(driver, []).append((orig_index, payload)) + + driver_group_list = list(driver_groups.items()) + + all_claims = await _gather_cancel_on_error( + [ + driver.store(context, [p for _, p in indexed_payloads]) + for driver, indexed_payloads in driver_group_list + ] + ) + + for (driver, indexed_payloads), claims in zip(driver_group_list, all_claims): + indices = [idx for idx, _ in indexed_payloads] + sizes = [p.ByteSize() for _, p in indexed_payloads] + + self._validate_claim_length(claims, expected=len(indices), driver=driver) + + for i, claim in enumerate(claims): + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claim, + ) + reference_payload = self._claim_converter.to_payload(reference) + if reference_payload is None: + raise ValueError( + f"Failed to serialize storage reference for driver '{driver.name()}'" + ) + reference_payload.external_payloads.add().size_bytes = sizes[i] + results[indices[i]] = reference_payload + + return results + + async def _retrieve_payload(self, payload: Payload) -> Payload: + if len(payload.external_payloads) == 0: + return payload + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + return payload + + driver = self._get_driver_by_name(reference.driver_name) + context = StorageDriverRetrieveContext() + + stored_payloads = await driver.retrieve(context, [reference.driver_claim]) + + self._validate_payload_length(stored_payloads, expected=1, driver=driver) + + return stored_payloads[0] + + async def _retrieve_payloads(self, payloads: Payloads): + stored_payloads = await self._retrieve_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def _retrieve_payload_sequence( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + results = list(payloads) + + if len(payloads) == 1: + return [await self._retrieve_payload(payloads[0])] + + driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {} + for index, payload in enumerate(payloads): + if len(payload.external_payloads) == 0: + continue + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + continue + + driver = self._get_driver_by_name(reference.driver_name) + driver_claims.setdefault(driver, []).append((index, reference.driver_claim)) + + if not driver_claims: + return results + + context = StorageDriverRetrieveContext() + stored_by_index: dict[int, Payload] = {} + + driver_claim_list = list(driver_claims.items()) + + all_stored = await _gather_cancel_on_error( + [ + driver.retrieve(context, [claim for _, claim in indexed_claims]) + for driver, indexed_claims in driver_claim_list + ] + ) + + for (driver, indexed_claims), stored_payloads in zip( + driver_claim_list, all_stored + ): + indices = [idx for idx, _ in indexed_claims] + + self._validate_payload_length( + stored_payloads, + expected=len(indexed_claims), + driver=driver, + ) + + for idx, stored_payload in zip(indices, stored_payloads): + stored_by_index[idx] = stored_payload + + retrieve_indices = sorted(stored_by_index.keys()) + stored_list = [stored_by_index[idx] for idx in retrieve_indices] + + for i, retrieved_payload in enumerate(stored_list): + results[retrieve_indices[i]] = retrieved_payload + + return results + + def _validate_claim_length( + self, claims: Sequence[StorageDriverClaim], expected: int, driver: StorageDriver + ) -> None: + if len(claims) != expected: + raise ValueError( + f"Driver '{driver.name()}' returned {len(claims)} claims, expected {expected}", + ) + + def _validate_payload_length( + self, payloads: Sequence[Payload], expected: int, driver: StorageDriver + ) -> None: + if len(payloads) != expected: + raise ValueError( + f"Driver '{driver.name()}' returned {len(payloads)} payloads, expected {expected}", + ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 9c4d0ec17..7b67734d9 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -629,7 +629,7 @@ async def _execute_activity( else None, ) - if self._encode_headers and data_converter._decode_payload_has_effect: + if self._encode_headers: for payload in start.header_fields.values(): payload.CopyFrom(await data_converter._decode_payload(payload)) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 2f8e7560f..b305bd3e0 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -342,21 +342,44 @@ async def _handle_activation( "Failed handling activation on workflow with run ID %s", act.run_id ) - completion.failed.failure.SetInParent() - try: - data_converter.failure_converter.to_failure( - err, - data_converter.payload_converter, - completion.failed.failure, - ) - except Exception as inner_err: - logger.exception( - "Failed converting activation exception on workflow with run ID %s", - act.run_id, - ) - completion.failed.failure.message = ( - f"Failed converting activation exception: {inner_err}" - ) + if ( + isinstance(err, temporalio.exceptions.ApplicationError) + and err.non_retryable + ): + # Fail the workflow execution terminally rather than failing the task + command = completion.successful.commands.add() + failure = command.fail_workflow_execution.failure + failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) + else: + completion.failed.failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + completion.failed.failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + completion.failed.failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) completion.run_id = act.run_id diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 1fa7e2eae..1bfa77c3c 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1794,9 +1794,9 @@ def workflow_set_current_details(self, details: str): self._current_details = details def workflow_is_failure_exception(self, err: BaseException) -> bool: - # An exception is a failure instead of a task fail if it's already a - # failure error or if it is a timeout error or if it is an instance of - # any of the failure types in the worker or workflow-level setting + # An exception causes the workflow to fail (rather than the task) if it + # is already a failure error, a timeout error, or an instance of any of the + # failure exception types configured at the worker or workflow level. wf_failure_exception_types = self._defn.failure_exception_types if self._dynamic_failure_exception_types is not None: wf_failure_exception_types = self._dynamic_failure_exception_types diff --git a/tests/test_extstore.py b/tests/test_extstore.py new file mode 100644 index 000000000..8a8a8b6d6 --- /dev/null +++ b/tests/test_extstore.py @@ -0,0 +1,677 @@ +"""Tests for external storage functionality.""" + +import asyncio +from collections.abc import Sequence + +import pytest + +from temporalio.api.common.v1 import Payload +from temporalio.converter import ( + DataConverter, + ExternalStorage, + JSONPlainPayloadConverter, + PayloadCodec, + StorageDriver, + StorageDriverClaim, + StorageDriverRetrieveContext, + StorageDriverStoreContext, +) +from temporalio.converter._extstore import _StorageReference +from temporalio.exceptions import ApplicationError + + +class InMemoryTestDriver(StorageDriver): + """In-memory storage driver for testing.""" + + def __init__( + self, + driver_name: str = "test-driver", + ): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + self._store_calls = 0 + self._retrieve_calls = 0 + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + self._store_calls += 1 + start_index = len(self._storage) + + entries = [ + (f"payload-{start_index + i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + + return [StorageDriverClaim(claim_data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + self._retrieve_calls += 1 + + def parse_claim( + claim: StorageDriverClaim, + ) -> Payload: + key = claim.claim_data["key"] + if key not in self._storage: + raise ApplicationError( + f"Payload not found for key '{key}'", non_retryable=True + ) + payload = Payload() + payload.ParseFromString(self._storage[key]) + return payload + + return [parse_claim(claim) for claim in claims] + + +class TestDataConverterExternalStorage: + """Tests for DataConverter with external storage.""" + + async def test_extstore_encode_decode(self): + """Test that large payloads are stored externally.""" + driver = InMemoryTestDriver() + + # Configure with 100-byte threshold + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=100, + ) + ) + + # Small value should not be externalized + small_value = "small" + encoded_small = await converter.encode([small_value]) + assert len(encoded_small) == 1 + assert not encoded_small[0].external_payloads # Not externalized + assert driver._store_calls == 0 + + # Large value should be externalized + large_value = "x" * 200 + encoded_large = await converter.encode([large_value]) + assert len(encoded_large) == 1 + assert len(encoded_large[0].external_payloads) > 0 # Externalized + assert driver._store_calls == 1 + + # Decode large value + decoded = await converter.decode(encoded_large, [str]) + assert len(decoded) == 1 + assert decoded[0] == large_value + assert driver._retrieve_calls == 1 + + async def test_extstore_reference_structure(self): + """Test that external storage creates proper reference structure.""" + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[InMemoryTestDriver("test-driver")], + payload_size_threshold=50, + ) + ) + + # Create large payload + large_value = "x" * 100 + encoded = await converter.encode([large_value]) + + # Verify reference structure + reference_payload = encoded[0] + assert len(reference_payload.external_payloads) > 0 + + # The payload should contain a serialized _ExternalStorageReference + # Deserialize it to verify structure using the same encoding + claim_converter = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + reference = claim_converter.from_payload(reference_payload, _StorageReference) + + assert isinstance(reference, _StorageReference) + assert "test-driver" == reference.driver_name + assert isinstance(reference.driver_claim, StorageDriverClaim) + assert "key" in reference.driver_claim.claim_data + + async def test_extstore_composite_conditional(self): + """Test using multiple drivers based on size.""" + hot_driver = InMemoryTestDriver("hot-storage") + cold_driver = InMemoryTestDriver("cold-storage") + + options = ExternalStorage( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: hot_driver + if payload.ByteSize() < 500 + else cold_driver, + payload_size_threshold=100, + ) + converter = DataConverter(external_storage=options) + + # Small payload (not externalized) + small = "x" * 50 + encoded_small = await converter.encode([small]) + assert not encoded_small[0].external_payloads + assert hot_driver._store_calls == 0 + assert cold_driver._store_calls == 0 + + # Medium payload (hot storage) + medium = "x" * 200 + encoded_medium = await converter.encode([medium]) + assert len(encoded_medium[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 + assert cold_driver._store_calls == 0 + + # Large payload (cold storage) + large = "x" * 2000 + encoded_large = await converter.encode([large]) + assert len(encoded_large[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 # Unchanged + assert cold_driver._store_calls == 1 + + # Verify retrieval from correct drivers + decoded_medium = await converter.decode(encoded_medium, [str]) + assert decoded_medium[0] == medium + assert hot_driver._retrieve_calls == 1 + + decoded_large = await converter.decode(encoded_large, [str]) + assert decoded_large[0] == large + assert cold_driver._retrieve_calls == 1 + + +class TestDriverError: + """Tests for ValueError raised when a driver violates its contract.""" + + async def test_encode_wrong_claim_count_raises_runtime_error(self): + """store() returning fewer claims than payloads must raise ValueError.""" + + class _NoClaimsDriver(InMemoryTestDriver): + async def store( + self, context: StorageDriverStoreContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: + return [] + + driver = _NoClaimsDriver() + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=10, + ) + ) + with pytest.raises( + ValueError, + match=f"Driver '{driver.name()}' returned 0 claims, expected 1", + ): + await converter.encode(["x" * 200]) + + async def test_decode_wrong_payload_count_raises_runtime_error(self): + """retrieve() returning fewer payloads than claims must raise ValueError.""" + good_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[InMemoryTestDriver()], + payload_size_threshold=10, + ) + ) + encoded = await good_converter.encode(["x" * 200]) + + class _NoPayloadsDriver(InMemoryTestDriver): + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + return [] + + driver = _NoPayloadsDriver() + bad_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=10, + ) + ) + with pytest.raises( + ValueError, + match=f"Driver '{driver.name()}' returned 0 payloads, expected 1", + ): + await bad_converter.decode(encoded, [str]) + + async def test_store_cancels_in_flight_driver_on_error(self): + """When one driver raises during concurrent store, other in-flight drivers are cancelled.""" + store_cancelled = asyncio.Event() + + class _SleepingStoreDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("sleeping") + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + try: + await asyncio.sleep(float("inf")) + except asyncio.CancelledError: + store_cancelled.set() + raise + return [] # unreachable + + class _FailingStoreDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("failing") + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + raise ValueError( + "failed to store payloads because remote service is unavailable" + ) + + drivers = [_SleepingStoreDriver(), _FailingStoreDriver()] + drivers_iter = iter(drivers) + converter = DataConverter( + external_storage=ExternalStorage( + drivers=drivers, + driver_selector=lambda ctx, p: next(drivers_iter), + payload_size_threshold=None, + ) + ) + + with pytest.raises( + ValueError, + match="^failed to store payloads because remote service is unavailable$", + ): + await converter.encode(["payload_a", "payload_b"]) + + assert store_cancelled.is_set() + + async def test_retrieve_cancels_in_flight_driver_on_error(self): + """When one driver raises during concurrent retrieve, other in-flight drivers are cancelled.""" + retrieve_cancelled = asyncio.Event() + + class _SleepingRetrieveDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("sleeping") + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + try: + await asyncio.sleep(float("inf")) + except asyncio.CancelledError: + retrieve_cancelled.set() + raise + return [] # unreachable + + class _FailingRetrieveDriver(InMemoryTestDriver): + def __init__(self): + super().__init__("failing") + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + raise ValueError( + "failed to retrieve a payload because the object key does not exist" + ) + + drivers: list[StorageDriver] = [ + _SleepingRetrieveDriver(), + _FailingRetrieveDriver(), + ] + drivers_iter = iter(drivers) + converter = DataConverter( + external_storage=ExternalStorage( + drivers=drivers, + driver_selector=lambda ctx, p: next(drivers_iter), + payload_size_threshold=None, + ) + ) + encoded = await converter.encode(["payload_a", "payload_b"]) + + with pytest.raises( + ValueError, + match="^failed to retrieve a payload because the object key does not exist$", + ): + await converter.decode(encoded, [str, str]) + + assert retrieve_cancelled.is_set() + + +class RecordingPayloadCodec(PayloadCodec): + """Codec that wraps each payload under a recognisable ``encoding`` label. + + Encode sets ``metadata["encoding"]`` to ``encoding_label`` and stores the + serialised inner payload as ``data``. Decode reverses that. The call + counters let tests assert exactly how many payloads each codec processed. + """ + + def __init__(self, encoding_label: str) -> None: + self._encoding_label = encoding_label.encode() + self.encoded_count = 0 + self.decoded_count = 0 + + async def encode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.encoded_count += len(payloads) + results = [] + for p in payloads: + wrapped = Payload() + wrapped.metadata["encoding"] = self._encoding_label + wrapped.data = p.SerializeToString() + results.append(wrapped) + return results + + async def decode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.decoded_count += len(payloads) + results = [] + for p in payloads: + inner = Payload() + inner.ParseFromString(p.data) + results.append(inner) + return results + + +class TestPayloadCodecWithExternalStorage: + """Tests for interaction between DataConverter.payload_codec and external storage.""" + + async def test_dc_payload_codec_encodes_stored_bytes(self): + """DataConverter.payload_codec encodes the bytes handed to the driver + for storage. The reference payload written to workflow history is NOT + encoded by the DataConverter codec.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # The reference payload written to history must NOT carry the dc_codec label. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") != b"binary/dc-encoded" + + # The bytes given to the driver must carry the dc_codec label. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") == b"binary/dc-encoded" + + # Round-trip must recover the original value. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_dc_payload_codec_does_not_encode_reference_payload(self): + """The reference payload stored in workflow history is NOT encoded by + DataConverter.payload_codec – encoding is applied to the stored bytes + instead.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Reference payload in history is NOT encoded by DataConverter.payload_codec. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") != b"binary/dc-encoded" + + # Stored bytes ARE encoded by DataConverter.payload_codec. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") == b"binary/dc-encoded" + + # Round-trip. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + +class TestMultiDriver: + """Tests for ExternalStorage with multiple drivers.""" + + async def test_selector_always_first_driver_handles_all_stores(self): + """A selector that always picks the first driver routes all store + operations there. The second driver is never called for store.""" + first = InMemoryTestDriver("driver-first") + second = InMemoryTestDriver("driver-second") + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[first, second], + driver_selector=lambda _ctx, _p: first, + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert first._store_calls == 1 + assert second._store_calls == 0 + + # The reference in history names the first driver. + ref = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ).from_payload(encoded[0], _StorageReference) + assert ref.driver_name == "driver-first" + + # Retrieval also goes to the first driver. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert first._retrieve_calls == 1 + assert second._retrieve_calls == 0 + + async def test_no_selector_second_driver_is_retrieve_only(self): + """A driver that is second in the list acts as a retrieve-only driver. + References are resolved by name, not by position, so a payload stored + by driver-b is retrieved correctly even when driver-a is listed first.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Store with driver-b as the sole driver. + store_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_b], + payload_size_threshold=50, + ) + ) + large = "y" * 200 + encoded = await store_converter.encode([large]) + + # Retrieve with driver-a listed first, driver-b second. + # The "driver-b" name in the reference must route to driver-b. + retrieve_converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_a, driver_b], + driver_selector=lambda _ctx, _p: driver_a, + payload_size_threshold=50, + ) + ) + decoded = await retrieve_converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver_a._retrieve_calls == 0 # never consulted + assert driver_b._retrieve_calls == 1 + + async def test_selector_routes_payloads_to_different_drivers_in_single_batch(self): + """When a selector routes different payloads to different drivers, a + single encode([v1, v2, ...]) call batches payloads per driver so each + driver receives exactly one store() call regardless of how many + payloads are routed to it.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Route payloads that serialise to < 500 bytes to driver_a, larger ones + # to driver_b. + def selector(_ctx: object, payload: Payload) -> StorageDriver: + return driver_a if payload.ByteSize() < 500 else driver_b + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_a, driver_b], + driver_selector=selector, + payload_size_threshold=50, + ) + ) + + small_ext = "a" * 100 # above threshold, serialises well below 500 B + large_ext = "b" * 1000 # serialises above 500 B + + # Encode both values in a single call — they should be batched per driver. + encoded = await converter.encode([small_ext, large_ext]) + assert driver_a._store_calls == 1 # one batched call, not two individual ones + assert driver_b._store_calls == 1 + + # Full round-trip. + decoded = await converter.decode(encoded, [str, str]) + assert decoded == [small_ext, large_ext] + assert driver_a._retrieve_calls == 1 + assert driver_b._retrieve_calls == 1 + + async def test_selector_returning_none_keeps_payload_inline(self): + """A selector that returns None for a payload leaves it stored inline + in workflow history rather than offloading it to any driver, even when + the payload exceeds the size threshold.""" + driver = InMemoryTestDriver("driver-a") + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver], + driver_selector=lambda _ctx, _payload: None, + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert driver._store_calls == 0 + assert len(encoded[0].external_payloads) == 0 # payload is inline + + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver._retrieve_calls == 0 + + async def test_selector_returns_unregistered_driver_raises(self): + """A selector that returns a driver instance not present in + ExternalStorage.drivers raises ValueError during encode.""" + registered = InMemoryTestDriver("registered") + unregistered = InMemoryTestDriver("unregistered") + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[registered], + driver_selector=lambda _ctx, _payload: unregistered, + payload_size_threshold=50, + ) + ) + + with pytest.raises(ValueError): + await converter.encode(["x" * 200]) + + async def test_selector_dispatches_drivers_concurrently(self): + started_a = asyncio.Event() + started_b = asyncio.Event() + + class BarrierDriver(InMemoryTestDriver): + def __init__( + self, name: str, my_event: asyncio.Event, their_event: asyncio.Event + ): + super().__init__(name) + self._my_event = my_event + self._their_event = their_event + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + self._my_event.set() + await asyncio.wait_for(self._their_event.wait(), timeout=2.0) + return await super().store(context, payloads) + + driver_a = BarrierDriver("driver-a", started_a, started_b) + driver_b = BarrierDriver("driver-b", started_b, started_a) + + def selector(_ctx: object, payload: Payload) -> StorageDriver: + return driver_a if payload.ByteSize() < 500 else driver_b + + converter = DataConverter( + external_storage=ExternalStorage( + drivers=[driver_a, driver_b], + driver_selector=selector, + payload_size_threshold=None, + ) + ) + + small_ext = "a" * 100 # routes to driver-a + large_ext = "b" * 1000 # routes to driver-b + + # This will deadlock (and timeout) if the two store() calls are not + # dispatched concurrently. + encoded = await asyncio.wait_for( + converter.encode([small_ext, large_ext]), timeout=5.0 + ) + + decoded = await converter.decode(encoded, [str, str]) + assert decoded == [small_ext, large_ext] + + def test_multiple_drivers_without_selector_raises(self): + """Registering more than one driver without a driver_selector raises + ValueError immediately when constructing ExternalStorage.""" + first = InMemoryTestDriver("driver-a") + second = InMemoryTestDriver("driver-b") + + with pytest.raises( + ValueError, + match=r"^ExternalStorage\.driver_selector must be specified if multiple drivers are registered\.$", + ): + ExternalStorage( + drivers=[first, second], + payload_size_threshold=50, + ) + + def test_duplicate_driver_names_raises(self): + """Registering two drivers with identical names raises ValueError immediately + when constructing ExternalStorage.""" + first = InMemoryTestDriver("dup-name") + duplicate = InMemoryTestDriver("dup-name") + + with pytest.raises( + ValueError, + match=r"^ExternalStorage\.drivers contains multiple drivers with name 'dup-name'\. Each driver must have a unique name\.$", + ): + ExternalStorage( + drivers=[first, duplicate], + driver_selector=lambda _ctx, _p: first, + payload_size_threshold=50, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py new file mode 100644 index 000000000..921e8a3f1 --- /dev/null +++ b/tests/worker/test_extstore.py @@ -0,0 +1,554 @@ +import dataclasses +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import timedelta + +import pytest + +import temporalio +import temporalio.converter +from temporalio import activity, workflow +from temporalio.api.common.v1 import Payload +from temporalio.client import Client, WorkflowFailureError, WorkflowHandle +from temporalio.common import RetryPolicy +from temporalio.converter import ( + ExternalStorage, + StorageDriverClaim, + StorageDriverRetrieveContext, + StorageDriverStoreContext, + StorageWarning, +) +from temporalio.exceptions import ActivityError, ApplicationError +from temporalio.testing._workflow import WorkflowEnvironment +from temporalio.worker import Replayer +from tests.helpers import assert_task_fail_eventually, new_worker +from tests.test_extstore import InMemoryTestDriver + + +@dataclass(frozen=True) +class ExtStoreActivityInput: + input_data: str + output_size: int + pass + + +# --------------------------------------------------------------------------- +# Chained-activity scenario +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class ProcessDataInput: + """Input for the first activity: generate a large result.""" + + size: int + + +@dataclass(frozen=True) +class SummarizeInput: + """Input for the second activity: receives the large result from the first.""" + + data: str + + +@activity.defn +async def process_data(input: ProcessDataInput) -> str: + """Produces a large string result that will be stored externally.""" + return "x" * input.size + + +@activity.defn +async def summarize(input: SummarizeInput) -> str: + """Receives the large result and returns a short summary.""" + return f"received {len(input.data)} bytes" + + +@workflow.defn +class ChainedExtStoreWorkflow: + """Workflow that passes a large activity result directly into a second activity. + + This mirrors a common customer pattern: activity A produces a large payload + (e.g. a fetched document or ML inference result) which is too big to store + inline in workflow history, and is then consumed by activity B. External + storage should transparently offload the payload between the two steps + without any special handling in the workflow code. + """ + + @workflow.run + async def run(self, payload_size: int) -> str: + large_result = await workflow.execute_activity( + process_data, + ProcessDataInput(size=payload_size), + schedule_to_close_timeout=timedelta(seconds=10), + ) + return await workflow.execute_activity( + summarize, + SummarizeInput(data=large_result), + schedule_to_close_timeout=timedelta(seconds=10), + ) + + +@activity.defn +async def ext_store_activity( + input: ExtStoreActivityInput, +) -> str: + return "ao" * int(input.output_size / 2) + + +@dataclass(frozen=True) +class ExtStoreWorkflowInput: + input_data: str + activity_input_size: int + activity_output_size: int + output_size: int + max_activity_attempts: int | None = None + + +@workflow.defn +class ExtStoreWorkflow: + @workflow.run + async def run(self, input: ExtStoreWorkflowInput) -> str: + retry_policy = ( + RetryPolicy(maximum_attempts=input.max_activity_attempts) + if input.max_activity_attempts is not None + else None + ) + await workflow.execute_activity( + ext_store_activity, + ExtStoreActivityInput( + input_data="ai" * int(input.activity_input_size / 2), + output_size=input.activity_output_size, + ), + schedule_to_close_timeout=timedelta(seconds=3), + retry_policy=retry_policy, + ) + return "wo" * int(input.output_size / 2) + + +class BadTestDriver(InMemoryTestDriver): + def __init__( + self, + driver_name: str = "bad-driver", + no_store: bool = False, + no_retrieve: bool = False, + raise_payload_not_found: bool = False, + ): + super().__init__(driver_name) + self._no_store = no_store + self._no_retrieve = no_retrieve + self._raise_payload_not_found = raise_payload_not_found + + async def store( + self, + context: StorageDriverStoreContext, + payloads: Sequence[Payload], + ) -> list[StorageDriverClaim]: + if self._no_store: + return [] + return await super().store(context, payloads) + + async def retrieve( + self, + context: StorageDriverRetrieveContext, + claims: Sequence[StorageDriverClaim], + ) -> list[Payload]: + if self._no_retrieve: + return [] + if self._raise_payload_not_found: + raise ApplicationError( + "Payload not found because the bucket does not exist.", + type="BucketNotFoundError", + non_retryable=True, + ) + return await super().retrieve(context, claims) + + +async def test_extstore_activity_input_no_retrieve( + env: WorkflowEnvironment, +): + """When the driver's retrieve returns no payloads for an externalized + activity input, the activity fails and the workflow terminates with a + WorkflowFailureError wrapping an ActivityError.""" + driver = BadTestDriver(no_retrieve=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=1000, + activity_output_size=10, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + assert err.value.cause.cause.message == "Failed decoding arguments" + + +async def test_extstore_activity_result_no_store( + env: WorkflowEnvironment, +): + """When the driver's store returns no claims for an activity result that + exceeds the size threshold, the activity fails to complete and the workflow + terminates with a WorkflowFailureError wrapping an ActivityError.""" + driver = BadTestDriver(no_store=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=10, + activity_output_size=1000, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + assert isinstance(err.value.cause.cause, ApplicationError) + assert ( + err.value.cause.cause.message + == "Driver 'bad-driver' returned 0 claims, expected 1" + ) + assert err.value.cause.cause.type == "ValueError" + + +async def test_extstore_worker_missing_driver( + env: WorkflowEnvironment, +): + """Validate that when a worker is provided a workflow history with + external storage references and the worker is not configured for external + storage, it will cause a workflow task failure. + """ + driver = InMemoryTestDriver() + + far_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + worker_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + ) + + async with new_worker( + worker_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await far_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 1024, + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + await assert_task_fail_eventually(handle) + + +async def test_extstore_payload_not_found_fails_workflow( + env: WorkflowEnvironment, +): + """When a non-retryable ApplicationError is raised while retrieving workflow input, + the workflow must fail terminally (not retry as a task failure). + """ + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[BadTestDriver(raise_payload_not_found=True)], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 512, # exceeds 1024-byte threshold + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + with pytest.raises(WorkflowFailureError) as exc_info: + await handle.result() + + assert isinstance(exc_info.value.cause, ApplicationError) + assert ( + exc_info.value.cause.message + == "Payload not found because the bucket does not exist." + ) + assert exc_info.value.cause.type == "BucketNotFoundError" + assert exc_info.value.cause.non_retryable is True + + +async def _run_extstore_workflow_and_fetch_history( + env: WorkflowEnvironment, + driver: InMemoryTestDriver, + *, + input_data: str, + activity_output_size: int = 10, +) -> WorkflowHandle: + """Helper: run ExtStoreWorkflow with the given driver and return its history handle.""" + extstore_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ) + async with new_worker( + extstore_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await extstore_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data=input_data, + activity_input_size=10, + activity_output_size=activity_output_size, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + return handle + + +async def test_replay_extstore_history_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input fails to replay when the + Replayer has no external storage configured.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="wi" * 512, # exceeds 512-byte threshold + ) + history = await handle.fetch_history() + + # Replay without external storage — the reference payload cannot be decoded. + # The middleware emits a StorageWarning when it encounters a reference payload + # with no driver configured. + with pytest.warns( + StorageWarning, + match=r"^\[TMPRL1105\] Detected externally stored payload\(s\) but external storage is not configured\.$", + ): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Must be a task-failure RuntimeError, not a NondeterminismError — external + # storage decode failures are distinct from workflow code changes. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0] + + +async def test_replay_extstore_history_succeeds_with_correct_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input replays successfully when the + Replayer is configured with the same storage driver that holds the data.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with the same populated driver — must succeed. + await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history) + + +async def test_replay_extstore_history_fails_with_empty_driver( + env: WorkflowEnvironment, +) -> None: + """A history with external storage references fails to replay when the + Replayer has external storage configured but the driver holds no data + (simulates pointing at the wrong backend or a purged store).""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with a fresh empty driver — retrieval will fail. + result = await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[InMemoryTestDriver()], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history, raise_on_replay_failure=False) + # InMemoryTestDriver raises ApplicationError for absent keys. + # ApplicationError is re-raised without wrapping, so it propagates + # through decode_activation (before the workflow task runs). The core SDK + # receives an activation failure, issues a FailWorkflow command, but the + # next history event is ActivityTaskScheduled — causing a NondeterminismError. + assert isinstance(result.replay_failure, workflow.NondeterminismError) + + +async def test_replay_extstore_activity_result_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history where only the activity result was stored externally (the + workflow input is small enough to be inline) also fails to replay without + external storage — verifying that mid-workflow decode failures are caught.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="small", # well under 512 bytes — stays inline + activity_output_size=2048, # 2 KB result — stored externally + ) + history = await handle.fetch_history() + + # Replay without external storage. The workflow input decodes fine, but + # when the ActivityTaskCompleted result is delivered back to the workflow + # coroutine it cannot be decoded. + with pytest.warns( + StorageWarning, + match=r"^\[TMPRL1105\] Detected externally stored payload\(s\) but external storage is not configured\.$", + ): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Mid-workflow decode failure is still a task failure (RuntimeError), not + # nondeterminism. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0] + + +async def test_extstore_chained_activities( + env: WorkflowEnvironment, +) -> None: + """Large activity output is transparently offloaded and passed to a second activity. + + This is a representative customer scenario: activity A returns a payload that + exceeds the size threshold (e.g. a fetched document), external storage offloads + it so it never bloats workflow history, and activity B receives it as its input + without any special handling in the workflow code. + """ + driver = InMemoryTestDriver() + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=ExternalStorage( + drivers=[driver], + payload_size_threshold=1024, # 1 KB threshold + ), + ), + ) + + # process_data returns 10 KB — well above the 1 KB threshold. + payload_size = 10_000 + + async with new_worker( + client, + ChainedExtStoreWorkflow, + activities=[process_data, summarize], + ) as worker: + result = await client.execute_workflow( + ChainedExtStoreWorkflow.run, + payload_size, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=30), + ) + + # The second activity received the full payload and summarized it correctly. + assert result == f"received {payload_size} bytes" + + # External storage was actually used: the large activity result and its + # re-use as the second activity's input should have triggered at least two + # round-trips (one store on completion, one retrieve on the next WFT). + assert driver._store_calls == 2 + assert driver._retrieve_calls == 2