diff --git a/src/orcapod/__init__.py b/src/orcapod/__init__.py index 85956537..8afbf1ef 100644 --- a/src/orcapod/__init__.py +++ b/src/orcapod/__init__.py @@ -2,6 +2,7 @@ FunctionPod, function_pod, ) +from .core.sources.arrow_table_source import ArrowTableSource from .pipeline import Pipeline # Subpackage re-exports for clean public API @@ -13,6 +14,7 @@ from . import types # noqa: F401 __all__ = [ + "ArrowTableSource", "FunctionPod", "function_pod", "Pipeline", diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py index 76d2a27d..c6d9a00c 100644 --- a/src/orcapod/core/cached_function_pod.py +++ b/src/orcapod/core/cached_function_pod.py @@ -65,7 +65,11 @@ def record_path(self) -> tuple[str, ...]: return self._cache.record_path def process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a packet with pod-level caching. @@ -77,6 +81,7 @@ def process_packet( Args: tag: The tag associated with the packet. packet: The input packet to process. + logger: Optional packet execution logger. Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` @@ -84,10 +89,10 @@ def process_packet( """ cached = self._cache.lookup(packet) if cached is not None: - logger.info("Pod-level cache hit") + _logger.info("Pod-level cache hit") return tag, cached - tag, output = self._function_pod.process_packet(tag, packet) + tag, output = self._function_pod.process_packet(tag, packet, logger=logger) if output is not None: pf = self._function_pod.packet_function self._cache.store( @@ -100,7 +105,11 @@ def process_packet( return tag, output async def async_process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``. @@ -110,10 +119,12 @@ async def async_process_packet( """ cached = self._cache.lookup(packet) if cached is not None: - logger.info("Pod-level cache hit") + _logger.info("Pod-level cache hit") return tag, cached - tag, output = await self._function_pod.async_process_packet(tag, packet) + tag, output = await self._function_pod.async_process_packet( + tag, packet, logger=logger + ) if output is not None: pf = self._function_pod.packet_function self._cache.store( @@ -152,3 +163,7 @@ def process( input_stream=input_stream, label=label, ) + + +# Module-level logger alias to avoid conflict with `logger` kwarg +_logger = logger diff --git a/src/orcapod/core/executors/base.py b/src/orcapod/core/executors/base.py index 386cbe1a..513fce4b 100644 --- a/src/orcapod/core/executors/base.py +++ b/src/orcapod/core/executors/base.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol class PacketFunctionExecutorBase(ABC): @@ -46,28 +47,33 @@ def supports(self, packet_function_type_id: str) -> bool: @abstractmethod def execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> PacketProtocol | None: + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Synchronously execute *packet_function* on *packet*. Implementations should call ``packet_function.direct_call(packet)`` to invoke the function's native computation, bypassing executor - routing. + routing. If a logger is provided, captured I/O is recorded to it. + On failure, the original exception is re-raised after recording. """ ... async def async_execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> PacketProtocol | None: + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Asynchronous counterpart of ``execute``. The default implementation delegates to ``execute`` synchronously. Subclasses should override for truly async execution. """ - return self.execute(packet_function, packet) + return self.execute(packet_function, packet, logger=logger) @property def supports_concurrent_execution(self) -> bool: @@ -96,16 +102,21 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: - """Synchronously execute *fn* with *kwargs*. + """Synchronously execute *fn* with *kwargs*, returning the raw result. - Default implementation calls ``fn(**kwargs)`` in-process. - Subclasses should override for remote/distributed execution. + Default implementation calls ``fn(**kwargs)`` with no capture. + Exceptions propagate to the caller. Subclasses (e.g. + ``LocalExecutor``, ``RayExecutor``) override to add I/O capture + and logger recording. Args: fn: The Python callable to execute. kwargs: Keyword arguments to pass to *fn*. executor_options: Optional per-call options. + logger: Optional logger to record captured I/O. Returns: The raw return value of *fn*. @@ -117,14 +128,15 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: - """Asynchronously execute *fn* with *kwargs*. + """Asynchronously execute *fn* with *kwargs*, returning the raw result. Default implementation delegates to ``execute_callable`` - synchronously. Subclasses should override for truly async - execution. + synchronously. Subclasses should override for truly async execution. """ - return self.execute_callable(fn, kwargs, executor_options) + return self.execute_callable(fn, kwargs, executor_options, logger=logger) def get_execution_data(self) -> dict[str, Any]: """Return metadata describing the execution environment. diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py index fc97e477..ce13be5e 100644 --- a/src/orcapod/core/executors/local.py +++ b/src/orcapod/core/executors/local.py @@ -2,6 +2,7 @@ import asyncio import inspect +import traceback as _traceback_module from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -9,6 +10,7 @@ if TYPE_CHECKING: from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol class LocalExecutor(PacketFunctionExecutorBase): @@ -27,16 +29,20 @@ def supported_function_type_ids(self) -> frozenset[str]: def execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> PacketProtocol | None: + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": return packet_function.direct_call(packet) async def async_execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> PacketProtocol | None: + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": return await packet_function.direct_async_call(packet) # -- PythonFunctionExecutorProtocol -- @@ -46,10 +52,28 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: - if inspect.iscoroutinefunction(fn): - return self._run_async_sync(fn, kwargs) - return fn(**kwargs) + from orcapod.pipeline.logging_capture import LocalCaptureContext + + ctx = LocalCaptureContext() + with ctx: + try: + if inspect.iscoroutinefunction(fn): + raw_result = self._run_async_sync(fn, kwargs) + else: + raw_result = fn(**kwargs) + except Exception: + tb = _traceback_module.format_exc() + captured = ctx.get_captured(success=False, tb=tb) + if logger is not None: + logger.record(captured) + raise + captured = ctx.get_captured(success=True) + if logger is not None: + logger.record(captured) + return raw_result @staticmethod def _run_async_sync(fn: Callable[..., Any], kwargs: dict[str, Any]) -> Any: @@ -69,11 +93,36 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: - if inspect.iscoroutinefunction(fn): - return await fn(**kwargs) - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, lambda: fn(**kwargs)) + from orcapod.pipeline.logging_capture import LocalCaptureContext + + ctx = LocalCaptureContext() + with ctx: + try: + if inspect.iscoroutinefunction(fn): + raw_result = await fn(**kwargs) + else: + import contextvars + import functools + + loop = asyncio.get_running_loop() + task_ctx = contextvars.copy_context() + raw_result = await loop.run_in_executor( + None, + functools.partial(task_ctx.run, fn, **kwargs), + ) + except Exception: + tb = _traceback_module.format_exc() + captured = ctx.get_captured(success=False, tb=tb) + if logger is not None: + logger.record(captured) + raise + captured = ctx.get_captured(success=True) + if logger is not None: + logger.record(captured) + return raw_result def with_options(self, **opts: Any) -> "LocalExecutor": """Return a new ``LocalExecutor``. diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py index 998c2193..89dd9755 100644 --- a/src/orcapod/core/executors/ray.py +++ b/src/orcapod/core/executors/ray.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from orcapod.core.packet_function import PythonPacketFunction from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol class RayExecutor(PacketFunctionExecutorBase): @@ -77,9 +78,38 @@ def __init__( self._remote_opts.update(ray_remote_opts) def _ensure_ray_initialized(self) -> None: - """Initialize Ray if it has not been initialized yet.""" + """Initialize Ray if it has not been initialized yet. + + Also registers a cloudpickle dispatch for ``logging.Logger`` so that + user functions referencing loggers can be sent to Ray workers that + do not have orcapod installed. + + By default cloudpickle serializes Logger instances by value, which + traverses the parent chain to the root logger. After + ``install_capture_streams()`` the root logger has a + ``ContextVarLoggingHandler`` from ``orcapod``. Workers without + orcapod cannot deserialize that class. + + Registering loggers as ``(logging.getLogger, (name,))`` is the + correct semantic — loggers are name-keyed singletons — and produces + no orcapod dependency in the pickled bytes. + """ + import logging import ray + try: + import cloudpickle + + def _pickle_logger(l: logging.Logger) -> tuple: + # Root logger has name "root" but must be fetched as "" + name = "" if isinstance(l, logging.RootLogger) else l.name + return logging.getLogger, (name,) + + cloudpickle.CloudPickler.dispatch[logging.Logger] = _pickle_logger + cloudpickle.CloudPickler.dispatch[logging.RootLogger] = _pickle_logger + except Exception: + pass # cloudpickle not available or API changed — best effort + if not ray.is_initialized(): if self._ray_address is not None: ray.init(address=self._ray_address) @@ -102,7 +132,7 @@ def _build_remote_opts(self) -> dict[str, Any]: return dict(self._remote_opts) def _as_python_packet_function( - self, packet_function: PacketFunctionProtocol + self, packet_function: "PacketFunctionProtocol" ) -> "PythonPacketFunction": """Return *packet_function* cast to ``PythonPacketFunction``, or raise. @@ -122,69 +152,188 @@ def _as_python_packet_function( def execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> PacketProtocol | None: - import ray - + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": pf = self._as_python_packet_function(packet_function) if not pf.is_active(): return None - self._ensure_ray_initialized() - - kwargs = packet.as_dict() - remote_fn = ray.remote(**self._build_remote_opts())(pf._function) - ref = remote_fn.remote(**kwargs) - raw_result = ray.get(ref) - return pf._build_output_packet(raw_result) + raw = self.execute_callable(pf._function, packet.as_dict(), logger=logger) + return pf._build_output_packet(raw) async def async_execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> PacketProtocol | None: - import ray - + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": pf = self._as_python_packet_function(packet_function) if not pf.is_active(): return None - self._ensure_ray_initialized() - - kwargs = packet.as_dict() - remote_fn = ray.remote(**self._build_remote_opts())(pf._function) - ref = remote_fn.remote(**kwargs) - raw_result = await asyncio.wrap_future(ref.future()) - return pf._build_output_packet(raw_result) + raw = await self.async_execute_callable( + pf._function, packet.as_dict(), logger=logger + ) + return pf._build_output_packet(raw) # -- PythonFunctionExecutorProtocol -- + @staticmethod + def _make_capture_wrapper() -> Callable[..., Any]: + """Return an inline capture wrapper suitable for Ray remote execution. + + The wrapper is defined as a closure (not a module-level import) so that + cloudpickle serializes it by bytecode rather than by module reference. + This means the Ray cluster workers do **not** need ``orcapod`` installed + — only the standard library is required on the worker side. + + The wrapper returns a plain 6-tuple ``(raw_result, stdout, stderr, + python_logs, traceback_str, success)`` so no orcapod types cross the + Ray object store; the driver reconstructs :class:`CapturedLogs` from + the tuple. + """ + def _capture(fn: Any, kwargs: dict) -> tuple: + import io + import logging + import os + import sys + import tempfile + import traceback as _tb + + stdout_tmp = tempfile.TemporaryFile() + stderr_tmp = tempfile.TemporaryFile() + orig_stdout_fd = os.dup(1) + orig_stderr_fd = os.dup(2) + orig_sys_stdout = sys.stdout + orig_sys_stderr = sys.stderr + sys_stdout_buf = io.StringIO() + sys_stderr_buf = io.StringIO() + log_records: list = [] + + fmt = logging.Formatter("%(levelname)s:%(name)s:%(message)s") + + class _H(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_records.append(fmt.format(record)) + + handler = _H() + root_logger = logging.getLogger() + orig_level = root_logger.level + root_logger.setLevel(logging.DEBUG) + root_logger.addHandler(handler) + + raw_result = None + success = True + tb_str = None + try: + sys.stdout.flush() + sys.stderr.flush() + os.dup2(stdout_tmp.fileno(), 1) + os.dup2(stderr_tmp.fileno(), 2) + sys.stdout = sys_stdout_buf + sys.stderr = sys_stderr_buf + try: + raw_result = fn(**kwargs) + except Exception: + success = False + tb_str = _tb.format_exc() + finally: + sys.stdout = orig_sys_stdout + sys.stderr = orig_sys_stderr + os.dup2(orig_stdout_fd, 1) + os.dup2(orig_stderr_fd, 2) + os.close(orig_stdout_fd) + os.close(orig_stderr_fd) + root_logger.removeHandler(handler) + root_logger.setLevel(orig_level) + stdout_tmp.seek(0) + stderr_tmp.seek(0) + cap_stdout = ( + stdout_tmp.read().decode("utf-8", errors="replace") + + sys_stdout_buf.getvalue() + ) + cap_stderr = ( + stderr_tmp.read().decode("utf-8", errors="replace") + + sys_stderr_buf.getvalue() + ) + stdout_tmp.close() + stderr_tmp.close() + + return raw_result, cap_stdout, cap_stderr, "\n".join(log_records), tb_str, success + + return _capture + def execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: + """Execute *fn* on the Ray cluster with fd-level I/O capture. + + The capture wrapper is serialized by bytecode (not module reference) so + the Ray cluster workers do not need ``orcapod`` installed. + """ import ray + from orcapod.pipeline.logging_capture import CapturedLogs + self._ensure_ray_initialized() - remote_fn = ray.remote(**self._build_remote_opts())(fn) - ref = remote_fn.remote(**kwargs) - return ray.get(ref) + wrapper = self._make_capture_wrapper() + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + remote_fn = ray.remote(**self._build_remote_opts())(wrapper) + ref = remote_fn.remote(fn, kwargs) + raw, stdout, stderr, python_logs, tb, success = ray.get(ref) + + captured = CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) + if logger is not None: + logger.record(captured) + if not success: + raise RuntimeError(tb or "Ray worker execution failed") + return raw async def async_execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: + """Async counterpart of :meth:`execute_callable`.""" import ray + from orcapod.pipeline.logging_capture import CapturedLogs + self._ensure_ray_initialized() - remote_fn = ray.remote(**self._build_remote_opts())(fn) - ref = remote_fn.remote(**kwargs) - return await asyncio.wrap_future(ref.future()) + wrapper = self._make_capture_wrapper() + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + remote_fn = ray.remote(**self._build_remote_opts())(wrapper) + ref = remote_fn.remote(fn, kwargs) + raw, stdout, stderr, python_logs, tb, success = await asyncio.wrap_future( + ref.future() + ) + + captured = CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) + if logger is not None: + logger.record(captured) + if not success: + raise RuntimeError(tb or "Ray worker execution failed") + return raw def with_options(self, **opts: Any) -> "RayExecutor": """Return a new ``RayExecutor`` with the given options merged in. diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 7e78cddd..f68e83ba 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -142,25 +142,37 @@ def _validate_input_schema(self, input_schema: Schema) -> None: ) def process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a single packet using the pod's packet function. Args: tag: The tag associated with the packet. packet: The input packet to process. + logger: Optional :class:`PacketExecutionLoggerProtocol` for + recording captured I/O. Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` if the function filters the packet out. """ - return tag, self.packet_function.call(packet) + result = self.packet_function.call(packet, logger=logger) + return tag, result async def async_process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``.""" - return tag, await self.packet_function.async_call(packet) + result = await self.packet_function.async_call(packet, logger=logger) + return tag, result def handle_input_streams(self, *streams: StreamProtocol) -> StreamProtocol: """Handle multiple input streams by joining them if necessary. @@ -354,6 +366,9 @@ async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: tag, result_packet = await self.async_process_packet(tag, packet) if result_packet is not None: await output.send((tag, result_packet)) + except Exception: + # Swallow packet-level errors so remaining packets continue. + logger.debug("Packet processing failed, skipping", exc_info=True) finally: if sem is not None: sem.release() diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index 2bb8a911..56539430 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -39,7 +39,7 @@ import polars as pl import pyarrow as pa - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol else: pa = LazyModule("pyarrow") pl = LazyModule("polars") @@ -489,26 +489,35 @@ def execute_packet( Returns: A ``(tag, output_packet)`` tuple. """ - return self._process_packet_internal(tag, packet) + tag_out, result = self._process_packet_internal(tag, packet) + return tag_out, result def execute( self, input_stream: StreamProtocol, *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", ) -> list[tuple[TagProtocol, PacketProtocol]]: """Execute all packets from a stream: compute, persist, and cache. Args: input_stream: The input stream to process. observer: Optional execution observer for hooks. + error_policy: ``"continue"`` (default) skips failed packets; + ``"fail_fast"`` re-raises on the first failure. Returns: Materialized list of (tag, output_packet) pairs, excluding - None outputs. + None outputs and failed packets. """ - if observer is not None: - observer.on_node_start(self) + node_label = self.label or "unknown" + node_hash = self.pipeline_hash().to_string() if self._pipeline_database is not None else "" + + ctx_observer = observer.contextualize(node_hash, node_label) if observer is not None else None + + if ctx_observer is not None: + ctx_observer.on_node_start(node_label, node_hash) # Gather entry IDs and check cache upstream_entries = [ @@ -518,25 +527,47 @@ def execute( entry_ids = [eid for _, _, eid in upstream_entries] cached = self.get_cached_results(entry_ids=entry_ids) + pp = self.pipeline_path if self._pipeline_database is not None else () + output: list[tuple[TagProtocol, PacketProtocol]] = [] for tag, packet, entry_id in upstream_entries: - if observer is not None: - observer.on_packet_start(self, tag, packet) + if ctx_observer is not None: + ctx_observer.on_packet_start(node_label, tag, packet) if entry_id in cached: tag_out, result = cached[entry_id] - if observer is not None: - observer.on_packet_end(self, tag, packet, result, cached=True) + if ctx_observer is not None: + ctx_observer.on_packet_end( + node_label, tag, packet, result, cached=True + ) output.append((tag_out, result)) else: - tag_out, result = self._process_packet_internal(tag, packet) - if observer is not None: - observer.on_packet_end(self, tag, packet, result, cached=False) - if result is not None: - output.append((tag_out, result)) - - if observer is not None: - observer.on_node_end(self) + pkt_logger = ( + ctx_observer.create_packet_logger(tag, packet, pipeline_path=pp) + if ctx_observer is not None + else None + ) + try: + tag_out, result = self._process_packet_internal( + tag, packet, logger=pkt_logger + ) + except Exception as exc: + if ctx_observer is not None: + ctx_observer.on_packet_crash(node_label, tag, packet, exc) + if error_policy == "fail_fast": + if ctx_observer is not None: + ctx_observer.on_node_end(node_label, node_hash) + raise + else: + if ctx_observer is not None: + ctx_observer.on_packet_end( + node_label, tag, packet, result, cached=False + ) + if result is not None: + output.append((tag_out, result)) + + if ctx_observer is not None: + ctx_observer.on_node_end(node_label, node_hash) return output def _process_packet_internal( @@ -544,21 +575,28 @@ def _process_packet_internal( tag: TagProtocol, packet: PacketProtocol, cache_index: int | None = None, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Core compute + persist + cache. Used by ``execute_packet``, ``execute``, and ``iter_packets``. No input validation is performed — the caller guarantees correctness. + Exceptions propagate to the caller. + + Returns: + A ``(tag, output_packet)`` 2-tuple. Args: tag: The input tag. packet: The input packet. cache_index: Optional explicit index for the internal cache. When ``None``, auto-assigns at ``len(_cached_output_packets)``. + logger: Optional packet execution logger. """ if self._cached_function_pod is not None: tag_out, output_packet = self._cached_function_pod.process_packet( - tag, packet + tag, packet, logger=logger ) if output_packet is not None: @@ -574,7 +612,9 @@ def _process_packet_internal( computed=result_computed, ) else: - tag_out, output_packet = self._function_pod.process_packet(tag, packet) + tag_out, output_packet = self._function_pod.process_packet( + tag, packet, logger=logger + ) # Cache internally and invalidate derived caches idx = ( @@ -676,23 +716,30 @@ async def _async_process_packet_internal( tag: TagProtocol, packet: PacketProtocol, cache_index: int | None = None, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``_process_packet_internal``. Computes via async path, writes pipeline provenance, and caches - internally — no schema validation. + internally — no schema validation. Exceptions propagate. + + Returns: + A ``(tag, output_packet)`` 2-tuple. Args: tag: The input tag. packet: The input packet. cache_index: Optional explicit index for the internal cache. When ``None``, auto-assigns at ``len(_cached_output_packets)``. + logger: Optional packet execution logger. """ if self._cached_function_pod is not None: - ( - tag_out, - output_packet, - ) = await self._cached_function_pod.async_process_packet(tag, packet) + tag_out, output_packet = ( + await self._cached_function_pod.async_process_packet( + tag, packet, logger=logger + ) + ) if output_packet is not None: result_computed = bool( @@ -707,8 +754,10 @@ async def _async_process_packet_internal( computed=result_computed, ) else: - tag_out, output_packet = await self._function_pod.async_process_packet( - tag, packet + tag_out, output_packet = ( + await self._function_pod.async_process_packet( + tag, packet, logger=logger + ) ) # Cache internally and invalidate derived caches @@ -1172,7 +1221,7 @@ async def async_execute( input_channel: ReadableChannel[tuple[TagProtocol, PacketProtocol]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: """Streaming async execution for FunctionNode. @@ -1187,9 +1236,14 @@ async def async_execute( """ # TODO(PLT-930): Restore concurrency limiting (semaphore) via node-level config. # Currently all packets are processed sequentially in async_execute. + node_label = self.label or "unknown" + node_hash = self.pipeline_hash().to_string() if self._pipeline_database is not None else "" + + ctx_observer = observer.contextualize(node_hash, node_label) if observer is not None else None + try: - if observer is not None: - observer.on_node_start(self) + if ctx_observer is not None: + ctx_observer.on_node_start(node_label, node_hash) if self._cached_function_pod is not None: # DB-backed async execution: @@ -1245,46 +1299,71 @@ async def async_execute( entry_id = self.compute_pipeline_entry_id(tag, packet) if entry_id in cached_by_entry_id: tag_out, result_packet = cached_by_entry_id[entry_id] - if observer is not None: - observer.on_packet_start(self, tag, packet) - observer.on_packet_end( - self, tag, packet, result_packet, cached=True + if ctx_observer is not None: + ctx_observer.on_packet_start(node_label, tag, packet) + ctx_observer.on_packet_end( + node_label, tag, packet, result_packet, cached=True ) await output.send((tag_out, result_packet)) else: - if observer is not None: - observer.on_packet_start(self, tag, packet) - ( - tag_out, - result_packet, - ) = await self._async_process_packet_internal(tag, packet) - if observer is not None: - observer.on_packet_end( - self, tag, packet, result_packet, cached=False - ) - if result_packet is not None: - await output.send((tag_out, result_packet)) + await self._async_execute_one_packet( + tag, packet, output, + ctx_observer=ctx_observer, + node_label=node_label, + node_hash=node_hash, + ) else: # Simple async execution without DB async for tag, packet in input_channel: - if observer is not None: - observer.on_packet_start(self, tag, packet) - ( - tag_out, - result_packet, - ) = await self._async_process_packet_internal(tag, packet) - if observer is not None: - observer.on_packet_end( - self, tag, packet, result_packet, cached=False - ) - if result_packet is not None: - await output.send((tag_out, result_packet)) + await self._async_execute_one_packet( + tag, packet, output, + ctx_observer=ctx_observer, + node_label=node_label, + node_hash=node_hash, + ) - if observer is not None: - observer.on_node_end(self) + if ctx_observer is not None: + ctx_observer.on_node_end(node_label, node_hash) finally: await output.close() + async def _async_execute_one_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + ctx_observer: "ExecutionObserverProtocol | None" = None, + node_label: str = "unknown", + node_hash: str = "", + ) -> None: + """Process one non-cached packet in the async execute path.""" + pp = self.pipeline_path if self._pipeline_database is not None else () + + if ctx_observer is not None: + ctx_observer.on_packet_start(node_label, tag, packet) + + pkt_logger = ( + ctx_observer.create_packet_logger(tag, packet, pipeline_path=pp) + if ctx_observer is not None + else None + ) + + try: + tag_out, result_packet = await self._async_process_packet_internal( + tag, packet, logger=pkt_logger + ) + except Exception as exc: + if ctx_observer is not None: + ctx_observer.on_packet_crash(node_label, tag, packet, exc) + else: + if ctx_observer is not None: + ctx_observer.on_packet_end( + node_label, tag, packet, result_packet, cached=False + ) + if result_packet is not None: + await output.send((tag_out, result_packet)) + def __repr__(self) -> str: return ( f"{type(self).__name__}(packet_function={self._packet_function!r}, " diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py index b5145d60..9286f09f 100644 --- a/src/orcapod/core/nodes/operator_node.py +++ b/src/orcapod/core/nodes/operator_node.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: import pyarrow as pa - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol else: pa = LazyModule("pyarrow") @@ -445,15 +445,17 @@ def execute( Returns: Materialized list of (tag, packet) pairs. """ + node_label = self.label or "operator" + node_hash = "" if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) # Check REPLAY cache first cached_output = self.get_cached_output() if cached_output is not None: output = list(cached_output.iter_packets()) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return output # Compute @@ -481,7 +483,7 @@ def execute( self._store_output_stream(self._cached_output_stream) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return output def _compute_and_store(self) -> None: @@ -658,21 +660,23 @@ async def async_execute( output: Writable channel for output (tag, packet) pairs. observer: Optional execution observer for hooks. """ + node_label = self.label or "operator" + node_hash = "" if self._pipeline_database is None: # Simple delegation without DB if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) hashes = [s.pipeline_hash() for s in self._input_streams] await self._operator.async_execute( inputs, output, input_pipeline_hashes=hashes ) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return try: if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) if self._cache_mode == CacheMode.REPLAY: self._replay_from_cache() @@ -712,7 +716,7 @@ async def forward() -> None: self._update_modified_time() if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) finally: await output.close() diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py index df8d47b1..e0757c06 100644 --- a/src/orcapod/core/nodes/source_node.py +++ b/src/orcapod/core/nodes/source_node.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: import pyarrow as pa - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol class SourceNode(StreamBase): @@ -253,12 +253,14 @@ def execute( raise RuntimeError( "SourceNode in read-only mode has no stream data available" ) + node_label = self.label or "source" + node_hash = "" if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) result = list(self.stream.iter_packets()) self._cached_results = result if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return result def run(self) -> None: @@ -280,12 +282,14 @@ async def async_execute( raise RuntimeError( "SourceNode in read-only mode has no stream data available" ) + node_label = self.label or "source" + node_hash = "" try: if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) for tag, packet in self.stream.iter_packets(): await output.send((tag, packet)) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) finally: await output.close() diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index 5ee9afdb..03deccde 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -36,11 +36,13 @@ if TYPE_CHECKING: import pyarrow as pa import pyarrow.compute as pc + + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol else: pa = LazyModule("pyarrow") pc = LazyModule("pyarrow.compute") -logger = logging.getLogger(__name__) +logger = _logger = logging.getLogger(__name__) error_handling_options = Literal["raise", "ignore", "warn"] @@ -130,6 +132,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls._resolved_executor_protocol = args[0] return + _SKIP_AUTO_EXECUTOR = object() + def __init__( self, version: str = "v0.0", @@ -137,6 +141,7 @@ def __init__( data_context: str | DataContext | None = None, config: Config | None = None, executor: PacketFunctionExecutorProtocol | None = None, + _skip_auto_executor: bool = False, ): super().__init__(label=label, data_context=data_context, config=config) self._active = True @@ -162,6 +167,12 @@ def __init__( # *after* super().__init__(). if executor is not None: self.executor = executor + elif not _skip_auto_executor: + # Auto-assign LocalExecutor so all execution routes through + # the executor layer (ensuring capture is always available). + from orcapod.core.executors.local import LocalExecutor + + self.executor = LocalExecutor() def computed_label(self) -> str | None: """Return the canonical function name as the label if no explicit label is given.""" @@ -285,35 +296,50 @@ def set_executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: # ==================== Execution ==================== - def call(self, packet: PacketProtocol) -> PacketProtocol | None: + def call( + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Process a single packet, routing through the executor if one is set. Subclasses should override ``direct_call`` instead of this method. """ if self._executor is not None: - return self._executor.execute(self, packet) + return self._executor.execute(self, packet, logger=logger) return self.direct_call(packet) - async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def async_call( + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Asynchronously process a single packet, routing through the executor if set. Subclasses should override ``direct_async_call`` instead of this method. """ if self._executor is not None: - return await self._executor.async_execute(self, packet) + return await self._executor.async_execute(self, packet, logger=logger) return await self.direct_async_call(packet) @abstractmethod - def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: + def direct_call( + self, packet: PacketProtocol + ) -> PacketProtocol | None: """Execute the function's native computation on *packet*. This is the method executors invoke. It bypasses executor routing - and runs the computation directly. Subclasses must implement this. + and runs the computation directly. On user-function failure the + exception is re-raised. Subclasses must implement this. """ ... @abstractmethod - async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def direct_async_call( + self, packet: PacketProtocol + ) -> PacketProtocol | None: """Asynchronous counterpart of ``direct_call``.""" ... @@ -508,61 +534,92 @@ def _call_async_function_sync(self, packet: PacketProtocol) -> Any: _get_sync_executor().submit(lambda: asyncio.run(fn(**kwargs))).result() ) - def call(self, packet: PacketProtocol) -> PacketProtocol | None: + def call( + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Process a single packet, routing through the executor if one is set. When an executor implementing ``PythonFunctionExecutorProtocol`` is - set, the raw callable and kwargs are handed to - ``execute_callable`` and the result is wrapped into an output packet. + set, the raw callable and kwargs are handed to ``execute_callable`` + which captures I/O and records to the logger. """ if self._executor is not None: if not self._active: return None - raw = self._executor.execute_callable(self._function, packet.as_dict()) + raw = self._executor.execute_callable( + self._function, packet.as_dict(), logger=logger + ) return self._build_output_packet(raw) return self.direct_call(packet) - async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def async_call( + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Async counterpart of ``call``.""" if self._executor is not None: if not self._active: return None raw = await self._executor.async_execute_callable( - self._function, packet.as_dict() + self._function, packet.as_dict(), logger=logger ) return self._build_output_packet(raw) return await self.direct_async_call(packet) - def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: - """Execute the function on *packet* synchronously. + def direct_call( + self, packet: PacketProtocol + ) -> PacketProtocol | None: + """Execute the function on *packet* synchronously (no executor path). + On user-function failure the exception is re-raised. For async functions, the coroutine is driven to completion via ``asyncio.run()`` (or a helper thread when already inside an event loop). """ if not self._active: return None + if self._is_async: - values = self._call_async_function_sync(packet) + raw_result = self._call_async_function_sync(packet) else: - values = self._function(**packet.as_dict()) - return self._build_output_packet(values) + raw_result = self._function(**packet.as_dict()) + return self._build_output_packet(raw_result) - async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | None: - """Execute the function on *packet* asynchronously. + async def direct_async_call( + self, packet: PacketProtocol + ) -> PacketProtocol | None: + """Execute the function on *packet* asynchronously (no executor path). Async functions are ``await``-ed directly. Sync functions are - offloaded to a thread pool via ``run_in_executor``. + offloaded to a thread pool via ``run_in_executor``. On failure, + the exception is re-raised. """ + import asyncio + if not self._active: return None + if self._is_async: - values = await self._function(**packet.as_dict()) - return self._build_output_packet(values) + raw_result = await self._function(**packet.as_dict()) else: - import asyncio + import contextvars + import functools loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, self.direct_call, packet) + task_ctx = contextvars.copy_context() + raw_result = await loop.run_in_executor( + None, + functools.partial( + task_ctx.run, + self._function, + **packet.as_dict(), + ), + ) + return self._build_output_packet(raw_result) def to_config(self) -> dict[str, Any]: """Serialize this packet function to a JSON-compatible config dict. @@ -624,8 +681,10 @@ class PacketFunctionWrapper(PacketFunctionBase[E]): """ def __init__(self, packet_function: PacketFunctionProtocol, **kwargs) -> None: - super().__init__(**kwargs) self._packet_function = packet_function + # Skip auto-executor assignment — wrappers delegate executor + # to the wrapped packet function which already has one. + super().__init__(_skip_auto_executor=True, **kwargs) def computed_label(self) -> str | None: return self._packet_function.label @@ -712,16 +771,30 @@ def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: # call/async_call which handles executor routing. direct_call / # direct_async_call bypass executor routing as their names imply. - def call(self, packet: PacketProtocol) -> PacketProtocol | None: - return self._packet_function.call(packet) + def call( + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: + return self._packet_function.call(packet, logger=logger) - async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: - return await self._packet_function.async_call(packet) + async def async_call( + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: + return await self._packet_function.async_call(packet, logger=logger) - def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: + def direct_call( + self, packet: PacketProtocol + ) -> PacketProtocol | None: return self._packet_function.direct_call(packet) - async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def direct_async_call( + self, packet: PacketProtocol + ) -> PacketProtocol | None: return await self._packet_function.direct_async_call(packet) @@ -764,58 +837,59 @@ def call( self, packet: PacketProtocol, *, + logger: "PacketExecutionLoggerProtocol | None" = None, skip_cache_lookup: bool = False, skip_cache_insert: bool = False, ) -> PacketProtocol | None: output_packet = None if not skip_cache_lookup: - logger.info("Checking for cache...") + _logger.info("Checking for cache...") output_packet = self._cache.lookup(packet) if output_packet is not None: - logger.info(f"Cache hit for {packet}!") - if output_packet is None: - output_packet = self._packet_function.call(packet) - if output_packet is not None: - if not skip_cache_insert: - self._cache.store( - packet, - output_packet, - variation_data=self.get_function_variation_data(), - execution_data=self.get_execution_data(), - ) - output_packet = output_packet.with_meta_columns( - **{self.RESULT_COMPUTED_FLAG: True} + _logger.info(f"Cache hit for {packet}!") + return output_packet + output_packet = self._packet_function.call(packet, logger=logger) + if output_packet is not None: + if not skip_cache_insert: + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), ) - + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) return output_packet async def async_call( self, packet: PacketProtocol, *, + logger: "PacketExecutionLoggerProtocol | None" = None, skip_cache_lookup: bool = False, skip_cache_insert: bool = False, ) -> PacketProtocol | None: """Async counterpart of ``call`` with cache check and recording.""" output_packet = None if not skip_cache_lookup: - logger.info("Checking for cache...") + _logger.info("Checking for cache...") output_packet = self._cache.lookup(packet) if output_packet is not None: - logger.info(f"Cache hit for {packet}!") - if output_packet is None: - output_packet = await self._packet_function.async_call(packet) - if output_packet is not None: - if not skip_cache_insert: - self._cache.store( - packet, - output_packet, - variation_data=self.get_function_variation_data(), - execution_data=self.get_execution_data(), - ) - output_packet = output_packet.with_meta_columns( - **{self.RESULT_COMPUTED_FLAG: True} + _logger.info(f"Cache hit for {packet}!") + return output_packet + output_packet = await self._packet_function.async_call(packet, logger=logger) + if output_packet is not None: + if not skip_cache_insert: + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), ) + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) return output_packet def get_cached_output_for_packet( diff --git a/src/orcapod/core/packet_function_proxy.py b/src/orcapod/core/packet_function_proxy.py index 615b94a9..2827ea55 100644 --- a/src/orcapod/core/packet_function_proxy.py +++ b/src/orcapod/core/packet_function_proxy.py @@ -10,13 +10,16 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from orcapod.core.packet_function import PacketFunctionBase from orcapod.errors import PacketFunctionUnavailableError from orcapod.protocols.core_protocols import PacketFunctionProtocol from orcapod.types import ContentHash, Schema +if TYPE_CHECKING: + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol + class PacketFunctionProxy(PacketFunctionBase): """Stand-in for an unavailable packet function. @@ -55,8 +58,10 @@ def __init__( # Call super().__init__ so that major_version and # output_packet_schema_hash are available for URI fallback. + # Skip auto-executor: the proxy delegates executor to the bound + # function (when one is bound). version = inner["version"] - super().__init__(version=version) + super().__init__(version=version, _skip_auto_executor=True) # URI: read from config if present, otherwise compute from metadata. uri_list = config.get("uri") @@ -129,16 +134,26 @@ def _raise_unavailable(self) -> None: f"Use bind() to attach a real function, or access cached results only." ) - def call(self, packet: "PacketProtocol") -> "PacketProtocol | None": + def call( + self, + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Process a single packet; delegates to bound function or raises.""" if self._bound_function is not None: - return self._bound_function.call(packet) + return self._bound_function.call(packet, logger=logger) self._raise_unavailable() - async def async_call(self, packet: "PacketProtocol") -> "PacketProtocol | None": + async def async_call( + self, + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Async counterpart of ``call``.""" if self._bound_function is not None: - return await self._bound_function.async_call(packet) + return await self._bound_function.async_call(packet, logger=logger) self._raise_unavailable() def direct_call(self, packet: "PacketProtocol") -> "PacketProtocol | None": diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 714ea9e7..67dddfe1 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,11 +1,14 @@ from .async_orchestrator import AsyncPipelineOrchestrator from .graph import Pipeline +from .logging_observer import LoggingObserver, PacketLogger from .serialization import LoadStatus, PIPELINE_FORMAT_VERSION from .sync_orchestrator import SyncPipelineOrchestrator __all__ = [ "AsyncPipelineOrchestrator", "LoadStatus", + "LoggingObserver", + "PacketLogger", "PIPELINE_FORMAT_VERSION", "Pipeline", "SyncPipelineOrchestrator", diff --git a/src/orcapod/pipeline/async_orchestrator.py b/src/orcapod/pipeline/async_orchestrator.py index 04a6e7d3..c6e35d41 100644 --- a/src/orcapod/pipeline/async_orchestrator.py +++ b/src/orcapod/pipeline/async_orchestrator.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio +import uuid import logging from collections import defaultdict from typing import TYPE_CHECKING, Any @@ -23,7 +24,7 @@ if TYPE_CHECKING: import networkx as nx - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol logger = logging.getLogger(__name__) @@ -48,16 +49,19 @@ class AsyncPipelineOrchestrator: def __init__( self, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, buffer_size: int = 64, + error_policy: str = "continue", ) -> None: self._observer = observer self._buffer_size = buffer_size + self._error_policy = error_policy def run( self, graph: "nx.DiGraph", materialize_results: bool = True, + run_id: str | None = None, ) -> OrchestratorResult: """Synchronous entry point — runs the async pipeline to completion. @@ -65,36 +69,46 @@ def run( graph: A NetworkX DiGraph with GraphNode objects as vertices. materialize_results: If True, collect all node outputs into the result. If False, return empty node_outputs. + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. Returns: OrchestratorResult with node outputs. """ - return asyncio.run(self._run_async(graph, materialize_results)) + return asyncio.run(self._run_async(graph, materialize_results, run_id=run_id)) async def run_async( self, graph: "nx.DiGraph", materialize_results: bool = True, + run_id: str | None = None, ) -> OrchestratorResult: """Async entry point for callers already inside an event loop. Args: graph: A NetworkX DiGraph with GraphNode objects as vertices. materialize_results: If True, collect all node outputs. + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. Returns: OrchestratorResult with node outputs. """ - return await self._run_async(graph, materialize_results) + return await self._run_async(graph, materialize_results, run_id=run_id) async def _run_async( self, graph: "nx.DiGraph", materialize_results: bool, + run_id: str | None = None, ) -> OrchestratorResult: """Core async logic: wire channels, launch tasks, collect results.""" import networkx as nx + effective_run_id = run_id or str(uuid.uuid4()) + if self._observer is not None: + self._observer.on_run_start(effective_run_id) + topo_order = list(nx.topological_sort(graph)) buf = self._buffer_size @@ -186,6 +200,8 @@ async def _run_async( for ch in terminal_channels: tg.create_task(ch.reader.collect()) + if self._observer is not None: + self._observer.on_run_end(effective_run_id) return OrchestratorResult( node_outputs=collectors if materialize_results else {} ) diff --git a/src/orcapod/pipeline/logging_capture.py b/src/orcapod/pipeline/logging_capture.py new file mode 100644 index 00000000..e2d8cee2 --- /dev/null +++ b/src/orcapod/pipeline/logging_capture.py @@ -0,0 +1,235 @@ +"""Capture infrastructure for observability logging. + +Provides context-variable-local capture of stdout, stderr, and Python logging +for use in FunctionNode execution. Thread-safe and asyncio-task-safe via +``contextvars.ContextVar`` — captures from concurrent packets never intermingle. + +CapturedLogs travel as part of the return type through the call chain +(``direct_call`` → ``call`` → ``process_packet`` → FunctionNode) so there +is no ContextVar side-channel for logs. Each executor's ``execute_callable`` +returns ``(raw_result, CapturedLogs)``, and ``direct_call`` returns +``(output_packet, CapturedLogs)`` — catching user-function exceptions +internally rather than re-raising. + +Typical usage +------------- +Call ``install_capture_streams()`` once when a logging Observer is created. +The executor or ``direct_call`` wraps function execution in +``LocalCaptureContext`` and returns CapturedLogs alongside the result:: + + result, captured = packet_function.call(packet) + pkt_logger.record(captured) +""" + +from __future__ import annotations + +import contextvars +import io +import logging +import sys +from dataclasses import dataclass +from typing import Any + + +# --------------------------------------------------------------------------- +# CapturedLogs +# --------------------------------------------------------------------------- + + +@dataclass +class CapturedLogs: + """I/O captured from a single packet function execution.""" + + stdout: str = "" + stderr: str = "" + python_logs: str = "" + traceback: str | None = None + success: bool = True + + +# --------------------------------------------------------------------------- +# Context variables +# --------------------------------------------------------------------------- +# Each asyncio task and thread gets its own copy of these variables, so +# captures from concurrent packets never intermingle. + +_stdout_capture: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_stdout_capture", default=None +) +_stderr_capture: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_stderr_capture", default=None +) +_log_capture: contextvars.ContextVar[list[str] | None] = contextvars.ContextVar( + "_log_capture", default=None +) + + +# --------------------------------------------------------------------------- +# ContextLocalTeeStream +# --------------------------------------------------------------------------- + + +class ContextLocalTeeStream: + """A stream that writes to the original *and* a per-context capture buffer. + + All writes go to *original* (terminal output is preserved) and also to a + ``StringIO`` buffer active for the current asyncio task / thread (selected + via ``capture_var``). Concurrent tasks each have their own buffer and do + not interfere with each other. + """ + + def __init__( + self, + original: Any, + capture_var: contextvars.ContextVar[io.StringIO | None], + ) -> None: + self._original = original + self._capture_var = capture_var + + def write(self, s: str) -> int: + buf = self._capture_var.get() + if buf is not None: + buf.write(s) + return self._original.write(s) + + def flush(self) -> None: + buf = self._capture_var.get() + if buf is not None: + buf.flush() + self._original.flush() + + def __getattr__(self, name: str) -> Any: + return getattr(self._original, name) + + +# --------------------------------------------------------------------------- +# ContextVarLoggingHandler +# --------------------------------------------------------------------------- + + +class ContextVarLoggingHandler(logging.Handler): + """A logging handler that captures records into a per-context buffer. + + When a capture buffer is active for the current context (asyncio task or + thread), log records are formatted and appended to it. When no buffer is + active the record is silently discarded (not duplicated to other handlers). + """ + + def emit(self, record: logging.LogRecord) -> None: + buf = _log_capture.get() + if buf is not None: + buf.append(self.format(record)) + + +# --------------------------------------------------------------------------- +# Global installation (idempotent) +# --------------------------------------------------------------------------- + +_installed = False +_logging_handler: ContextVarLoggingHandler | None = None + + +def install_capture_streams() -> None: + """Install tee streams and the logging handler globally. + + Idempotent — safe to call multiple times. Should be called once when a + concrete logging Observer is instantiated. + + After installation: + + * ``sys.stdout`` / ``sys.stderr`` tee writes to per-context buffers while + also forwarding to the original stream (terminal output preserved). + * The root logger gains a ``ContextVarLoggingHandler`` that captures + records to per-context buffers (covering Python ``logging`` calls). + + .. note:: + Subprocess and C-extension output bypasses Python's stream objects and + goes directly to file descriptors 1/2. For local execution these are + *not* captured (but are still visible in the terminal). Ray remote + execution uses fd-level capture via + ``RayExecutor._make_capture_wrapper``. + + The stream check runs on every call so that if something (e.g. a test + harness) replaces ``sys.stdout``/``sys.stderr`` between calls we + re-wrap the new stream. The logging handler is only added once. + """ + global _installed, _logging_handler + + # Always re-check in case sys.stdout/stderr was replaced (e.g. by pytest). + if not isinstance(sys.stdout, ContextLocalTeeStream): + sys.stdout = ContextLocalTeeStream(sys.stdout, _stdout_capture) + if not isinstance(sys.stderr, ContextLocalTeeStream): + sys.stderr = ContextLocalTeeStream(sys.stderr, _stderr_capture) + + if _installed: + return + + _logging_handler = ContextVarLoggingHandler() + _logging_handler.setFormatter( + logging.Formatter("%(levelname)s:%(name)s:%(message)s") + ) + logging.getLogger().addHandler(_logging_handler) + + _installed = True + + +# --------------------------------------------------------------------------- +# LocalCaptureContext +# --------------------------------------------------------------------------- + + +class LocalCaptureContext: + """Context manager that activates per-context capture for one packet. + + Requires ``install_capture_streams()`` to have been called; without it the + ContextVars are set but nothing tees into them, so captured strings will be + empty (acceptable when no logging Observer is configured). + + Example:: + + ctx = LocalCaptureContext() + try: + with ctx: + result = call_something() + except Exception: + captured = ctx.get_captured(success=False, tb=traceback.format_exc()) + else: + captured = ctx.get_captured(success=True) + """ + + def __init__(self) -> None: + self._stdout_buf = io.StringIO() + self._stderr_buf = io.StringIO() + self._log_buf: list[str] = [] + self._tokens: list[contextvars.Token] = [] + + def __enter__(self) -> "LocalCaptureContext": + self._tokens = [ + _stdout_capture.set(self._stdout_buf), + _stderr_capture.set(self._stderr_buf), + _log_capture.set(self._log_buf), + ] + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + for token, var in zip( + self._tokens, + [_stdout_capture, _stderr_capture, _log_capture], + ): + var.reset(token) + return False # do not suppress exceptions + + def get_captured( + self, + success: bool, + tb: str | None = None, + ) -> CapturedLogs: + """Return a :class:`CapturedLogs` from what was captured in this context.""" + return CapturedLogs( + stdout=self._stdout_buf.getvalue(), + stderr=self._stderr_buf.getvalue(), + python_logs="\n".join(self._log_buf) if self._log_buf else "", + traceback=tb, + success=success, + ) + diff --git a/src/orcapod/pipeline/logging_observer.py b/src/orcapod/pipeline/logging_observer.py new file mode 100644 index 00000000..af4054dd --- /dev/null +++ b/src/orcapod/pipeline/logging_observer.py @@ -0,0 +1,372 @@ +"""Concrete logging observer for orcapod pipelines. + +Provides :class:`LoggingObserver`, a drop-in observer that captures stdout, +stderr, Python logging, and tracebacks from every packet execution and writes +structured log rows to any :class:`~orcapod.protocols.database_protocols.ArrowDatabaseProtocol` +(in-memory, Delta Lake, etc.). + +Typical usage:: + + from orcapod.pipeline.logging_observer import LoggingObserver + from orcapod.pipeline import SyncPipelineOrchestrator + from orcapod.databases import InMemoryArrowDatabase + + obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + pipeline.run(orchestrator=SyncPipelineOrchestrator(observer=obs)) + + # Inspect captured logs + logs = obs.get_logs() # pyarrow.Table + logs.to_pandas() # pandas DataFrame + +Log schema (fixed columns) +-------------------------- +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``log_id`` + - ``large_utf8`` + - UUID unique to this log entry + * - ``run_id`` + - ``large_utf8`` + - UUID of the pipeline run (from ``on_run_start``) + * - ``node_label`` + - ``large_utf8`` + - Label of the function node + * - ``node_hash`` + - ``large_utf8`` + - Pipeline hash of the function node + * - ``stdout`` + - ``large_utf8`` + - Captured standard output + * - ``stderr`` + - ``large_utf8`` + - Captured standard error + * - ``python_logs`` + - ``large_utf8`` + - Python ``logging`` output captured during execution + * - ``traceback`` + - ``large_utf8`` + - Full traceback on failure; ``None`` on success + * - ``success`` + - ``bool_`` + - ``True`` if the packet function returned normally + * - ``timestamp`` + - ``large_utf8`` + - ISO-8601 UTC timestamp when ``record()`` was called + +In addition, each tag key from the packet's tag becomes a separate +``large_utf8`` column (queryable, not JSON-encoded). + +Log storage +----------- +Logs are stored at a pipeline-path-mirrored location: +``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. +Each function node gets its own log table. Use +``get_logs(pipeline_path=node.pipeline_path)`` to retrieve +node-specific logs. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from uuid_utils import uuid7 + +from orcapod.pipeline.logging_capture import CapturedLogs, install_capture_streams + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +logger = logging.getLogger(__name__) + +# Default path (table name) within the database where log rows are stored. +DEFAULT_LOG_PATH: tuple[str, ...] = ("execution_logs",) + + +class PacketLogger: + """Context-bound logger created by :class:`LoggingObserver` per packet. + + Holds all context needed to write a structured log row + (run_id, node_label, node_hash, tag data) so the caller only needs to + pass the :class:`~orcapod.pipeline.logging_capture.CapturedLogs` payload. + + Tag data is stored as individual queryable columns (not JSON) alongside + the fixed log columns. + + This class is not intended to be instantiated directly — use + :meth:`LoggingObserver.create_packet_logger` instead. + """ + + def __init__( + self, + db: "ArrowDatabaseProtocol", + log_path: tuple[str, ...], + run_id: str, + node_label: str, + node_hash: str, + tag_data: dict[str, Any], + ) -> None: + self._db = db + self._log_path = log_path + self._run_id = run_id + self._node_label = node_label + self._node_hash = node_hash + self._tag_data = tag_data + + def record(self, captured: CapturedLogs) -> None: + """Write one log row to the database.""" + import pyarrow as pa + + log_id = str(uuid7()) + timestamp = datetime.now(timezone.utc).isoformat() + + # Fixed columns + columns: dict[str, pa.Array] = { + "log_id": pa.array([log_id], type=pa.large_utf8()), + "run_id": pa.array([self._run_id], type=pa.large_utf8()), + "node_label": pa.array([self._node_label], type=pa.large_utf8()), + "node_hash": pa.array([self._node_hash], type=pa.large_utf8()), + "stdout": pa.array([captured.stdout], type=pa.large_utf8()), + "stderr": pa.array([captured.stderr], type=pa.large_utf8()), + "python_logs": pa.array([captured.python_logs], type=pa.large_utf8()), + "traceback": pa.array([captured.traceback], type=pa.large_utf8()), + "success": pa.array([captured.success], type=pa.bool_()), + "timestamp": pa.array([timestamp], type=pa.large_utf8()), + } + + # Dynamic tag columns — each tag key becomes its own column + for key, value in self._tag_data.items(): + columns[key] = pa.array([str(value)], type=pa.large_utf8()) + + row = pa.table(columns) + try: + self._db.add_record(self._log_path, log_id, row, flush=True) + except Exception: + logger.exception( + "LoggingObserver: failed to write log row for node=%s", + self._node_label, + ) + + +class _ContextualizedLoggingObserver: + """Lightweight wrapper holding parent observer + node identity context. + + Created by :meth:`LoggingObserver.contextualize`. All lifecycle hooks + and logger creation use the stamped ``node_hash`` and ``node_label``. + """ + + def __init__( + self, + parent: "LoggingObserver", + node_hash: str, + node_label: str, + ) -> None: + self._parent = parent + self._node_hash = node_hash + self._node_label = node_label + + def contextualize( + self, node_hash: str, node_label: str + ) -> "_ContextualizedLoggingObserver": + """Re-contextualize (returns a new wrapper with updated identity).""" + return _ContextualizedLoggingObserver(self._parent, node_hash, node_label) + + def on_run_start(self, run_id: str) -> None: + self._parent.on_run_start(run_id) + + def on_run_end(self, run_id: str) -> None: + self._parent.on_run_end(run_id) + + def on_node_start(self, node_label: str, node_hash: str) -> None: + self._parent.on_node_start(node_label, node_hash) + + def on_node_end(self, node_label: str, node_hash: str) -> None: + self._parent.on_node_end(node_label, node_hash) + + def on_packet_start( + self, node_label: str, tag: Any, packet: Any + ) -> None: + self._parent.on_packet_start(node_label, tag, packet) + + def on_packet_end( + self, + node_label: str, + tag: Any, + input_packet: Any, + output_packet: Any, + cached: bool, + ) -> None: + self._parent.on_packet_end(node_label, tag, input_packet, output_packet, cached) + + def on_packet_crash( + self, node_label: str, tag: Any, packet: Any, error: Exception + ) -> None: + self._parent.on_packet_crash(node_label, tag, packet, error) + + def create_packet_logger( + self, + tag: Any, + packet: Any, + pipeline_path: tuple[str, ...] = (), + ) -> PacketLogger: + """Create a logger using context from this wrapper.""" + tag_data = dict(tag) + + # Compute mirrored log path + if pipeline_path: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._parent._log_path + + return PacketLogger( + db=self._parent._db, + log_path=log_path, + run_id=self._parent._current_run_id, + node_label=self._node_label, + node_hash=self._node_hash, + tag_data=tag_data, + ) + + +class LoggingObserver: + """Concrete observer that writes packet execution logs to a database. + + Instantiate once, outside the pipeline, and pass to the orchestrator:: + + obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + # After the run, read back captured logs: + logs_table = obs.get_logs() # pyarrow.Table + + For async / Ray pipelines use :class:`~orcapod.pipeline.AsyncPipelineOrchestrator` + with the same observer:: + + orch = AsyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + Args: + log_database: Any :class:`~orcapod.protocols.database_protocols.ArrowDatabaseProtocol` + instance — :class:`~orcapod.databases.InMemoryArrowDatabase`, + a Delta Lake database, etc. + log_path: Tuple of strings identifying the table within the database. + Defaults to ``("execution_logs",)``. + + Note: + Construction calls :func:`~orcapod.pipeline.logging_capture.install_capture_streams` + so that stdout/stderr tee-capture is active from the moment the observer + is created. + """ + + def __init__( + self, + log_database: "ArrowDatabaseProtocol", + log_path: tuple[str, ...] | None = None, + ) -> None: + self._db = log_database + self._log_path = log_path or DEFAULT_LOG_PATH + self._current_run_id: str = "" + # Activate tee-capture as soon as the observer is created. + install_capture_streams() + + # -- contextualize -- + + def contextualize( + self, node_hash: str, node_label: str + ) -> _ContextualizedLoggingObserver: + """Return a contextualized wrapper stamped with node identity.""" + return _ContextualizedLoggingObserver(self, node_hash, node_label) + + # -- lifecycle hooks -- + + def on_run_start(self, run_id: str) -> None: + self._current_run_id = run_id + + def on_run_end(self, run_id: str) -> None: + pass + + def on_node_start(self, node_label: str, node_hash: str) -> None: + pass + + def on_node_end(self, node_label: str, node_hash: str) -> None: + pass + + def on_packet_start(self, node_label: str, tag: Any, packet: Any) -> None: + pass + + def on_packet_end( + self, + node_label: str, + tag: Any, + input_packet: Any, + output_packet: Any, + cached: bool, + ) -> None: + pass + + def on_packet_crash( + self, node_label: str, tag: Any, packet: Any, error: Exception + ) -> None: + pass + + def create_packet_logger( + self, + tag: Any, + packet: Any, + pipeline_path: tuple[str, ...] = (), + ) -> PacketLogger: + """Return a :class:`PacketLogger` bound to *tag* context. + + Log rows are stored at a pipeline-path-mirrored location: + ``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. This gives + each function node its own log table in the database. + + Note: + When called directly on ``LoggingObserver`` (not a contextualized + wrapper), node_label and node_hash default to "unknown". Prefer + calling via a contextualized observer. + """ + tag_data = dict(tag) + + # Compute mirrored log path + if pipeline_path: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._log_path + + return PacketLogger( + db=self._db, + log_path=log_path, + run_id=self._current_run_id, + node_label="unknown", + node_hash="unknown", + tag_data=tag_data, + ) + + # -- convenience -- + + def get_logs( + self, pipeline_path: tuple[str, ...] | None = None + ) -> "pa.Table | None": + """Read log rows from the database as a :class:`pyarrow.Table`. + + Args: + pipeline_path: If provided, reads logs for a specific node + (mirrored path). If ``None``, reads from the default + log path. + + Returns ``None`` if no logs have been written yet. + """ + if pipeline_path is not None: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._log_path + return self._db.get_all_records(log_path) diff --git a/src/orcapod/pipeline/observer.py b/src/orcapod/pipeline/observer.py index 22a34138..c19ba366 100644 --- a/src/orcapod/pipeline/observer.py +++ b/src/orcapod/pipeline/observer.py @@ -1,57 +1,75 @@ -"""Execution observer protocol for pipeline orchestration. +"""No-op implementations of the observability protocols. -Provides hooks for monitoring node and packet-level execution events -during orchestrated pipeline runs. +Provides :class:`NoOpLogger` and :class:`NoOpObserver` — the defaults used +when no observability is configured. Every method is a zero-cost no-op. """ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING + +from orcapod.protocols.observability_protocols import ( # noqa: F401 (re-exported for convenience) + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) if TYPE_CHECKING: - from orcapod.core.nodes import GraphNode + from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol -@runtime_checkable -class ExecutionObserver(Protocol): - """Observer protocol for pipeline execution events. +# --------------------------------------------------------------------------- +# NoOpLogger +# --------------------------------------------------------------------------- + + +class NoOpLogger: + """Logger that discards all captured output. - ``on_packet_start`` / ``on_packet_end`` are only invoked for function - nodes. ``on_node_start`` / ``on_node_end`` are invoked for all node - types. + Returned by :class:`NoOpObserver` when no logging sink is configured. """ - def on_node_start(self, node: "GraphNode") -> None: ... - def on_node_end(self, node: "GraphNode") -> None: ... - def on_packet_start( - self, - node: "GraphNode", - tag: "TagProtocol", - packet: "PacketProtocol", - ) -> None: ... - def on_packet_end( - self, - node: "GraphNode", - tag: "TagProtocol", - input_packet: "PacketProtocol", - output_packet: "PacketProtocol | None", - cached: bool, - ) -> None: ... + def record(self, captured: "CapturedLogs") -> None: + pass + + +# Singleton — NoOpLogger carries no state so one instance is enough. +_NOOP_LOGGER = NoOpLogger() + + +# --------------------------------------------------------------------------- +# NoOpObserver +# --------------------------------------------------------------------------- class NoOpObserver: - """Default observer that does nothing.""" + """Observer that does nothing. - def on_node_start(self, node: "GraphNode") -> None: + Satisfies :class:`~orcapod.protocols.observability_protocols.ExecutionObserverProtocol` + and is the default when no observability is configured. + ``create_packet_logger`` returns the shared :data:`_NOOP_LOGGER` singleton. + """ + + def contextualize( + self, node_hash: str, node_label: str + ) -> "NoOpObserver": + return self + + def on_run_start(self, run_id: str) -> None: + pass + + def on_run_end(self, run_id: str) -> None: pass - def on_node_end(self, node: "GraphNode") -> None: + def on_node_start(self, node_label: str, node_hash: str) -> None: + pass + + def on_node_end(self, node_label: str, node_hash: str) -> None: pass def on_packet_start( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", packet: "PacketProtocol", ) -> None: @@ -59,10 +77,27 @@ def on_packet_start( def on_packet_end( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", input_packet: "PacketProtocol", output_packet: "PacketProtocol | None", cached: bool, ) -> None: pass + + def on_packet_crash( + self, + node_label: str, + tag: "TagProtocol", + packet: "PacketProtocol", + error: Exception, + ) -> None: + pass + + def create_packet_logger( + self, + tag: "TagProtocol", + packet: "PacketProtocol", + pipeline_path: tuple[str, ...] = (), + ) -> NoOpLogger: + return _NOOP_LOGGER diff --git a/src/orcapod/pipeline/sync_orchestrator.py b/src/orcapod/pipeline/sync_orchestrator.py index e788421c..b3cf8d74 100644 --- a/src/orcapod/pipeline/sync_orchestrator.py +++ b/src/orcapod/pipeline/sync_orchestrator.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +import uuid from typing import TYPE_CHECKING, Any from orcapod.pipeline.result import OrchestratorResult @@ -19,7 +20,7 @@ if TYPE_CHECKING: import networkx as nx - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol logger = logging.getLogger(__name__) @@ -39,13 +40,19 @@ class SyncPipelineOrchestrator: observer: Optional execution observer forwarded to nodes. """ - def __init__(self, observer: "ExecutionObserver | None" = None) -> None: + def __init__( + self, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", + ) -> None: self._observer = observer + self._error_policy = error_policy def run( self, graph: "nx.DiGraph", materialize_results: bool = True, + run_id: str | None = None, ) -> OrchestratorResult: """Execute the node graph synchronously. @@ -54,12 +61,18 @@ def run( materialize_results: If True, keep all node outputs in memory and return them. If False, discard buffers after downstream consumption (only DB-persisted results survive). + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. Returns: OrchestratorResult with node outputs. """ import networkx as nx + effective_run_id = run_id or str(uuid.uuid4()) + if self._observer is not None: + self._observer.on_run_start(effective_run_id) + topo_order = list(nx.topological_sort(graph)) buffers: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} processed: set[Any] = set() @@ -71,7 +84,11 @@ def run( upstream_buf = self._gather_upstream(node, graph, buffers) upstream_node = list(graph.predecessors(node))[0] input_stream = self._materialize_as_stream(upstream_buf, upstream_node) - buffers[node] = node.execute(input_stream, observer=self._observer) + buffers[node] = node.execute( + input_stream, + observer=self._observer, + error_policy=self._error_policy, + ) elif is_operator_node(node): upstream_buffers = self._gather_upstream_multi(node, graph, buffers) input_streams = [ @@ -91,6 +108,9 @@ def run( if not materialize_results: buffers.clear() + + if self._observer is not None: + self._observer.on_run_end(effective_run_id) return OrchestratorResult(node_outputs=buffers) @staticmethod diff --git a/src/orcapod/protocols/__init__.py b/src/orcapod/protocols/__init__.py index e69de29b..7f8ba7c7 100644 --- a/src/orcapod/protocols/__init__.py +++ b/src/orcapod/protocols/__init__.py @@ -0,0 +1,4 @@ +from orcapod.protocols.observability_protocols import ( + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py index 2260898a..8290cc26 100644 --- a/src/orcapod/protocols/core_protocols/executor.py +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from orcapod.protocols.core_protocols.packet_function import PacketFunctionProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol @runtime_checkable @@ -36,20 +37,25 @@ def supports(self, packet_function_type_id: str) -> bool: def execute( self, - packet_function: PacketFunctionProtocol, + packet_function: "PacketFunctionProtocol", packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> PacketProtocol | None: """Synchronously execute *packet_function* on *packet*. The executor should invoke ``packet_function.direct_call(packet)`` - in the appropriate execution environment. + in the appropriate execution environment and return the result. + If a logger is provided, the executor records captured I/O to it. """ ... async def async_execute( self, - packet_function: PacketFunctionProtocol, + packet_function: "PacketFunctionProtocol", packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> PacketProtocol | None: """Asynchronous counterpart of ``execute``.""" ... @@ -98,17 +104,22 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: - """Synchronously execute *fn* with *kwargs*. + """Synchronously execute *fn* with *kwargs*, capturing I/O. Args: fn: The Python callable to execute. kwargs: Keyword arguments to pass to *fn*. executor_options: Optional per-call options (e.g. resource overrides). + logger: Optional logger to record captured I/O. Returns: - The raw return value of *fn*. + The raw return value of *fn* (or ``None`` on failure). + On failure, the executor re-raises the original exception + after recording captured logs. """ ... @@ -117,13 +128,16 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> Any: - """Asynchronously execute *fn* with *kwargs*. + """Asynchronously execute *fn* with *kwargs*, capturing I/O. Args: fn: The Python callable to execute. kwargs: Keyword arguments to pass to *fn*. executor_options: Optional per-call options. + logger: Optional logger to record captured I/O. Returns: The raw return value of *fn*. diff --git a/src/orcapod/protocols/core_protocols/packet_function.py b/src/orcapod/protocols/core_protocols/packet_function.py index 37ab97f0..7792e873 100644 --- a/src/orcapod/protocols/core_protocols/packet_function.py +++ b/src/orcapod/protocols/core_protocols/packet_function.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from orcapod.protocols.core_protocols.datagrams import PacketProtocol from orcapod.protocols.core_protocols.executor import PacketFunctionExecutorProtocol @@ -11,6 +11,9 @@ ) from orcapod.types import Schema +if TYPE_CHECKING: + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol + @runtime_checkable class PacketFunctionProtocol( @@ -78,28 +81,35 @@ def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: def call( self, packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> PacketProtocol | None: """Process a single packet, routing through the executor if one is set. Args: packet: The data payload to process. + logger: Optional logger for recording captured I/O. Returns: - The processed packet, or ``None`` to filter it out. + The output packet, or ``None`` when the function filters the + packet out or when execution failed (exception is re-raised). """ ... async def async_call( self, packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, ) -> PacketProtocol | None: """Asynchronously process a single packet, routing through the executor if set. Args: packet: The data payload to process. + logger: Optional logger for recording captured I/O. Returns: - The processed packet, or ``None`` to filter it out. + The output packet, or ``None``. """ ... @@ -110,12 +120,13 @@ def direct_call( """Execute the function's native computation on *packet*. This is the method executors invoke, bypassing executor routing. + On user-function failure the exception is re-raised. Args: packet: The data payload to process. Returns: - The processed packet, or ``None`` to filter it out. + The output packet, or ``None`` if filtered. """ ... diff --git a/src/orcapod/protocols/node_protocols.py b/src/orcapod/protocols/node_protocols.py index 95677412..f51e6de6 100644 --- a/src/orcapod/protocols/node_protocols.py +++ b/src/orcapod/protocols/node_protocols.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.nodes import GraphNode - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol from orcapod.protocols.core_protocols import ( PacketProtocol, StreamProtocol, @@ -34,14 +34,14 @@ class SourceNodeProtocol(Protocol): def execute( self, *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... async def async_execute( self, output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: ... @@ -55,7 +55,8 @@ def execute( self, input_stream: "StreamProtocol", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... async def async_execute( @@ -63,7 +64,7 @@ async def async_execute( input_channel: "ReadableChannel[tuple[TagProtocol, PacketProtocol]]", output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: ... @@ -76,7 +77,7 @@ class OperatorNodeProtocol(Protocol): def execute( self, *input_streams: "StreamProtocol", - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... async def async_execute( @@ -84,7 +85,7 @@ async def async_execute( inputs: "Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]]", output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: ... diff --git a/src/orcapod/protocols/observability_protocols.py b/src/orcapod/protocols/observability_protocols.py new file mode 100644 index 00000000..3b4a740f --- /dev/null +++ b/src/orcapod/protocols/observability_protocols.py @@ -0,0 +1,156 @@ +"""Observability protocols for pipeline execution tracking and logging. + +Defines: + +* :class:`PacketExecutionLoggerProtocol` — receives captured I/O from a single + packet execution and persists it to a configured sink. +* :class:`ExecutionObserverProtocol` — lifecycle hooks for pipeline/node/packet + events, plus a factory method for creating context-bound loggers. + +Both follow the same runtime-checkable Protocol pattern used throughout the +rest of the orcapod codebase. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +@runtime_checkable +class PacketExecutionLoggerProtocol(Protocol): + """Receives captured execution output and persists it. + + A logger is *bound* to a specific packet execution context (node, tag, + packet) when created by the Observer. It knows the destination (e.g. a + Delta Lake table) but does not know how the logs were collected — that is + the executor's responsibility. + """ + + def record(self, captured: "CapturedLogs") -> None: + """Persist the captured logs from a packet function execution. + + Called after every packet execution (success or failure), except for + cache hits when ``log_cache_hits=False`` (the default). + """ + ... + + +@runtime_checkable +class ExecutionObserverProtocol(Protocol): + """Observer protocol for pipeline execution lifecycle events. + + Instantiated once outside the pipeline and injected into the orchestrator. + Provides hooks for lifecycle events at the run, node, and packet level, and + acts as a factory for context-specific loggers. + + ``on_packet_start`` / ``on_packet_end`` / ``on_packet_crash`` are invoked + only for function nodes. ``on_node_start`` / ``on_node_end`` are invoked + for all node types. + + Observers are *contextualized* per node via :meth:`contextualize`, which + returns a lightweight wrapper stamped with node identity. The contextualized + observer is used for all hooks and logger creation within that node. + """ + + def contextualize( + self, node_hash: str, node_label: str + ) -> "ExecutionObserverProtocol": + """Return a contextualized copy stamped with node identity. + + Args: + node_hash: The pipeline hash of the node (stable identity). + node_label: Human-readable label of the node. + + Returns: + An observer (possibly a lightweight wrapper) that carries + node_hash and node_label context for all subsequent calls. + """ + ... + + def on_run_start(self, run_id: str) -> None: + """Called at the very start of an orchestrator ``run()`` call. + + Args: + run_id: A UUID string unique to this execution run. All loggers + created during the run will be stamped with this ID. + """ + ... + + def on_run_end(self, run_id: str) -> None: + """Called at the very end of an orchestrator ``run()`` call. + + Args: + run_id: The same UUID passed to ``on_run_start``. + """ + ... + + def on_node_start(self, node_label: str, node_hash: str) -> None: + """Called before a node begins processing its packets.""" + ... + + def on_node_end(self, node_label: str, node_hash: str) -> None: + """Called after a node finishes processing all packets.""" + ... + + def on_packet_start( + self, + node_label: str, + tag: "TagProtocol", + packet: "PacketProtocol", + ) -> None: + """Called before a packet is processed by a function node.""" + ... + + def on_packet_end( + self, + node_label: str, + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + cached: bool, + ) -> None: + """Called after a packet is successfully processed (or served from cache). + + Args: + cached: ``True`` when the result came from a database cache and + the user function was not executed. + """ + ... + + def on_packet_crash( + self, + node_label: str, + tag: "TagProtocol", + packet: "PacketProtocol", + error: Exception, + ) -> None: + """Called when a packet's execution fails. + + Covers both user-function exceptions (captured on the worker) and + system-level crashes (e.g. ``WorkerCrashedError`` from Ray). The + pipeline continues processing remaining packets rather than aborting. + """ + ... + + def create_packet_logger( + self, + tag: "TagProtocol", + packet: "PacketProtocol", + pipeline_path: tuple[str, ...] = (), + ) -> PacketExecutionLoggerProtocol: + """Create a context-bound logger for a single packet execution. + + The returned logger is pre-stamped with the node label, run ID, and + packet identity so every ``record()`` call writes the correct context + without the executor needing to know anything about the pipeline. + + Args: + tag: The tag for the packet being processed. + packet: The input packet being processed. + pipeline_path: The node's pipeline path for log storage scoping. + """ + ... diff --git a/tests/test_channels/test_async_execute.py b/tests/test_channels/test_async_execute.py index 6aa4124d..a92a76ed 100644 --- a/tests/test_channels/test_async_execute.py +++ b/tests/test_channels/test_async_execute.py @@ -186,7 +186,8 @@ def record_thread(x: int) -> int: return x pf = PythonPacketFunction(record_thread, output_keys="result") - await pf.direct_async_call(Packet({"x": 42})) + result = await pf.direct_async_call(Packet({"x": 42})) + assert result is not None assert len(call_threads) == 1 @@ -727,8 +728,8 @@ async def push(stream, ch): class TestErrorPropagation: @pytest.mark.asyncio - async def test_function_exception_propagates(self): - """An exception in the packet function should propagate out.""" + async def test_function_exception_returns_none(self): + """An exception in the packet function is caught by process_packet — no raise.""" def failing(x: int) -> int: if x == 2: @@ -743,13 +744,24 @@ def failing(x: int) -> int: output_ch = Channel(buffer_size=16) await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + # The failing packet (x=2) is silently dropped; 4 of 5 succeed + assert len(results) == 4 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 1, 3, 4] - with pytest.raises(ExceptionGroup) as exc_info: - await pod.async_execute([input_ch.reader], output_ch.writer) + @pytest.mark.asyncio + async def test_direct_async_call_raises_on_failure(self): + """direct_async_call re-raises the exception on error.""" - # Should contain the ValueError from the function - causes = exc_info.value.exceptions - assert any(isinstance(e, ValueError) and "boom" in str(e) for e in causes) + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result") + with pytest.raises(ValueError, match="boom"): + await pf.direct_async_call(Packet({"x": 1})) # --------------------------------------------------------------------------- diff --git a/tests/test_channels/test_copilot_review_issues.py b/tests/test_channels/test_copilot_review_issues.py index 3c786f09..4a1bd915 100644 --- a/tests/test_channels/test_copilot_review_issues.py +++ b/tests/test_channels/test_copilot_review_issues.py @@ -142,7 +142,7 @@ async def run_in_loop(): nonlocal creation_count with patch.object(ThreadPoolExecutor, "__init__", counting_init): for _ in range(3): - pf.direct_call(Packet({"x": 1, "y": 2})) + _result = pf.direct_call(Packet({"x": 1, "y": 2})) asyncio.run(run_in_loop()) # Current code creates a new executor per call, so creation_count == 3. @@ -213,7 +213,7 @@ def instrumented_call(packet): # Run inside an event loop to trigger the ThreadPoolExecutor path async def run_in_loop(): - pf.direct_call(Packet({"x": 5})) + _result = pf.direct_call(Packet({"x": 5})) asyncio.run(run_in_loop()) diff --git a/tests/test_channels/test_node_async_execute.py b/tests/test_channels/test_node_async_execute.py index f8d2f185..1173e6fa 100644 --- a/tests/test_channels/test_node_async_execute.py +++ b/tests/test_channels/test_node_async_execute.py @@ -138,11 +138,11 @@ def counting_double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - await cpf.async_call(packet) + _result1 = await cpf.async_call(packet) assert call_count == 1 # With skip_cache_lookup, should recompute - await cpf.async_call(packet, skip_cache_lookup=True) + _result2 = await cpf.async_call(packet, skip_cache_lookup=True) assert call_count == 2 @pytest.mark.asyncio @@ -613,9 +613,9 @@ async def test_function_node_async_uses_async_process_packet_internal(self): original = node._async_process_packet_internal - async def patched(tag, packet): + async def patched(tag, packet, **kwargs): call_log.append("_async_process_packet_internal") - return await original(tag, packet) + return await original(tag, packet, **kwargs) node._async_process_packet_internal = patched diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index 317a2cb1..78ee2d2b 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -60,11 +60,11 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "PacketProtocol | None": self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): + def execute_callable(self, fn, kwargs, executor_options=None, **kw): self.calls.append((fn, kwargs)) return fn(**kwargs) @@ -83,7 +83,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "PacketProtocol | None": return packet_function.direct_call(packet) @@ -101,7 +101,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "PacketProtocol | None": return packet_function.direct_call(packet) @@ -205,8 +205,8 @@ def test_with_options_returns_new_instance(self, local_executor: LocalExecutor): class TestExecutorProperty: - def test_default_executor_is_none(self, add_pf: PythonPacketFunction): - assert add_pf.executor is None + def test_default_executor_is_local(self, add_pf: PythonPacketFunction): + assert isinstance(add_pf.executor, LocalExecutor) def test_set_executor( self, add_pf: PythonPacketFunction, spy_executor: SpyExecutor @@ -214,7 +214,7 @@ def test_set_executor( add_pf.executor = spy_executor assert add_pf.executor is spy_executor - def test_unset_executor( + def test_unset_executor_to_none( self, add_pf: PythonPacketFunction, spy_executor: SpyExecutor ): add_pf.executor = spy_executor @@ -407,15 +407,16 @@ def test_pod_executor_set_targets_packet_function(self): pod.executor = spy assert pf.executor is spy - def test_pod_executor_unset(self): + def test_pod_executor_swap(self): from orcapod.core.function_pod import FunctionPod spy = SpyExecutor() + spy2 = SpyExecutor() pf = PythonPacketFunction(add, output_keys="result") pf.executor = spy pod = FunctionPod(pf) - pod.executor = None - assert pf.executor is None + pod.executor = spy2 + assert pf.executor is spy2 def test_pod_process_uses_executor(self): from orcapod.core.function_pod import FunctionPod @@ -545,14 +546,14 @@ def test_decorator_incompatible_executor_raises(self): def my_add(x: int, y: int) -> int: return x + y - def test_decorator_without_executor_defaults_to_none(self): + def test_decorator_without_executor_defaults_to_local(self): from orcapod.core.function_pod import function_pod @function_pod(output_keys="result") def my_add(x: int, y: int) -> int: return x + y - assert my_add.pod.executor is None + assert isinstance(my_add.pod.executor, LocalExecutor) # --------------------------------------------------------------------------- @@ -599,7 +600,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "PacketProtocol | None": self.sync_calls.append(packet) return packet_function.direct_call(packet) @@ -607,15 +608,15 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "PacketProtocol | None": self.async_calls.append(packet) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): + def execute_callable(self, fn, kwargs, executor_options=None, **kw): self.sync_calls.append(kwargs) return fn(**kwargs) - async def async_execute_callable(self, fn, kwargs, executor_options=None): + async def async_execute_callable(self, fn, kwargs, executor_options=None, **kw): self.async_calls.append(kwargs) return fn(**kwargs) @@ -785,7 +786,8 @@ def test_set_executor_accepts_compatible_protocol(self): assert pf.executor is executor def test_set_executor_accepts_none(self): - pf = PythonPacketFunction(add, output_keys="result", executor=LocalExecutor()) + spy = SpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) pf.set_executor(None) assert pf.executor is None diff --git a/tests/test_core/packet_function/test_packet_function.py b/tests/test_core/packet_function/test_packet_function.py index 4ab5cf1c..636dd29b 100644 --- a/tests/test_core/packet_function/test_packet_function.py +++ b/tests/test_core/packet_function/test_packet_function.py @@ -461,7 +461,8 @@ def test_source_info_record_id_is_uuid(self, add_pf, add_packet): def test_inactive_returns_none(self, add_pf, add_packet): add_pf.set_active(False) - assert add_pf.call(add_packet) is None + result = add_pf.call(add_packet) + assert result is None def test_multiple_output_keys(self, multi_pf): packet = Packet({"a": 3, "b": 4}) @@ -607,7 +608,8 @@ def test_call_returns_correct_result(self, async_add_pf, add_packet): def test_inactive_returns_none(self, async_add_pf, add_packet): async_add_pf.set_active(False) - assert async_add_pf.call(add_packet) is None + result = async_add_pf.call(add_packet) + assert result is None def test_multiple_outputs(self, async_multi_pf): packet = Packet({"a": 3, "b": 4}) diff --git a/tests/test_core/test_regression_fixes.py b/tests/test_core/test_regression_fixes.py index 793f0efe..824a7bcf 100644 --- a/tests/test_core/test_regression_fixes.py +++ b/tests/test_core/test_regression_fixes.py @@ -78,11 +78,13 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, + *, + logger=None, ) -> PacketProtocol | None: self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): + def execute_callable(self, fn, kwargs, executor_options=None, *, logger=None): self.calls.append((fn, kwargs)) return fn(**kwargs) @@ -97,8 +99,9 @@ class TestAsyncExecuteChannelCloseOnError: @pytest.mark.asyncio async def test_unary_operator_closes_channel_on_error(self): - """SelectPacketColumns with a column that doesn't exist should fail, - but the output channel must still be closed.""" + """When a packet function raises, process_packet catches the exception + and returns (tag, None). The output channel is closed + normally and no exception propagates.""" def failing(x: int) -> int: raise ValueError("boom") @@ -112,15 +115,14 @@ def failing(x: int) -> int: await feed_stream_to_channel(stream, input_ch) - with pytest.raises(ExceptionGroup): - await pod.async_execute([input_ch.reader], output_ch.writer) + # No exception raised — failures are caught inside direct_call + await pod.async_execute([input_ch.reader], output_ch.writer) - # The output channel should be closed despite the error. - # Attempting to collect should return whatever was sent before error, - # and not hang forever. + # The output channel should be closed. results = await output_ch.reader.collect() - # We don't assert content, just that it doesn't hang. + # All packets failed so no results were sent assert isinstance(results, list) + assert len(results) == 0 @pytest.mark.asyncio async def test_operator_closes_channel_on_static_process_error(self): @@ -308,13 +310,13 @@ def double(x: int) -> int: # Patch async_call to use our concurrency-tracking function original_async_call = pf.async_call - async def tracked_async_call(packet: PacketProtocol) -> PacketProtocol | None: + async def tracked_async_call(packet: PacketProtocol, **kwargs): nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) await asyncio.sleep(0.01) concurrent_count -= 1 - return await original_async_call(packet) + return await original_async_call(packet, **kwargs) pf.async_call = tracked_async_call # type: ignore @@ -458,20 +460,22 @@ def test_ensure_ray_initialized_skips_when_already_initialized(self): mock_ray.init.assert_not_called() def test_async_execute_uses_wrap_future(self): - """async_execute should use ref.future() + asyncio.wrap_future, - not bare 'await ref'.""" + """async_execute_callable should use ref.future() + asyncio.wrap_future, + not bare 'await ref'. async_execute delegates to async_execute_callable.""" import inspect from orcapod.core.executors.ray import RayExecutor - source = inspect.getsource(RayExecutor.async_execute) + source = inspect.getsource(RayExecutor.async_execute_callable) assert "ref.future()" in source, ( - "async_execute should use ref.future() for asyncio compatibility" + "async_execute_callable should use ref.future() for asyncio compatibility" + ) + assert "wrap_future" in source, ( + "async_execute_callable should use asyncio.wrap_future" ) - assert "wrap_future" in source, "async_execute should use asyncio.wrap_future" # Should NOT do bare 'await ref' assert "return await ref\n" not in source, ( - "async_execute should not use bare 'await ref'" + "async_execute_callable should not use bare 'await ref'" ) diff --git a/tests/test_pipeline/test_logging_capture.py b/tests/test_pipeline/test_logging_capture.py new file mode 100644 index 00000000..1b9ab9bf --- /dev/null +++ b/tests/test_pipeline/test_logging_capture.py @@ -0,0 +1,332 @@ +"""Tests for logging_capture — CapturedLogs, tee streams, LocalCaptureContext, +and the Ray worker-side wrapper.""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import io +import logging +import sys + +import pytest + +from orcapod.pipeline.logging_capture import ( + CapturedLogs, + ContextLocalTeeStream, + ContextVarLoggingHandler, + LocalCaptureContext, + _log_capture, + _stderr_capture, + _stdout_capture, + install_capture_streams, +) + + +# --------------------------------------------------------------------------- +# CapturedLogs +# --------------------------------------------------------------------------- + + +class TestCapturedLogs: + def test_defaults(self): + c = CapturedLogs() + assert c.stdout == "" + assert c.stderr == "" + assert c.python_logs == "" + assert c.traceback is None + assert c.success is True + + def test_fields(self): + c = CapturedLogs(stdout="out", stderr="err", traceback="tb", success=False) + assert c.stdout == "out" + assert c.stderr == "err" + assert c.traceback == "tb" + assert c.success is False + + +# --------------------------------------------------------------------------- +# ContextLocalTeeStream +# --------------------------------------------------------------------------- + + +class TestContextLocalTeeStream: + def test_tees_to_buffer_and_original(self): + buf = io.StringIO() + original = io.StringIO() + token = _stdout_capture.set(buf) + tee = ContextLocalTeeStream(original, _stdout_capture) + try: + tee.write("hello") + tee.flush() + finally: + _stdout_capture.reset(token) + + assert buf.getvalue() == "hello" + assert original.getvalue() == "hello" + + def test_no_buffer_writes_only_original(self): + original = io.StringIO() + tee = ContextLocalTeeStream(original, _stdout_capture) + # No ContextVar buffer set — should just write to original + tee.write("world") + assert original.getvalue() == "world" + + def test_isolation_between_contexts(self): + """Two concurrent contexts each capture only their own writes.""" + buf_a = io.StringIO() + buf_b = io.StringIO() + original = io.StringIO() + tee = ContextLocalTeeStream(original, _stdout_capture) + + token_a = _stdout_capture.set(buf_a) + tee.write("from-A") + _stdout_capture.reset(token_a) + + token_b = _stdout_capture.set(buf_b) + tee.write("from-B") + _stdout_capture.reset(token_b) + + assert buf_a.getvalue() == "from-A" + assert buf_b.getvalue() == "from-B" + assert original.getvalue() == "from-Afrom-B" + + def test_proxy_attributes(self): + original = io.StringIO() + tee = ContextLocalTeeStream(original, _stdout_capture) + # Attribute delegation should work + assert tee.writable() == original.writable() + + +# --------------------------------------------------------------------------- +# ContextVarLoggingHandler +# --------------------------------------------------------------------------- + + +class TestContextVarLoggingHandler: + def test_captures_to_active_buffer(self): + handler = ContextVarLoggingHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + buf: list[str] = [] + token = _log_capture.set(buf) + try: + record = logging.LogRecord( + name="test", level=logging.INFO, + pathname="", lineno=0, msg="hello log", args=(), exc_info=None + ) + handler.emit(record) + finally: + _log_capture.reset(token) + + assert buf == ["hello log"] + + def test_discards_when_no_buffer(self): + handler = ContextVarLoggingHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + # No ContextVar buffer — emit should be a no-op + record = logging.LogRecord( + name="test", level=logging.INFO, + pathname="", lineno=0, msg="ignored", args=(), exc_info=None + ) + handler.emit(record) # should not raise + + +# --------------------------------------------------------------------------- +# LocalCaptureContext +# --------------------------------------------------------------------------- + + +class TestLocalCaptureContext: + def test_captures_nothing_without_install(self): + """Without install_capture_streams(), capture returns empty strings.""" + ctx = LocalCaptureContext() + with ctx: + print("not captured") + captured = ctx.get_captured(success=True) + # Since streams are not installed, buffer is empty + assert captured.stdout == "" + assert captured.success is True + + def test_captures_after_install(self): + install_capture_streams() + ctx = LocalCaptureContext() + with ctx: + print("captured output") + captured = ctx.get_captured(success=True) + assert "captured output" in captured.stdout + assert captured.success is True + + def test_captures_stderr_after_install(self): + install_capture_streams() + ctx = LocalCaptureContext() + with ctx: + print("error output", file=sys.stderr) + captured = ctx.get_captured(success=True) + assert "error output" in captured.stderr + + def test_exception_does_not_suppress(self): + install_capture_streams() + ctx = LocalCaptureContext() + with pytest.raises(ValueError, match="boom"): + with ctx: + print("before error") + raise ValueError("boom") + + def test_captures_partial_output_on_exception(self): + install_capture_streams() + ctx = LocalCaptureContext() + try: + with ctx: + print("before error") + raise ValueError("boom") + except ValueError: + pass + captured = ctx.get_captured(success=False, tb="traceback text") + assert "before error" in captured.stdout + assert captured.success is False + assert captured.traceback == "traceback text" + + def test_isolation_between_concurrent_asyncio_tasks(self): + """Each asyncio task captures only its own output.""" + install_capture_streams() + + async def run_task(label: str) -> str: + ctx = LocalCaptureContext() + with ctx: + await asyncio.sleep(0) # yield to other tasks + print(label) + await asyncio.sleep(0) + return ctx.get_captured(success=True).stdout + + async def main(): + results = await asyncio.gather( + run_task("task-A"), run_task("task-B") + ) + return results + + a_out, b_out = asyncio.run(main()) + assert "task-A" in a_out + assert "task-B" not in a_out + assert "task-B" in b_out + assert "task-A" not in b_out + + def test_isolation_between_threads(self): + """Each thread captures only its own output.""" + install_capture_streams() + results: dict[str, str] = {} + + def worker(label: str) -> None: + ctx = LocalCaptureContext() + with ctx: + print(label) + results[label] = ctx.get_captured(success=True).stdout + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futs = [pool.submit(worker, lbl) for lbl in ("thread-X", "thread-Y")] + for f in futs: + f.result() + + assert "thread-X" in results["thread-X"] + assert "thread-Y" not in results["thread-X"] + assert "thread-Y" in results["thread-Y"] + assert "thread-X" not in results["thread-Y"] + + def test_python_logging_captured(self): + install_capture_streams() + test_logger = logging.getLogger("test.capture") + ctx = LocalCaptureContext() + with ctx: + test_logger.warning("log message") + captured = ctx.get_captured(success=True) + assert "log message" in captured.python_logs + + def test_resets_contextvar_after_exit(self): + install_capture_streams() + ctx = LocalCaptureContext() + with ctx: + pass + # ContextVars should be None after context exits + assert _stdout_capture.get() is None + assert _stderr_capture.get() is None + assert _log_capture.get() is None + + +# --------------------------------------------------------------------------- +# RayExecutor._make_capture_wrapper (local process — no Ray cluster needed) +# --------------------------------------------------------------------------- + + +class TestRayCaptureWrapper: + """Tests exercise :meth:`RayExecutor._make_capture_wrapper` directly. + + The wrapper runs in the same process (no Ray cluster needed) and + returns a plain 6-tuple that the driver reassembles into CapturedLogs. + """ + + @staticmethod + def _call_wrapper(fn, kwargs): + """Helper: call the Ray capture wrapper and return (raw, CapturedLogs).""" + from orcapod.core.executors.ray import RayExecutor + + wrapper = RayExecutor._make_capture_wrapper() + raw, stdout, stderr, python_logs, tb, success = wrapper(fn, kwargs) + return raw, CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) + + def test_captures_stdout(self): + def fn(x): + print(f"result={x}") + return x * 2 + + raw, captured = self._call_wrapper(fn, {"x": 3}) + assert raw == 6 + assert "result=3" in captured.stdout + assert captured.success is True + assert captured.traceback is None + + def test_captures_stderr(self): + def fn(): + import sys + print("err line", file=sys.stderr) + return 1 + + raw, captured = self._call_wrapper(fn, {}) + assert raw == 1 + assert "err line" in captured.stderr + + def test_captures_exception(self): + def fn(): + raise RuntimeError("worker blew up") + + raw, captured = self._call_wrapper(fn, {}) + assert raw is None + assert captured.success is False + assert "RuntimeError" in captured.traceback + assert "worker blew up" in captured.traceback + + def test_captures_python_logging(self): + def fn(): + logging.getLogger("ray_wrapper_test").error("log from worker") + return True + + raw, captured = self._call_wrapper(fn, {}) + assert raw is True + assert "log from worker" in captured.python_logs + + def test_restores_fds_after_exception(self): + """File descriptors 1/2 must be restored even when fn raises.""" + import os + + original_stdout_fd = os.dup(1) + try: + def fn(): + raise ValueError("oops") + + self._call_wrapper(fn, {}) + + # Write to fd 1 — should succeed (not be pointing at a closed temp file) + os.write(1, b"") + finally: + os.close(original_stdout_fd) diff --git a/tests/test_pipeline/test_logging_observer_integration.py b/tests/test_pipeline/test_logging_observer_integration.py new file mode 100644 index 00000000..5d298c66 --- /dev/null +++ b/tests/test_pipeline/test_logging_observer_integration.py @@ -0,0 +1,392 @@ +"""Integration tests for LoggingObserver with real pipelines. + +Exercises the full logging pipeline: capture → CapturedLogs return type → +FunctionNode → observer → PacketLogger → database, using InMemoryArrowDatabase +and real Pipeline objects. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources.arrow_table_source import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import ( + AsyncPipelineOrchestrator, + Pipeline, + SyncPipelineOrchestrator, +) +from orcapod.pipeline.logging_observer import LoggingObserver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table({ + "id": pa.array([str(i) for i in range(n)], type=pa.large_string()), + "x": pa.array([10 * (i + 1) for i in range(n)], type=pa.int64()), + }) + return ArrowTableSource(table, tag_columns=["id"]) + + +def _get_function_node(pipeline: Pipeline): + """Return the first function node from the pipeline graph.""" + import networkx as nx + + for node in nx.topological_sort(pipeline._node_graph): + if node.node_type == "function": + return node + raise RuntimeError("No function node found") + + +# --------------------------------------------------------------------------- +# 1. Sync pipeline — succeeding packets → log rows with stdout captured +# --------------------------------------------------------------------------- + + +class TestSyncPipelineSuccessLogs: + def test_success_logs_captured(self): + db = InMemoryArrowDatabase() + source = _make_source() + + def double(x: int) -> int: + print(f"doubling {x}") + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_logs", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + assert logs.num_rows == 3 + success_col = logs.column("success").to_pylist() + assert all(success_col) + + +# --------------------------------------------------------------------------- +# 2. Failing packets → success=False, traceback populated +# --------------------------------------------------------------------------- + + +class TestFailingPacketsLogged: + def test_failure_logged_with_traceback(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_fail", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + assert logs.num_rows == 2 + for row_idx in range(logs.num_rows): + assert logs.column("success").to_pylist()[row_idx] is False + tb = logs.column("traceback").to_pylist()[row_idx] + assert tb is not None + assert "ValueError" in tb + assert "boom" in tb + + +# --------------------------------------------------------------------------- +# 3. Pipeline-path-mirrored storage +# --------------------------------------------------------------------------- + + +class TestPipelinePathMirroredStorage: + def test_log_path_mirrors_pipeline_path(self): + db = InMemoryArrowDatabase() + source = _make_source(1) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_mirror", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + pp = fn_node.pipeline_path + expected_log_path = pp[:1] + ("logs",) + pp[1:] + + # Verify the log path is correct by reading directly from the DB + raw = db.get_all_records(expected_log_path) + assert raw is not None + assert raw.num_rows == 1 + + +# --------------------------------------------------------------------------- +# 4. Queryable tag columns (not JSON) +# --------------------------------------------------------------------------- + + +class TestQueryableTagColumns: + def test_tag_columns_in_log_table(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_tags", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + # "id" tag column should be a separate column, not JSON + assert "id" in logs.column_names + assert "tags" not in logs.column_names + id_values = sorted(logs.column("id").to_pylist()) + assert id_values == ["0", "1"] + + +# --------------------------------------------------------------------------- +# 5. Async orchestrator logs +# --------------------------------------------------------------------------- + + +class TestAsyncOrchestratorLogs: + def test_async_pipeline_captures_logs(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_async_logs", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + obs = LoggingObserver(log_database=db) + orch = AsyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + assert logs.num_rows == 2 + assert all(logs.column("success").to_pylist()) + + +# --------------------------------------------------------------------------- +# 6. fail_fast error policy +# --------------------------------------------------------------------------- + + +class TestFailFastErrorPolicy: + def test_fail_fast_aborts_and_logs(self): + db = InMemoryArrowDatabase() + source = _make_source(3) + + def failing(x: int) -> int: + raise RuntimeError("crash") + + pf = PythonPacketFunction(failing, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_failfast", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs, error_policy="fail_fast") + + with pytest.raises(RuntimeError, match="crash"): + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + # At least the first crash should be logged before abort + assert logs is not None + assert logs.num_rows >= 1 + assert logs.column("success").to_pylist()[0] is False + + +# --------------------------------------------------------------------------- +# 7. Mixed success/failure — correct log row per packet +# --------------------------------------------------------------------------- + + +class TestMixedSuccessFailure: + def test_mixed_results_logged_correctly(self): + db = InMemoryArrowDatabase() + source = ArrowTableSource( + pa.table({ + "id": pa.array(["0", "1", "2"], type=pa.large_string()), + "x": pa.array([10, -1, 30], type=pa.int64()), + }), + tag_columns=["id"], + ) + + def safe_div(x: int) -> float: + return 100 / x + + pf = PythonPacketFunction(safe_div, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_mixed", pipeline_database=db) + with pipeline: + pod(source, label="divider") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + # x=10 succeeds, x=-1 succeeds (100/-1 = -100), x=30 succeeds + # All three should succeed since division by these values is fine + assert logs.num_rows == 3 + assert all(logs.column("success").to_pylist()) + + +# --------------------------------------------------------------------------- +# 8. Multiple function nodes — each gets own log table +# --------------------------------------------------------------------------- + + +class TestMultipleFunctionNodesSeparateLogs: + def test_two_nodes_separate_log_tables(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result") + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final") + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_multi", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + assert len(fn_nodes) == 2 + + logs1 = obs.get_logs(pipeline_path=fn_nodes[0].pipeline_path) + logs2 = obs.get_logs(pipeline_path=fn_nodes[1].pipeline_path) + + assert logs1 is not None + assert logs2 is not None + assert logs1.num_rows == 2 + assert logs2.num_rows == 2 + + # Verify they are at different paths + assert fn_nodes[0].pipeline_path != fn_nodes[1].pipeline_path + + +# --------------------------------------------------------------------------- +# 9. get_logs(pipeline_path) retrieves node-specific logs +# --------------------------------------------------------------------------- + + +class TestGetLogsNodeSpecific: + def test_get_logs_filters_by_node(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result") + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final") + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_filter", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + + # Each node's logs contain only that node's label + logs1 = obs.get_logs(pipeline_path=fn_nodes[0].pipeline_path) + logs2 = obs.get_logs(pipeline_path=fn_nodes[1].pipeline_path) + + labels1 = set(logs1.column("node_label").to_pylist()) + labels2 = set(logs2.column("node_label").to_pylist()) + + assert labels1 == {"doubler"} + assert labels2 == {"tripler"} diff --git a/tests/test_pipeline/test_node_protocols.py b/tests/test_pipeline/test_node_protocols.py index e05c743c..20d210bc 100644 --- a/tests/test_pipeline/test_node_protocols.py +++ b/tests/test_pipeline/test_node_protocols.py @@ -157,17 +157,18 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): - events.append(("start", n.node_type)) - def on_node_end(self, n): - events.append(("end", n.node_type)) - def on_packet_start(self, n, t, p): + def on_node_start(self, node_label, node_hash): + events.append(("start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("end", node_label)) + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass node.execute(observer=Obs()) - assert events == [("start", "source"), ("end", "source")] + assert events[0][0] == "start" + assert events[1][0] == "end" def test_execute_without_observer(self): """execute() works fine with no observer.""" @@ -206,13 +207,13 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): + def on_node_start(self, node_label, node_hash): events.append("start") - def on_node_end(self, n): + def on_node_end(self, node_label, node_hash): events.append("end") - def on_packet_start(self, n, t, p): + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass output_ch = Channel(buffer_size=16) @@ -249,21 +250,28 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): - events.append(("node_start", n.node_type)) - def on_node_end(self, n): - events.append(("node_end", n.node_type)) - def on_packet_start(self, n, t, p): + def contextualize(self, node_hash, node_label): + return self + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, t, p): events.append(("packet_start",)) - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): events.append(("packet_end", cached)) + def on_packet_crash(self, node_label, t, p, exc): + pass + def create_packet_logger(self, t, p, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER input_stream = node._input_stream result = node.execute(input_stream, observer=Obs()) assert len(result) == 2 - assert events[0] == ("node_start", "function") - assert events[-1] == ("node_end", "function") + assert events[0][0] == "node_start" + assert events[-1][0] == "node_end" packet_events = [e for e in events if e[0].startswith("packet")] assert len(packet_events) == 4 # 2 start + 2 end @@ -321,14 +329,21 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): + def contextualize(self, node_hash, node_label): + return self + def on_node_start(self, node_label, node_hash): events.append("node_start") - def on_node_end(self, n): + def on_node_end(self, node_label, node_hash): events.append("node_end") - def on_packet_start(self, n, t, p): + def on_packet_start(self, node_label, t, p): events.append("pkt_start") - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): events.append("pkt_end") + def on_packet_crash(self, node_label, t, p, exc): + pass + def create_packet_logger(self, t, p, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER input_ch = Channel(buffer_size=16) output_ch = Channel(buffer_size=16) @@ -368,18 +383,19 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): - events.append(("node_start", n.node_type)) - def on_node_end(self, n): - events.append(("node_end", n.node_type)) - def on_packet_start(self, n, t, p): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass result = node.execute(*node._input_streams, observer=Obs()) assert len(result) == 2 - assert events == [("node_start", "operator"), ("node_end", "operator")] + assert events[0][0] == "node_start" + assert events[1][0] == "node_end" def test_execute_without_observer(self): node = self._make_join_node() @@ -408,13 +424,13 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): + def on_node_start(self, node_label, node_hash): events.append("start") - def on_node_end(self, n): + def on_node_end(self, node_label, node_hash): events.append("end") - def on_packet_start(self, n, t, p): + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass input_ch = Channel(buffer_size=16) diff --git a/tests/test_pipeline/test_observer.py b/tests/test_pipeline/test_observer.py index 91237ca5..7d0f177c 100644 --- a/tests/test_pipeline/test_observer.py +++ b/tests/test_pipeline/test_observer.py @@ -1,8 +1,12 @@ -"""Tests for ExecutionObserver protocol and NoOpObserver.""" +"""Tests for ExecutionObserverProtocol, NoOpObserver, and NoOpLogger.""" from __future__ import annotations -from orcapod.pipeline.observer import ExecutionObserver, NoOpObserver +from orcapod.pipeline.observer import NoOpLogger, NoOpObserver, _NOOP_LOGGER +from orcapod.protocols.observability_protocols import ( + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) class TestNoOpObserver: @@ -10,20 +14,56 @@ class TestNoOpObserver: def test_satisfies_protocol(self): observer = NoOpObserver() - assert isinstance(observer, ExecutionObserver) + assert isinstance(observer, ExecutionObserverProtocol) + + def test_on_run_start_noop(self): + NoOpObserver().on_run_start("run-123") + + def test_on_run_end_noop(self): + NoOpObserver().on_run_end("run-123") def test_on_node_start_noop(self): - observer = NoOpObserver() - observer.on_node_start(None) # type: ignore[arg-type] + NoOpObserver().on_node_start("label", "hash") def test_on_node_end_noop(self): - observer = NoOpObserver() - observer.on_node_end(None) # type: ignore[arg-type] + NoOpObserver().on_node_end("label", "hash") def test_on_packet_start_noop(self): - observer = NoOpObserver() - observer.on_packet_start(None, None, None) # type: ignore[arg-type] + NoOpObserver().on_packet_start("label", None, None) # type: ignore[arg-type] def test_on_packet_end_noop(self): - observer = NoOpObserver() - observer.on_packet_end(None, None, None, None, cached=False) # type: ignore[arg-type] + NoOpObserver().on_packet_end("label", None, None, None, cached=False) # type: ignore[arg-type] + + def test_on_packet_crash_noop(self): + NoOpObserver().on_packet_crash("label", None, None, RuntimeError("boom")) # type: ignore[arg-type] + + def test_create_packet_logger_returns_noop(self): + logger = NoOpObserver().create_packet_logger(None, None) # type: ignore[arg-type] + assert logger is _NOOP_LOGGER + + def test_create_packet_logger_satisfies_protocol(self): + logger = NoOpObserver().create_packet_logger(None, None) # type: ignore[arg-type] + assert isinstance(logger, PacketExecutionLoggerProtocol) + + def test_contextualize_returns_self(self): + obs = NoOpObserver() + assert obs.contextualize("hash", "label") is obs + + +class TestNoOpLogger: + """NoOpLogger satisfies the protocol and discards everything.""" + + def test_satisfies_protocol(self): + assert isinstance(NoOpLogger(), PacketExecutionLoggerProtocol) + + def test_record_noop(self): + from orcapod.pipeline.logging_capture import CapturedLogs + + NoOpLogger().record(CapturedLogs(stdout="hello", success=True)) + + def test_noop_logger_singleton_identity(self): + # The singleton returned by create_packet_logger is always the same object. + obs = NoOpObserver() + l1 = obs.create_packet_logger(None, None) # type: ignore[arg-type] + l2 = obs.create_packet_logger(None, None) # type: ignore[arg-type] + assert l1 is l2 diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index 6c6d39f5..587b4d1a 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -428,9 +428,10 @@ def test_single_terminal_source(self): class TestAsyncOrchestratorErrorPropagation: - """Node failures should propagate correctly.""" + """Failed packets do not abort the pipeline; they are handled per-packet.""" - def test_node_failure_propagates(self): + def test_node_failure_does_not_abort_pipeline(self): + """A crashing packet function is skipped; the pipeline completes normally.""" def failing_fn(value: int) -> int: raise ValueError("intentional failure") @@ -445,8 +446,36 @@ def failing_fn(value: int) -> int: pipeline.compile() orch = AsyncPipelineOrchestrator() - with pytest.raises(ExceptionGroup): - orch.run(pipeline._node_graph) + # Pipeline must complete without raising; failing packet is silently dropped. + orch.run(pipeline._node_graph) + + def test_node_failure_calls_on_packet_crash(self): + """When an observer is set, on_packet_crash is called for the failing packet.""" + from orcapod.pipeline.observer import NoOpObserver + + def failing_fn(value: int) -> int: + raise ValueError("intentional failure") + + src = _make_source("key", "value", {"key": ["a"], "value": [1]}) + pf = PythonPacketFunction(failing_fn, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="error2", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="failer") + + crashes = [] + + class CrashRecorder(NoOpObserver): + def on_packet_crash(self, node_label, tag, packet, error): + crashes.append(error) + + pipeline.compile() + orch = AsyncPipelineOrchestrator(observer=CrashRecorder()) + orch.run(pipeline._node_graph) + + assert len(crashes) == 1 + assert isinstance(crashes[0], (ValueError, RuntimeError)) class TestAsyncOrchestratorObserverInjection: @@ -465,32 +494,37 @@ def test_linear_pipeline_observer_hooks(self): events = [] class RecordingObserver: - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - - def on_packet_start(self, node, tag, packet): - events.append(("packet_start", node.node_type)) - - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - events.append(("packet_end", node.node_type, cached)) + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): + events.append(("packet_start", node_label)) + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", node_label, cached)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self pipeline.compile() orch = AsyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - # Source fires node_start/node_end - assert ("node_start", "source") in events - assert ("node_end", "source") in events + # Source fires node_start/node_end (label contains "ArrowTableSource" or similar) + source_starts = [e for e in events if e[0] == "node_start" and e[1] != "doubler"] + assert len(source_starts) >= 1 # Function fires node_start, per-packet hooks, node_end - assert ("node_start", "function") in events - assert ("node_end", "function") in events + assert ("node_start", "doubler") in events + assert ("node_end", "doubler") in events fn_packet_ends = [ e for e in events - if e[0] == "packet_end" and e[1] == "function" + if e[0] == "packet_end" and e[1] == "doubler" ] assert len(fn_packet_ends) == 2 # All should be cached=False (first run, no DB) @@ -517,31 +551,36 @@ def double_val(val: int) -> int: events = [] class RecordingObserver: - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - - def on_packet_start(self, node, tag, packet): - events.append(("packet_start", node.node_type)) - - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - events.append(("packet_end", node.node_type)) + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): + events.append(("packet_start", node_label)) + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", node_label)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self pipeline.compile() orch = AsyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - # All three node types fire start/end - for node_type in ("source", "operator", "function"): - assert ("node_start", node_type) in events - assert ("node_end", node_type) in events + # All labeled nodes fire start/end + assert ("node_start", "mapper") in events + assert ("node_end", "mapper") in events + assert ("node_start", "doubler") in events + assert ("node_end", "doubler") in events # Only function nodes fire packet-level hooks - assert ("packet_start", "function") in events - assert ("packet_start", "source") not in events - assert ("packet_start", "operator") not in events + assert ("packet_start", "doubler") in events + assert ("packet_start", "mapper") not in events def test_no_observer_works(self): """Async pipeline runs fine with no observer.""" diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 99ffd956..203d9ec9 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1164,7 +1164,9 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + *, + logger=None, + ) -> "PacketProtocol | None": self.sync_calls.append(packet) return packet_function.direct_call(packet) @@ -1172,15 +1174,17 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + *, + logger=None, + ) -> "PacketProtocol | None": self.async_calls.append(packet) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): + def execute_callable(self, fn, kwargs, executor_options=None, *, logger=None): self.sync_calls.append(kwargs) return fn(**kwargs) - async def async_execute_callable(self, fn, kwargs, executor_options=None): + async def async_execute_callable(self, fn, kwargs, executor_options=None, *, logger=None): self.async_calls.append(kwargs) return fn(**kwargs) diff --git a/tests/test_pipeline/test_sync_orchestrator.py b/tests/test_pipeline/test_sync_orchestrator.py index cd5c7974..bdd58efb 100644 --- a/tests/test_pipeline/test_sync_orchestrator.py +++ b/tests/test_pipeline/test_sync_orchestrator.py @@ -130,27 +130,34 @@ def test_observer_hooks_fire(self): events = [] class RecordingObserver: - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - - def on_packet_start(self, node, tag, packet): + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): events.append(("packet_start",)) - - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): events.append(("packet_end", cached)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self orch = SyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - assert events[0] == ("node_start", "source") - assert events[1] == ("node_end", "source") - assert events[2] == ("node_start", "function") + # First two events are source node start/end + assert events[0][0] == "node_start" + assert events[1][0] == "node_end" + # Then function node start, packet hooks, function node end + assert events[2] == ("node_start", "doubler") assert events[3] == ("packet_start",) assert events[4] == ("packet_end", False) - assert events[5] == ("node_end", "function") + assert events[5] == ("node_end", "doubler") class TestSyncOrchestratorUnknownNodeType: @@ -204,17 +211,22 @@ def test_run_with_explicit_orchestrator(self): events = [] class RecordingObserver: - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - - def on_packet_start(self, node, tag, packet): + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): events.append(("packet_start",)) - - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): events.append(("packet_end",)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self orch = SyncPipelineOrchestrator(observer=RecordingObserver()) pipeline.run(orchestrator=orch) @@ -409,35 +421,35 @@ def double_val(val: int) -> int: events = [] class RecordingObserver: - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - - def on_packet_start(self, node, tag, packet): - events.append(("packet_start", node.node_type)) - - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - events.append(("packet_end", node.node_type, cached)) + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): + events.append(("packet_start", node_label)) + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", node_label, cached)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self orch = SyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - # Source fires node_start/node_end only (no packet-level hooks) - assert ("node_start", "source") in events - assert ("node_end", "source") in events - assert ("packet_start", "source") not in events - - # Operator fires node_start/node_end only - assert ("node_start", "operator") in events - assert ("node_end", "operator") in events - assert ("packet_start", "operator") not in events + # Mapper fires node_start/node_end only (no packet-level hooks) + assert ("node_start", "mapper") in events + assert ("node_end", "mapper") in events + assert ("packet_start", "mapper") not in events # Function fires node_start, per-packet hooks, node_end - assert ("node_start", "function") in events - assert ("node_end", "function") in events - fn_packet_events = [e for e in events if e[0] == "packet_start" and e[1] == "function"] + assert ("node_start", "doubler") in events + assert ("node_end", "doubler") in events + fn_packet_events = [e for e in events if e[0] == "packet_start" and e[1] == "doubler"] assert len(fn_packet_events) == 2 # 2 packets def test_function_node_cached_flag(self): @@ -455,12 +467,20 @@ def test_function_node_cached_flag(self): events1 = [] class Obs1: - def on_node_start(self, node): pass - def on_node_end(self, node): pass - def on_packet_start(self, node, tag, packet): pass - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - if node.node_type == "function": + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): pass + def on_node_end(self, node_label, node_hash): pass + def on_packet_start(self, node_label, tag, packet): pass + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + if node_label == "doubler": events1.append(cached) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self SyncPipelineOrchestrator(observer=Obs1()).run(pipeline._node_graph) assert events1 == [False] @@ -469,12 +489,20 @@ def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events2 = [] class Obs2: - def on_node_start(self, node): pass - def on_node_end(self, node): pass - def on_packet_start(self, node, tag, packet): pass - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - if node.node_type == "function": + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): pass + def on_node_end(self, node_label, node_hash): pass + def on_packet_start(self, node_label, tag, packet): pass + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + if node_label == "doubler": events2.append(cached) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self SyncPipelineOrchestrator(observer=Obs2()).run(pipeline._node_graph) assert events2 == [True] @@ -494,23 +522,33 @@ def test_diamond_dag_observer_event_order(self): node_order = [] class OrderObserver: - def on_node_start(self, node): - node_order.append(("start", node.node_type)) - def on_node_end(self, node): - node_order.append(("end", node.node_type)) - def on_packet_start(self, node, tag, packet): pass - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): pass + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass + def on_node_start(self, node_label, node_hash): + node_order.append(("start", node_label)) + def on_node_end(self, node_label, node_hash): + node_order.append(("end", node_label)) + def on_packet_start(self, node_label, tag, packet): pass + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): pass + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self SyncPipelineOrchestrator(observer=OrderObserver()).run(pipeline._node_graph) - # Extract just the node types in start order - starts = [nt for event, nt in node_order if event == "start"] - # Sources first (order between them doesn't matter), then operator, then function - assert starts.count("source") == 2 - assert starts.index("operator") > max( - i for i, s in enumerate(starts) if s == "source" - ) - assert starts.index("function") > starts.index("operator") + # Extract just the node labels in start order + starts = [label for event, label in node_order if event == "start"] + # Sources first, then operator ("join"), then function ("adder") + join_idx = starts.index("join") + adder_idx = starts.index("adder") + # join and adder should come after source labels + source_labels = [s for s in starts if s not in ("join", "adder")] + assert len(source_labels) == 2 + assert join_idx > max(starts.index(s) for s in source_labels) + assert adder_idx > join_idx def test_no_observer_works(self): """Pipeline runs fine with no observer (None)."""