From fc734c1f267235e28f7bc544cebf96eef230abdd Mon Sep 17 00:00:00 2001 From: Brian Arnold Date: Tue, 17 Mar 2026 18:39:52 +0000 Subject: [PATCH 1/2] feat(logging): add logging infrastructure --- src/orcapod/__init__.py | 2 + src/orcapod/core/cached_function_pod.py | 55 +++ src/orcapod/core/executors/base.py | 31 +- src/orcapod/core/executors/local.py | 50 ++- src/orcapod/core/executors/ray.py | 193 +++++++-- src/orcapod/core/function_pod.py | 30 +- src/orcapod/core/nodes/function_node.py | 174 +++++--- src/orcapod/core/nodes/operator_node.py | 2 +- src/orcapod/core/nodes/source_node.py | 2 +- src/orcapod/core/packet_function.py | 222 +++++++--- src/orcapod/pipeline/__init__.py | 3 + src/orcapod/pipeline/async_orchestrator.py | 24 +- src/orcapod/pipeline/logging_capture.py | 235 +++++++++++ src/orcapod/pipeline/logging_observer.py | 272 ++++++++++++ src/orcapod/pipeline/observer.py | 86 ++-- src/orcapod/pipeline/sync_orchestrator.py | 26 +- src/orcapod/protocols/__init__.py | 4 + .../protocols/core_protocols/executor.py | 23 +- .../core_protocols/packet_function.py | 23 +- src/orcapod/protocols/node_protocols.py | 15 +- .../protocols/observability_protocols.py | 140 +++++++ tests/test_channels/test_async_execute.py | 36 +- .../test_copilot_review_issues.py | 4 +- .../test_channels/test_node_async_execute.py | 12 +- .../test_cached_packet_function.py | 24 +- .../packet_function/test_executor.py | 55 ++- .../packet_function/test_packet_function.py | 44 +- tests/test_core/test_regression_fixes.py | 45 +- tests/test_core/test_result_cache.py | 6 +- tests/test_pipeline/test_logging_capture.py | 332 +++++++++++++++ .../test_logging_observer_integration.py | 392 ++++++++++++++++++ tests/test_pipeline/test_node_protocols.py | 10 + tests/test_pipeline/test_observer.py | 58 ++- tests/test_pipeline/test_orchestrator.py | 55 ++- tests/test_pipeline/test_pipeline.py | 12 +- tests/test_pipeline/test_sync_orchestrator.py | 45 +- 36 files changed, 2376 insertions(+), 366 deletions(-) create mode 100644 src/orcapod/pipeline/logging_capture.py create mode 100644 src/orcapod/pipeline/logging_observer.py create mode 100644 src/orcapod/protocols/observability_protocols.py create mode 100644 tests/test_pipeline/test_logging_capture.py create mode 100644 tests/test_pipeline/test_logging_observer_integration.py diff --git a/src/orcapod/__init__.py b/src/orcapod/__init__.py index 85956537..8afbf1ef 100644 --- a/src/orcapod/__init__.py +++ b/src/orcapod/__init__.py @@ -2,6 +2,7 @@ FunctionPod, function_pod, ) +from .core.sources.arrow_table_source import ArrowTableSource from .pipeline import Pipeline # Subpackage re-exports for clean public API @@ -13,6 +14,7 @@ from . import types # noqa: F401 __all__ = [ + "ArrowTableSource", "FunctionPod", "function_pod", "Pipeline", diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py index 76d2a27d..771009dc 100644 --- a/src/orcapod/core/cached_function_pod.py +++ b/src/orcapod/core/cached_function_pod.py @@ -18,6 +18,8 @@ if TYPE_CHECKING: import pyarrow as pa + from orcapod.pipeline.logging_capture import CapturedLogs + logger = logging.getLogger(__name__) @@ -125,6 +127,59 @@ async def async_process_packet( output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output + def process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + """Process with pod-level caching, returning CapturedLogs alongside. + + On cache hit, returns empty CapturedLogs (no function was executed). + """ + from orcapod.pipeline.logging_capture import CapturedLogs + + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached, CapturedLogs(success=True) + + tag, output, captured = self._function_pod.process_packet_with_capture( + tag, packet + ) + if output is not None and captured.success: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output, captured + + async def async_process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + """Async counterpart of ``process_packet_with_capture``.""" + from orcapod.pipeline.logging_capture import CapturedLogs + + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached, CapturedLogs(success=True) + + tag, output, captured = await self._function_pod.async_process_packet_with_capture( + tag, packet + ) + if output is not None and captured.success: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output, captured + def get_all_cached_outputs( self, include_system_columns: bool = False ) -> "pa.Table | None": diff --git a/src/orcapod/core/executors/base.py b/src/orcapod/core/executors/base.py index 386cbe1a..1a7d0d6b 100644 --- a/src/orcapod/core/executors/base.py +++ b/src/orcapod/core/executors/base.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol @@ -48,12 +49,12 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Synchronously execute *packet_function* on *packet*. Implementations should call ``packet_function.direct_call(packet)`` to invoke the function's native computation, bypassing executor - routing. + routing, and pass through the ``(result, CapturedLogs)`` tuple. """ ... @@ -61,7 +62,7 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Asynchronous counterpart of ``execute``. The default implementation delegates to ``execute`` synchronously. @@ -96,11 +97,14 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: - """Synchronously execute *fn* with *kwargs*. + ) -> "tuple[Any, CapturedLogs]": + """Synchronously execute *fn* with *kwargs*, returning captured I/O. - Default implementation calls ``fn(**kwargs)`` in-process. - Subclasses should override for remote/distributed execution. + Default implementation calls ``fn(**kwargs)`` with no capture and + returns empty :class:`~orcapod.pipeline.logging_capture.CapturedLogs`. + Exceptions propagate to the caller. Subclasses (e.g. + ``LocalExecutor``, ``RayExecutor``) override to add I/O capture and + exception swallowing. Args: fn: The Python callable to execute. @@ -108,21 +112,22 @@ def execute_callable( executor_options: Optional per-call options. Returns: - The raw return value of *fn*. + ``(raw_result, CapturedLogs)`` """ - return fn(**kwargs) + from orcapod.pipeline.logging_capture import CapturedLogs + + return fn(**kwargs), CapturedLogs() async def async_execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: - """Asynchronously execute *fn* with *kwargs*. + ) -> "tuple[Any, CapturedLogs]": + """Asynchronously execute *fn* with *kwargs*, returning captured I/O. Default implementation delegates to ``execute_callable`` - synchronously. Subclasses should override for truly async - execution. + synchronously. Subclasses should override for truly async execution. """ return self.execute_callable(fn, kwargs, executor_options) diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py index fc97e477..ec05a086 100644 --- a/src/orcapod/core/executors/local.py +++ b/src/orcapod/core/executors/local.py @@ -2,12 +2,14 @@ import asyncio import inspect +import traceback as _traceback_module from collections.abc import Callable from typing import TYPE_CHECKING, Any from orcapod.core.executors.base import PacketFunctionExecutorBase if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol @@ -29,14 +31,14 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return packet_function.direct_call(packet) async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return await packet_function.direct_async_call(packet) # -- PythonFunctionExecutorProtocol -- @@ -46,10 +48,23 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: - if inspect.iscoroutinefunction(fn): - return self._run_async_sync(fn, kwargs) - return fn(**kwargs) + ) -> "tuple[Any, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + ctx = LocalCaptureContext() + raw_result = None + success = True + tb: str | None = None + with ctx: + try: + if inspect.iscoroutinefunction(fn): + raw_result = self._run_async_sync(fn, kwargs) + else: + raw_result = fn(**kwargs) + except Exception: + success = False + tb = _traceback_module.format_exc() + return raw_result, ctx.get_captured(success=success, tb=tb) @staticmethod def _run_async_sync(fn: Callable[..., Any], kwargs: dict[str, Any]) -> Any: @@ -69,11 +84,24 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: - if inspect.iscoroutinefunction(fn): - return await fn(**kwargs) - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, lambda: fn(**kwargs)) + ) -> "tuple[Any, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + ctx = LocalCaptureContext() + raw_result = None + success = True + tb: str | None = None + with ctx: + try: + if inspect.iscoroutinefunction(fn): + raw_result = await fn(**kwargs) + else: + loop = asyncio.get_running_loop() + raw_result = await loop.run_in_executor(None, lambda: fn(**kwargs)) + except Exception: + success = False + tb = _traceback_module.format_exc() + return raw_result, ctx.get_captured(success=success, tb=tb) def with_options(self, **opts: Any) -> "LocalExecutor": """Return a new ``LocalExecutor``. diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py index 998c2193..b1236fbc 100644 --- a/src/orcapod/core/executors/ray.py +++ b/src/orcapod/core/executors/ray.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: from orcapod.core.packet_function import PythonPacketFunction + from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol @@ -77,9 +78,38 @@ def __init__( self._remote_opts.update(ray_remote_opts) def _ensure_ray_initialized(self) -> None: - """Initialize Ray if it has not been initialized yet.""" + """Initialize Ray if it has not been initialized yet. + + Also registers a cloudpickle dispatch for ``logging.Logger`` so that + user functions referencing loggers can be sent to Ray workers that + do not have orcapod installed. + + By default cloudpickle serializes Logger instances by value, which + traverses the parent chain to the root logger. After + ``install_capture_streams()`` the root logger has a + ``ContextVarLoggingHandler`` from ``orcapod``. Workers without + orcapod cannot deserialize that class. + + Registering loggers as ``(logging.getLogger, (name,))`` is the + correct semantic — loggers are name-keyed singletons — and produces + no orcapod dependency in the pickled bytes. + """ + import logging import ray + try: + import cloudpickle + + def _pickle_logger(l: logging.Logger) -> tuple: + # Root logger has name "root" but must be fetched as "" + name = "" if isinstance(l, logging.RootLogger) else l.name + return logging.getLogger, (name,) + + cloudpickle.CloudPickler.dispatch[logging.Logger] = _pickle_logger + cloudpickle.CloudPickler.dispatch[logging.RootLogger] = _pickle_logger + except Exception: + pass # cloudpickle not available or API changed — best effort + if not ray.is_initialized(): if self._ray_address is not None: ray.init(address=self._ray_address) @@ -124,67 +154,172 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: - import ray + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs pf = self._as_python_packet_function(packet_function) if not pf.is_active(): - return None - - self._ensure_ray_initialized() + return None, CapturedLogs(success=True) - kwargs = packet.as_dict() - remote_fn = ray.remote(**self._build_remote_opts())(pf._function) - ref = remote_fn.remote(**kwargs) - raw_result = ray.get(ref) - return pf._build_output_packet(raw_result) + raw, captured = self.execute_callable(pf._function, packet.as_dict()) + if not captured.success: + return None, captured + return pf._build_output_packet(raw), captured async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: - import ray + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs pf = self._as_python_packet_function(packet_function) if not pf.is_active(): - return None - - self._ensure_ray_initialized() + return None, CapturedLogs(success=True) - kwargs = packet.as_dict() - remote_fn = ray.remote(**self._build_remote_opts())(pf._function) - ref = remote_fn.remote(**kwargs) - raw_result = await asyncio.wrap_future(ref.future()) - return pf._build_output_packet(raw_result) + raw, captured = await self.async_execute_callable(pf._function, packet.as_dict()) + if not captured.success: + return None, captured + return pf._build_output_packet(raw), captured # -- PythonFunctionExecutorProtocol -- + @staticmethod + def _make_capture_wrapper() -> Callable[..., Any]: + """Return an inline capture wrapper suitable for Ray remote execution. + + The wrapper is defined as a closure (not a module-level import) so that + cloudpickle serializes it by bytecode rather than by module reference. + This means the Ray cluster workers do **not** need ``orcapod`` installed + — only the standard library is required on the worker side. + + The wrapper returns a plain 6-tuple ``(raw_result, stdout, stderr, + python_logs, traceback_str, success)`` so no orcapod types cross the + Ray object store; the driver reconstructs :class:`CapturedLogs` from + the tuple. + """ + def _capture(fn: Any, kwargs: dict) -> tuple: + import io + import logging + import os + import sys + import tempfile + import traceback as _tb + + stdout_tmp = tempfile.TemporaryFile() + stderr_tmp = tempfile.TemporaryFile() + orig_stdout_fd = os.dup(1) + orig_stderr_fd = os.dup(2) + orig_sys_stdout = sys.stdout + orig_sys_stderr = sys.stderr + sys_stdout_buf = io.StringIO() + sys_stderr_buf = io.StringIO() + log_records: list = [] + + fmt = logging.Formatter("%(levelname)s:%(name)s:%(message)s") + + class _H(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_records.append(fmt.format(record)) + + handler = _H() + root_logger = logging.getLogger() + orig_level = root_logger.level + root_logger.setLevel(logging.DEBUG) + root_logger.addHandler(handler) + + raw_result = None + success = True + tb_str = None + try: + sys.stdout.flush() + sys.stderr.flush() + os.dup2(stdout_tmp.fileno(), 1) + os.dup2(stderr_tmp.fileno(), 2) + sys.stdout = sys_stdout_buf + sys.stderr = sys_stderr_buf + try: + raw_result = fn(**kwargs) + except Exception: + success = False + tb_str = _tb.format_exc() + finally: + sys.stdout = orig_sys_stdout + sys.stderr = orig_sys_stderr + os.dup2(orig_stdout_fd, 1) + os.dup2(orig_stderr_fd, 2) + os.close(orig_stdout_fd) + os.close(orig_stderr_fd) + root_logger.removeHandler(handler) + root_logger.setLevel(orig_level) + stdout_tmp.seek(0) + stderr_tmp.seek(0) + cap_stdout = ( + stdout_tmp.read().decode("utf-8", errors="replace") + + sys_stdout_buf.getvalue() + ) + cap_stderr = ( + stderr_tmp.read().decode("utf-8", errors="replace") + + sys_stderr_buf.getvalue() + ) + stdout_tmp.close() + stderr_tmp.close() + + return raw_result, cap_stdout, cap_stderr, "\n".join(log_records), tb_str, success + + return _capture + def execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: + ) -> "tuple[Any, CapturedLogs]": + """Execute *fn* on the Ray cluster with fd-level I/O capture. + + The capture wrapper is serialized by bytecode (not module reference) so + the Ray cluster workers do not need ``orcapod`` installed. + """ import ray + from orcapod.pipeline.logging_capture import CapturedLogs + self._ensure_ray_initialized() - remote_fn = ray.remote(**self._build_remote_opts())(fn) - ref = remote_fn.remote(**kwargs) - return ray.get(ref) + wrapper = self._make_capture_wrapper() + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + remote_fn = ray.remote(**self._build_remote_opts())(wrapper) + ref = remote_fn.remote(fn, kwargs) + raw, stdout, stderr, python_logs, tb, success = ray.get(ref) + return raw, CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) async def async_execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: + ) -> "tuple[Any, CapturedLogs]": + """Async counterpart of :meth:`execute_callable`.""" import ray + from orcapod.pipeline.logging_capture import CapturedLogs + self._ensure_ray_initialized() - remote_fn = ray.remote(**self._build_remote_opts())(fn) - ref = remote_fn.remote(**kwargs) - return await asyncio.wrap_future(ref.future()) + wrapper = self._make_capture_wrapper() + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + remote_fn = ray.remote(**self._build_remote_opts())(wrapper) + ref = remote_fn.remote(fn, kwargs) + raw, stdout, stderr, python_logs, tb, success = await asyncio.wrap_future( + ref.future() + ) + return raw, CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) def with_options(self, **opts: Any) -> "RayExecutor": """Return a new ``RayExecutor`` with the given options merged in. diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 5a1db639..7c525188 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -152,15 +152,39 @@ def process_packet( Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` if - the function filters the packet out. + the function filters the packet out. CapturedLogs are discarded + (only relevant for node-level execution with observers). """ - return tag, self.packet_function.call(packet) + result, _captured = self.packet_function.call(packet) + return tag, result async def async_process_packet( self, tag: TagProtocol, packet: PacketProtocol ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``.""" - return tag, await self.packet_function.async_call(packet) + result, _captured = await self.packet_function.async_call(packet) + return tag, result + + def process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None, "CapturedLogs"]: + """Process a single packet and return CapturedLogs alongside the result. + + Used by FunctionNode to get logs without a ContextVar side-channel. + """ + from orcapod.pipeline.logging_capture import CapturedLogs + + result, captured = self.packet_function.call(packet) + return tag, result, captured + + async def async_process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None, "CapturedLogs"]: + """Async counterpart of ``process_packet_with_capture``.""" + from orcapod.pipeline.logging_capture import CapturedLogs + + result, captured = await self.packet_function.async_call(packet) + return tag, result, captured def handle_input_streams(self, *streams: StreamProtocol) -> StreamProtocol: """Handle multiple input streams by joining them if necessary. diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index ca3b198c..630dac05 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -39,7 +39,8 @@ import polars as pl import pyarrow as pa - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol else: pa = LazyModule("pyarrow") pl = LazyModule("polars") @@ -480,26 +481,32 @@ def execute_packet( packet: The input packet to process. Returns: - A ``(tag, output_packet)`` tuple. + A ``(tag, output_packet)`` tuple. CapturedLogs are discarded. """ - return self._process_packet_internal(tag, packet) + tag_out, result, _captured = self._process_packet_internal(tag, packet) + return tag_out, result def execute( self, input_stream: StreamProtocol, *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", ) -> list[tuple[TagProtocol, PacketProtocol]]: """Execute all packets from a stream: compute, persist, and cache. Args: input_stream: The input stream to process. observer: Optional execution observer for hooks. + error_policy: ``"continue"`` (default) skips failed packets; + ``"fail_fast"`` re-raises on the first failure. Returns: Materialized list of (tag, output_packet) pairs, excluding - None outputs. + None outputs and failed packets. """ + from orcapod.pipeline.observer import _NOOP_LOGGER + if observer is not None: observer.on_node_start(self) @@ -511,10 +518,17 @@ def execute( entry_ids = [eid for _, _, eid in upstream_entries] cached = self.get_cached_results(entry_ids=entry_ids) + pp = self.pipeline_path if self._pipeline_database is not None else () + output: list[tuple[TagProtocol, PacketProtocol]] = [] for tag, packet, entry_id in upstream_entries: if observer is not None: observer.on_packet_start(self, tag, packet) + pkt_logger = observer.create_packet_logger( + self, tag, packet, pipeline_path=pp + ) + else: + pkt_logger = _NOOP_LOGGER if entry_id in cached: tag_out, result = cached[entry_id] @@ -522,11 +536,31 @@ def execute( observer.on_packet_end(self, tag, packet, result, cached=True) output.append((tag_out, result)) else: - tag_out, result = self._process_packet_internal(tag, packet) - if observer is not None: - observer.on_packet_end(self, tag, packet, result, cached=False) - if result is not None: - output.append((tag_out, result)) + tag_out, result, captured = self._process_packet_internal(tag, packet) + pkt_logger.record(captured) + if not captured.success: + if observer is not None: + observer.on_packet_crash( + self, + tag, + packet, + RuntimeError( + captured.traceback or "packet function failed" + ), + ) + if error_policy == "fail_fast": + if observer is not None: + observer.on_node_end(self) + raise RuntimeError( + captured.traceback or "packet function failed" + ) + else: + if observer is not None: + observer.on_packet_end( + self, tag, packet, result, cached=False + ) + if result is not None: + output.append((tag_out, result)) if observer is not None: observer.on_node_end(self) @@ -537,12 +571,15 @@ def _process_packet_internal( tag: TagProtocol, packet: PacketProtocol, cache_index: int | None = None, - ) -> tuple[TagProtocol, PacketProtocol | None]: + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": """Core compute + persist + cache. Used by ``execute_packet``, ``execute``, and ``iter_packets``. No input validation is performed — the caller guarantees correctness. + Returns: + A ``(tag, output_packet, captured_logs)`` 3-tuple. + Args: tag: The input tag. packet: The input packet. @@ -550,11 +587,11 @@ def _process_packet_internal( When ``None``, auto-assigns at ``len(_cached_output_packets)``. """ if self._cached_function_pod is not None: - tag_out, output_packet = self._cached_function_pod.process_packet( - tag, packet + tag_out, output_packet, captured = ( + self._cached_function_pod.process_packet_with_capture(tag, packet) ) - if output_packet is not None: + if output_packet is not None and captured.success: result_computed = bool( output_packet.get_meta_value( self._cached_function_pod.RESULT_COMPUTED_FLAG, False @@ -567,7 +604,9 @@ def _process_packet_internal( computed=result_computed, ) else: - tag_out, output_packet = self._function_pod.process_packet(tag, packet) + tag_out, output_packet, captured = ( + self._function_pod.process_packet_with_capture(tag, packet) + ) # Cache internally and invalidate derived caches idx = ( @@ -579,7 +618,7 @@ def _process_packet_internal( self._cached_output_table = None self._cached_content_hash_column = None - return tag_out, output_packet + return tag_out, output_packet, captured def get_cached_results( self, entry_ids: list[str] @@ -669,12 +708,15 @@ async def _async_process_packet_internal( tag: TagProtocol, packet: PacketProtocol, cache_index: int | None = None, - ) -> tuple[TagProtocol, PacketProtocol | None]: + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": """Async counterpart of ``_process_packet_internal``. Computes via async path, writes pipeline provenance, and caches internally — no schema validation. + Returns: + A ``(tag, output_packet, captured_logs)`` 3-tuple. + Args: tag: The input tag. packet: The input packet. @@ -682,12 +724,13 @@ async def _async_process_packet_internal( When ``None``, auto-assigns at ``len(_cached_output_packets)``. """ if self._cached_function_pod is not None: - ( - tag_out, - output_packet, - ) = await self._cached_function_pod.async_process_packet(tag, packet) + tag_out, output_packet, captured = ( + await self._cached_function_pod.async_process_packet_with_capture( + tag, packet + ) + ) - if output_packet is not None: + if output_packet is not None and captured.success: result_computed = bool( output_packet.get_meta_value( self._cached_function_pod.RESULT_COMPUTED_FLAG, False @@ -700,8 +743,8 @@ async def _async_process_packet_internal( computed=result_computed, ) else: - tag_out, output_packet = await self._function_pod.async_process_packet( - tag, packet + tag_out, output_packet, captured = ( + await self._function_pod.async_process_packet_with_capture(tag, packet) ) # Cache internally and invalidate derived caches @@ -714,7 +757,7 @@ async def _async_process_packet_internal( self._cached_output_table = None self._cached_content_hash_column = None - return tag_out, output_packet + return tag_out, output_packet, captured def compute_pipeline_entry_id( self, tag: TagProtocol, input_packet: PacketProtocol @@ -961,7 +1004,9 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: entry_id = self.compute_pipeline_entry_id(tag, packet) if entry_id in existing_entry_ids: continue - tag, output_packet = self._process_packet_internal(tag, packet) + tag, output_packet, _captured = self._process_packet_internal( + tag, packet + ) if output_packet is not None: yield tag, output_packet @@ -998,7 +1043,9 @@ def _iter_packets_sequential( if packet is not None: yield tag, packet else: - tag, output_packet = self._process_packet_internal(tag, packet) + tag, output_packet, _captured = self._process_packet_internal( + tag, packet + ) if output_packet is not None: yield tag, output_packet self._cached_input_iterator = None @@ -1165,7 +1212,7 @@ async def async_execute( input_channel: ReadableChannel[tuple[TagProtocol, PacketProtocol]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: """Streaming async execution for FunctionNode. @@ -1245,39 +1292,62 @@ async def async_execute( ) await output.send((tag_out, result_packet)) else: - if observer is not None: - observer.on_packet_start(self, tag, packet) - ( - tag_out, - result_packet, - ) = await self._async_process_packet_internal(tag, packet) - if observer is not None: - observer.on_packet_end( - self, tag, packet, result_packet, cached=False - ) - if result_packet is not None: - await output.send((tag_out, result_packet)) + await self._async_execute_one_packet( + tag, packet, output, observer=observer + ) else: # Simple async execution without DB async for tag, packet in input_channel: - if observer is not None: - observer.on_packet_start(self, tag, packet) - ( - tag_out, - result_packet, - ) = await self._async_process_packet_internal(tag, packet) - if observer is not None: - observer.on_packet_end( - self, tag, packet, result_packet, cached=False - ) - if result_packet is not None: - await output.send((tag_out, result_packet)) + await self._async_execute_one_packet( + tag, packet, output, observer=observer + ) if observer is not None: observer.on_node_end(self) finally: await output.close() + async def _async_execute_one_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> None: + """Process one non-cached packet in the async execute path.""" + from orcapod.pipeline.observer import _NOOP_LOGGER + + pp = self.pipeline_path if self._pipeline_database is not None else () + + if observer is not None: + observer.on_packet_start(self, tag, packet) + pkt_logger = observer.create_packet_logger( + self, tag, packet, pipeline_path=pp + ) + else: + pkt_logger = _NOOP_LOGGER + + tag_out, result_packet, captured = await self._async_process_packet_internal( + tag, packet + ) + pkt_logger.record(captured) + if not captured.success: + if observer is not None: + observer.on_packet_crash( + self, + tag, + packet, + RuntimeError(captured.traceback or "packet function failed"), + ) + else: + if observer is not None: + observer.on_packet_end( + self, tag, packet, result_packet, cached=False + ) + if result_packet is not None: + await output.send((tag_out, result_packet)) + def __repr__(self) -> str: return ( f"{type(self).__name__}(packet_function={self._packet_function!r}, " diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py index b5145d60..e16da95a 100644 --- a/src/orcapod/core/nodes/operator_node.py +++ b/src/orcapod/core/nodes/operator_node.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: import pyarrow as pa - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol else: pa = LazyModule("pyarrow") diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py index df8d47b1..504f29f7 100644 --- a/src/orcapod/core/nodes/source_node.py +++ b/src/orcapod/core/nodes/source_node.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: import pyarrow as pa - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol class SourceNode(StreamBase): diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index bd26ea11..db874de2 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -36,6 +36,8 @@ if TYPE_CHECKING: import pyarrow as pa import pyarrow.compute as pc + + from orcapod.pipeline.logging_capture import CapturedLogs else: pa = LazyModule("pyarrow") pc = LazyModule("pyarrow.compute") @@ -285,35 +287,51 @@ def set_executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: # ==================== Execution ==================== - def call(self, packet: PacketProtocol) -> PacketProtocol | None: + def call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Process a single packet, routing through the executor if one is set. Subclasses should override ``direct_call`` instead of this method. + + Returns: + A ``(output_packet, captured_logs)`` tuple. """ if self._executor is not None: return self._executor.execute(self, packet) return self.direct_call(packet) - async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Asynchronously process a single packet, routing through the executor if set. Subclasses should override ``direct_async_call`` instead of this method. + + Returns: + A ``(output_packet, captured_logs)`` tuple. """ if self._executor is not None: return await self._executor.async_execute(self, packet) return await self.direct_async_call(packet) @abstractmethod - def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: + def direct_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Execute the function's native computation on *packet*. This is the method executors invoke. It bypasses executor routing - and runs the computation directly. Subclasses must implement this. + and runs the computation directly. On user-function failure the + exception is caught internally and ``(None, captured_failure)`` + is returned — no re-raise. Subclasses must implement this. """ ... @abstractmethod - async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def direct_async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Asynchronous counterpart of ``direct_call``.""" ... @@ -508,61 +526,116 @@ def _call_async_function_sync(self, packet: PacketProtocol) -> Any: _get_sync_executor().submit(lambda: asyncio.run(fn(**kwargs))).result() ) - def call(self, packet: PacketProtocol) -> PacketProtocol | None: + def call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Process a single packet, routing through the executor if one is set. When an executor implementing ``PythonFunctionExecutorProtocol`` is - set, the raw callable and kwargs are handed to - ``execute_callable`` and the result is wrapped into an output packet. + set, the raw callable and kwargs are handed to ``execute_callable`` + which returns ``(raw_result, CapturedLogs)``. The output packet is + built from ``raw_result``. """ + from orcapod.pipeline.logging_capture import CapturedLogs + if self._executor is not None: if not self._active: - return None - raw = self._executor.execute_callable(self._function, packet.as_dict()) - return self._build_output_packet(raw) + return None, CapturedLogs(success=True) + raw, captured = self._executor.execute_callable( + self._function, packet.as_dict() + ) + if not captured.success: + return None, captured + return self._build_output_packet(raw), captured return self.direct_call(packet) - async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Async counterpart of ``call``.""" + from orcapod.pipeline.logging_capture import CapturedLogs + if self._executor is not None: if not self._active: - return None - raw = await self._executor.async_execute_callable( + return None, CapturedLogs(success=True) + raw, captured = await self._executor.async_execute_callable( self._function, packet.as_dict() ) - return self._build_output_packet(raw) + if not captured.success: + return None, captured + return self._build_output_packet(raw), captured return await self.direct_async_call(packet) - def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: - """Execute the function on *packet* synchronously. + def direct_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Execute the function on *packet* synchronously (no executor path). + Uses :class:`~orcapod.pipeline.logging_capture.LocalCaptureContext` + for I/O capture. On user-function failure the exception is caught + internally and ``(None, captured_failure)`` is returned — no re-raise. For async functions, the coroutine is driven to completion via ``asyncio.run()`` (or a helper thread when already inside an event loop). """ - if not self._active: - return None - if self._is_async: - values = self._call_async_function_sync(packet) - else: - values = self._function(**packet.as_dict()) - return self._build_output_packet(values) + import traceback as _tb - async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | None: - """Execute the function on *packet* asynchronously. + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + if not self._active: + return None, CapturedLogs(success=True) + + ctx = LocalCaptureContext() + raw_result = None + with ctx: + try: + if self._is_async: + raw_result = self._call_async_function_sync(packet) + else: + raw_result = self._function(**packet.as_dict()) + except Exception: + return None, ctx.get_captured(success=False, tb=_tb.format_exc()) + return self._build_output_packet(raw_result), ctx.get_captured(success=True) + + async def direct_async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Execute the function on *packet* asynchronously (no executor path). Async functions are ``await``-ed directly. Sync functions are - offloaded to a thread pool via ``run_in_executor``. + offloaded to a thread pool via ``run_in_executor``. On failure, + ``(None, captured_failure)`` is returned — no re-raise. """ - if not self._active: - return None - if self._is_async: - values = await self._function(**packet.as_dict()) - return self._build_output_packet(values) - else: - import asyncio + import asyncio + import traceback as _tb - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, self.direct_call, packet) + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + if not self._active: + return None, CapturedLogs(success=True) + + ctx = LocalCaptureContext() + raw_result = None + with ctx: + try: + if self._is_async: + raw_result = await self._function(**packet.as_dict()) + else: + import contextvars + import functools + + loop = asyncio.get_running_loop() + task_ctx = contextvars.copy_context() + raw_result = await loop.run_in_executor( + None, + functools.partial( + task_ctx.run, + self._function, + **packet.as_dict(), + ), + ) + except Exception: + return None, ctx.get_captured(success=False, tb=_tb.format_exc()) + return self._build_output_packet(raw_result), ctx.get_captured(success=True) def to_config(self) -> dict[str, Any]: """Serialize this packet function to a JSON-compatible config dict. @@ -711,16 +784,24 @@ def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: # call/async_call which handles executor routing. direct_call / # direct_async_call bypass executor routing as their names imply. - def call(self, packet: PacketProtocol) -> PacketProtocol | None: + def call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return self._packet_function.call(packet) - async def async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return await self._packet_function.async_call(packet) - def direct_call(self, packet: PacketProtocol) -> PacketProtocol | None: + def direct_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return self._packet_function.direct_call(packet) - async def direct_async_call(self, packet: PacketProtocol) -> PacketProtocol | None: + async def direct_async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return await self._packet_function.direct_async_call(packet) @@ -765,28 +846,29 @@ def call( *, skip_cache_lookup: bool = False, skip_cache_insert: bool = False, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs + output_packet = None if not skip_cache_lookup: logger.info("Checking for cache...") output_packet = self._cache.lookup(packet) if output_packet is not None: logger.info(f"Cache hit for {packet}!") - if output_packet is None: - output_packet = self._packet_function.call(packet) - if output_packet is not None: - if not skip_cache_insert: - self._cache.store( - packet, - output_packet, - variation_data=self.get_function_variation_data(), - execution_data=self.get_execution_data(), - ) - output_packet = output_packet.with_meta_columns( - **{self.RESULT_COMPUTED_FLAG: True} + return output_packet, CapturedLogs(success=True) + output_packet, captured = self._packet_function.call(packet) + if output_packet is not None: + if not skip_cache_insert: + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), ) - - return output_packet + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet, captured async def async_call( self, @@ -794,28 +876,30 @@ async def async_call( *, skip_cache_lookup: bool = False, skip_cache_insert: bool = False, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Async counterpart of ``call`` with cache check and recording.""" + from orcapod.pipeline.logging_capture import CapturedLogs + output_packet = None if not skip_cache_lookup: logger.info("Checking for cache...") output_packet = self._cache.lookup(packet) if output_packet is not None: logger.info(f"Cache hit for {packet}!") - if output_packet is None: - output_packet = await self._packet_function.async_call(packet) - if output_packet is not None: - if not skip_cache_insert: - self._cache.store( - packet, - output_packet, - variation_data=self.get_function_variation_data(), - execution_data=self.get_execution_data(), - ) - output_packet = output_packet.with_meta_columns( - **{self.RESULT_COMPUTED_FLAG: True} + return output_packet, CapturedLogs(success=True) + output_packet, captured = await self._packet_function.async_call(packet) + if output_packet is not None: + if not skip_cache_insert: + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), ) - return output_packet + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet, captured def get_cached_output_for_packet( self, input_packet: PacketProtocol diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 714ea9e7..67dddfe1 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,11 +1,14 @@ from .async_orchestrator import AsyncPipelineOrchestrator from .graph import Pipeline +from .logging_observer import LoggingObserver, PacketLogger from .serialization import LoadStatus, PIPELINE_FORMAT_VERSION from .sync_orchestrator import SyncPipelineOrchestrator __all__ = [ "AsyncPipelineOrchestrator", "LoadStatus", + "LoggingObserver", + "PacketLogger", "PIPELINE_FORMAT_VERSION", "Pipeline", "SyncPipelineOrchestrator", diff --git a/src/orcapod/pipeline/async_orchestrator.py b/src/orcapod/pipeline/async_orchestrator.py index 04a6e7d3..c6e35d41 100644 --- a/src/orcapod/pipeline/async_orchestrator.py +++ b/src/orcapod/pipeline/async_orchestrator.py @@ -8,6 +8,7 @@ from __future__ import annotations import asyncio +import uuid import logging from collections import defaultdict from typing import TYPE_CHECKING, Any @@ -23,7 +24,7 @@ if TYPE_CHECKING: import networkx as nx - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol logger = logging.getLogger(__name__) @@ -48,16 +49,19 @@ class AsyncPipelineOrchestrator: def __init__( self, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, buffer_size: int = 64, + error_policy: str = "continue", ) -> None: self._observer = observer self._buffer_size = buffer_size + self._error_policy = error_policy def run( self, graph: "nx.DiGraph", materialize_results: bool = True, + run_id: str | None = None, ) -> OrchestratorResult: """Synchronous entry point — runs the async pipeline to completion. @@ -65,36 +69,46 @@ def run( graph: A NetworkX DiGraph with GraphNode objects as vertices. materialize_results: If True, collect all node outputs into the result. If False, return empty node_outputs. + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. Returns: OrchestratorResult with node outputs. """ - return asyncio.run(self._run_async(graph, materialize_results)) + return asyncio.run(self._run_async(graph, materialize_results, run_id=run_id)) async def run_async( self, graph: "nx.DiGraph", materialize_results: bool = True, + run_id: str | None = None, ) -> OrchestratorResult: """Async entry point for callers already inside an event loop. Args: graph: A NetworkX DiGraph with GraphNode objects as vertices. materialize_results: If True, collect all node outputs. + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. Returns: OrchestratorResult with node outputs. """ - return await self._run_async(graph, materialize_results) + return await self._run_async(graph, materialize_results, run_id=run_id) async def _run_async( self, graph: "nx.DiGraph", materialize_results: bool, + run_id: str | None = None, ) -> OrchestratorResult: """Core async logic: wire channels, launch tasks, collect results.""" import networkx as nx + effective_run_id = run_id or str(uuid.uuid4()) + if self._observer is not None: + self._observer.on_run_start(effective_run_id) + topo_order = list(nx.topological_sort(graph)) buf = self._buffer_size @@ -186,6 +200,8 @@ async def _run_async( for ch in terminal_channels: tg.create_task(ch.reader.collect()) + if self._observer is not None: + self._observer.on_run_end(effective_run_id) return OrchestratorResult( node_outputs=collectors if materialize_results else {} ) diff --git a/src/orcapod/pipeline/logging_capture.py b/src/orcapod/pipeline/logging_capture.py new file mode 100644 index 00000000..e2d8cee2 --- /dev/null +++ b/src/orcapod/pipeline/logging_capture.py @@ -0,0 +1,235 @@ +"""Capture infrastructure for observability logging. + +Provides context-variable-local capture of stdout, stderr, and Python logging +for use in FunctionNode execution. Thread-safe and asyncio-task-safe via +``contextvars.ContextVar`` — captures from concurrent packets never intermingle. + +CapturedLogs travel as part of the return type through the call chain +(``direct_call`` → ``call`` → ``process_packet`` → FunctionNode) so there +is no ContextVar side-channel for logs. Each executor's ``execute_callable`` +returns ``(raw_result, CapturedLogs)``, and ``direct_call`` returns +``(output_packet, CapturedLogs)`` — catching user-function exceptions +internally rather than re-raising. + +Typical usage +------------- +Call ``install_capture_streams()`` once when a logging Observer is created. +The executor or ``direct_call`` wraps function execution in +``LocalCaptureContext`` and returns CapturedLogs alongside the result:: + + result, captured = packet_function.call(packet) + pkt_logger.record(captured) +""" + +from __future__ import annotations + +import contextvars +import io +import logging +import sys +from dataclasses import dataclass +from typing import Any + + +# --------------------------------------------------------------------------- +# CapturedLogs +# --------------------------------------------------------------------------- + + +@dataclass +class CapturedLogs: + """I/O captured from a single packet function execution.""" + + stdout: str = "" + stderr: str = "" + python_logs: str = "" + traceback: str | None = None + success: bool = True + + +# --------------------------------------------------------------------------- +# Context variables +# --------------------------------------------------------------------------- +# Each asyncio task and thread gets its own copy of these variables, so +# captures from concurrent packets never intermingle. + +_stdout_capture: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_stdout_capture", default=None +) +_stderr_capture: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_stderr_capture", default=None +) +_log_capture: contextvars.ContextVar[list[str] | None] = contextvars.ContextVar( + "_log_capture", default=None +) + + +# --------------------------------------------------------------------------- +# ContextLocalTeeStream +# --------------------------------------------------------------------------- + + +class ContextLocalTeeStream: + """A stream that writes to the original *and* a per-context capture buffer. + + All writes go to *original* (terminal output is preserved) and also to a + ``StringIO`` buffer active for the current asyncio task / thread (selected + via ``capture_var``). Concurrent tasks each have their own buffer and do + not interfere with each other. + """ + + def __init__( + self, + original: Any, + capture_var: contextvars.ContextVar[io.StringIO | None], + ) -> None: + self._original = original + self._capture_var = capture_var + + def write(self, s: str) -> int: + buf = self._capture_var.get() + if buf is not None: + buf.write(s) + return self._original.write(s) + + def flush(self) -> None: + buf = self._capture_var.get() + if buf is not None: + buf.flush() + self._original.flush() + + def __getattr__(self, name: str) -> Any: + return getattr(self._original, name) + + +# --------------------------------------------------------------------------- +# ContextVarLoggingHandler +# --------------------------------------------------------------------------- + + +class ContextVarLoggingHandler(logging.Handler): + """A logging handler that captures records into a per-context buffer. + + When a capture buffer is active for the current context (asyncio task or + thread), log records are formatted and appended to it. When no buffer is + active the record is silently discarded (not duplicated to other handlers). + """ + + def emit(self, record: logging.LogRecord) -> None: + buf = _log_capture.get() + if buf is not None: + buf.append(self.format(record)) + + +# --------------------------------------------------------------------------- +# Global installation (idempotent) +# --------------------------------------------------------------------------- + +_installed = False +_logging_handler: ContextVarLoggingHandler | None = None + + +def install_capture_streams() -> None: + """Install tee streams and the logging handler globally. + + Idempotent — safe to call multiple times. Should be called once when a + concrete logging Observer is instantiated. + + After installation: + + * ``sys.stdout`` / ``sys.stderr`` tee writes to per-context buffers while + also forwarding to the original stream (terminal output preserved). + * The root logger gains a ``ContextVarLoggingHandler`` that captures + records to per-context buffers (covering Python ``logging`` calls). + + .. note:: + Subprocess and C-extension output bypasses Python's stream objects and + goes directly to file descriptors 1/2. For local execution these are + *not* captured (but are still visible in the terminal). Ray remote + execution uses fd-level capture via + ``RayExecutor._make_capture_wrapper``. + + The stream check runs on every call so that if something (e.g. a test + harness) replaces ``sys.stdout``/``sys.stderr`` between calls we + re-wrap the new stream. The logging handler is only added once. + """ + global _installed, _logging_handler + + # Always re-check in case sys.stdout/stderr was replaced (e.g. by pytest). + if not isinstance(sys.stdout, ContextLocalTeeStream): + sys.stdout = ContextLocalTeeStream(sys.stdout, _stdout_capture) + if not isinstance(sys.stderr, ContextLocalTeeStream): + sys.stderr = ContextLocalTeeStream(sys.stderr, _stderr_capture) + + if _installed: + return + + _logging_handler = ContextVarLoggingHandler() + _logging_handler.setFormatter( + logging.Formatter("%(levelname)s:%(name)s:%(message)s") + ) + logging.getLogger().addHandler(_logging_handler) + + _installed = True + + +# --------------------------------------------------------------------------- +# LocalCaptureContext +# --------------------------------------------------------------------------- + + +class LocalCaptureContext: + """Context manager that activates per-context capture for one packet. + + Requires ``install_capture_streams()`` to have been called; without it the + ContextVars are set but nothing tees into them, so captured strings will be + empty (acceptable when no logging Observer is configured). + + Example:: + + ctx = LocalCaptureContext() + try: + with ctx: + result = call_something() + except Exception: + captured = ctx.get_captured(success=False, tb=traceback.format_exc()) + else: + captured = ctx.get_captured(success=True) + """ + + def __init__(self) -> None: + self._stdout_buf = io.StringIO() + self._stderr_buf = io.StringIO() + self._log_buf: list[str] = [] + self._tokens: list[contextvars.Token] = [] + + def __enter__(self) -> "LocalCaptureContext": + self._tokens = [ + _stdout_capture.set(self._stdout_buf), + _stderr_capture.set(self._stderr_buf), + _log_capture.set(self._log_buf), + ] + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + for token, var in zip( + self._tokens, + [_stdout_capture, _stderr_capture, _log_capture], + ): + var.reset(token) + return False # do not suppress exceptions + + def get_captured( + self, + success: bool, + tb: str | None = None, + ) -> CapturedLogs: + """Return a :class:`CapturedLogs` from what was captured in this context.""" + return CapturedLogs( + stdout=self._stdout_buf.getvalue(), + stderr=self._stderr_buf.getvalue(), + python_logs="\n".join(self._log_buf) if self._log_buf else "", + traceback=tb, + success=success, + ) + diff --git a/src/orcapod/pipeline/logging_observer.py b/src/orcapod/pipeline/logging_observer.py new file mode 100644 index 00000000..6db61312 --- /dev/null +++ b/src/orcapod/pipeline/logging_observer.py @@ -0,0 +1,272 @@ +"""Concrete logging observer for orcapod pipelines. + +Provides :class:`LoggingObserver`, a drop-in observer that captures stdout, +stderr, Python logging, and tracebacks from every packet execution and writes +structured log rows to any :class:`~orcapod.protocols.database_protocols.ArrowDatabaseProtocol` +(in-memory, Delta Lake, etc.). + +Typical usage:: + + from orcapod.pipeline.logging_observer import LoggingObserver + from orcapod.pipeline import SyncPipelineOrchestrator + from orcapod.databases import InMemoryArrowDatabase + + obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + pipeline.run(orchestrator=SyncPipelineOrchestrator(observer=obs)) + + # Inspect captured logs + logs = obs.get_logs() # pyarrow.Table + logs.to_pandas() # pandas DataFrame + +Log schema (fixed columns) +-------------------------- +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``log_id`` + - ``large_utf8`` + - UUID unique to this log entry + * - ``run_id`` + - ``large_utf8`` + - UUID of the pipeline run (from ``on_run_start``) + * - ``node_label`` + - ``large_utf8`` + - Label of the function node + * - ``stdout`` + - ``large_utf8`` + - Captured standard output + * - ``stderr`` + - ``large_utf8`` + - Captured standard error + * - ``python_logs`` + - ``large_utf8`` + - Python ``logging`` output captured during execution + * - ``traceback`` + - ``large_utf8`` + - Full traceback on failure; ``None`` on success + * - ``success`` + - ``bool_`` + - ``True`` if the packet function returned normally + * - ``timestamp`` + - ``large_utf8`` + - ISO-8601 UTC timestamp when ``record()`` was called + +In addition, each tag key from the packet's tag becomes a separate +``large_utf8`` column (queryable, not JSON-encoded). + +Log storage +----------- +Logs are stored at a pipeline-path-mirrored location: +``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. +Each function node gets its own log table. Use +``get_logs(pipeline_path=node.pipeline_path)`` to retrieve +node-specific logs. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from uuid_utils import uuid7 + +from orcapod.pipeline.logging_capture import CapturedLogs, install_capture_streams + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +logger = logging.getLogger(__name__) + +# Default path (table name) within the database where log rows are stored. +DEFAULT_LOG_PATH: tuple[str, ...] = ("execution_logs",) + + +class PacketLogger: + """Context-bound logger created by :class:`LoggingObserver` per packet. + + Holds all context needed to write a structured log row + (run_id, node_label, tag data) so the caller only needs to pass the + :class:`~orcapod.pipeline.logging_capture.CapturedLogs` payload. + + Tag data is stored as individual queryable columns (not JSON) alongside + the fixed log columns. + + This class is not intended to be instantiated directly — use + :meth:`LoggingObserver.create_packet_logger` instead. + """ + + def __init__( + self, + db: "ArrowDatabaseProtocol", + log_path: tuple[str, ...], + run_id: str, + node_label: str, + tag_data: dict[str, Any], + ) -> None: + self._db = db + self._log_path = log_path + self._run_id = run_id + self._node_label = node_label + self._tag_data = tag_data + + def record(self, captured: CapturedLogs) -> None: + """Write one log row to the database.""" + import pyarrow as pa + + log_id = str(uuid7()) + timestamp = datetime.now(timezone.utc).isoformat() + + # Fixed columns + columns: dict[str, pa.Array] = { + "log_id": pa.array([log_id], type=pa.large_utf8()), + "run_id": pa.array([self._run_id], type=pa.large_utf8()), + "node_label": pa.array([self._node_label], type=pa.large_utf8()), + "stdout": pa.array([captured.stdout], type=pa.large_utf8()), + "stderr": pa.array([captured.stderr], type=pa.large_utf8()), + "python_logs": pa.array([captured.python_logs], type=pa.large_utf8()), + "traceback": pa.array([captured.traceback], type=pa.large_utf8()), + "success": pa.array([captured.success], type=pa.bool_()), + "timestamp": pa.array([timestamp], type=pa.large_utf8()), + } + + # Dynamic tag columns — each tag key becomes its own column + for key, value in self._tag_data.items(): + columns[key] = pa.array([str(value)], type=pa.large_utf8()) + + row = pa.table(columns) + try: + self._db.add_record(self._log_path, log_id, row, flush=True) + except Exception: + logger.exception( + "LoggingObserver: failed to write log row for node=%s", + self._node_label, + ) + + +class LoggingObserver: + """Concrete observer that writes packet execution logs to a database. + + Instantiate once, outside the pipeline, and pass to the orchestrator:: + + obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + # After the run, read back captured logs: + logs_table = obs.get_logs() # pyarrow.Table + + For async / Ray pipelines use :class:`~orcapod.pipeline.AsyncPipelineOrchestrator` + with the same observer:: + + orch = AsyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + Args: + log_database: Any :class:`~orcapod.protocols.database_protocols.ArrowDatabaseProtocol` + instance — :class:`~orcapod.databases.InMemoryArrowDatabase`, + a Delta Lake database, etc. + log_path: Tuple of strings identifying the table within the database. + Defaults to ``("execution_logs",)``. + + Note: + Construction calls :func:`~orcapod.pipeline.logging_capture.install_capture_streams` + so that stdout/stderr tee-capture is active from the moment the observer + is created. + """ + + def __init__( + self, + log_database: "ArrowDatabaseProtocol", + log_path: tuple[str, ...] | None = None, + ) -> None: + self._db = log_database + self._log_path = log_path or DEFAULT_LOG_PATH + self._current_run_id: str = "" + # Activate tee-capture as soon as the observer is created. + install_capture_streams() + + # -- lifecycle hooks -- + + def on_run_start(self, run_id: str) -> None: + self._current_run_id = run_id + + def on_run_end(self, run_id: str) -> None: + pass + + def on_node_start(self, node: Any) -> None: + pass + + def on_node_end(self, node: Any) -> None: + pass + + def on_packet_start(self, node: Any, tag: Any, packet: Any) -> None: + pass + + def on_packet_end( + self, + node: Any, + tag: Any, + input_packet: Any, + output_packet: Any, + cached: bool, + ) -> None: + pass + + def on_packet_crash(self, node: Any, tag: Any, packet: Any, error: Exception) -> None: + pass + + def create_packet_logger( + self, + node: Any, + tag: Any, + packet: Any, + pipeline_path: tuple[str, ...] = (), + ) -> PacketLogger: + """Return a :class:`PacketLogger` bound to *node* + *tag* context. + + Log rows are stored at a pipeline-path-mirrored location: + ``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. This gives + each function node its own log table in the database. + """ + node_label = getattr(node, "label", None) or getattr(node, "node_type", "unknown") + tag_data = dict(tag) + + # Compute mirrored log path + if pipeline_path: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._log_path + + return PacketLogger( + db=self._db, + log_path=log_path, + run_id=self._current_run_id, + node_label=node_label, + tag_data=tag_data, + ) + + # -- convenience -- + + def get_logs( + self, pipeline_path: tuple[str, ...] | None = None + ) -> "pa.Table | None": + """Read log rows from the database as a :class:`pyarrow.Table`. + + Args: + pipeline_path: If provided, reads logs for a specific node + (mirrored path). If ``None``, reads from the default + log path. + + Returns ``None`` if no logs have been written yet. + """ + if pipeline_path is not None: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._log_path + return self._db.get_all_records(log_path) diff --git a/src/orcapod/pipeline/observer.py b/src/orcapod/pipeline/observer.py index 22a34138..e9df178c 100644 --- a/src/orcapod/pipeline/observer.py +++ b/src/orcapod/pipeline/observer.py @@ -1,47 +1,61 @@ -"""Execution observer protocol for pipeline orchestration. +"""No-op implementations of the observability protocols. -Provides hooks for monitoring node and packet-level execution events -during orchestrated pipeline runs. +Provides :class:`NoOpLogger` and :class:`NoOpObserver` — the defaults used +when no observability is configured. Every method is a zero-cost no-op. """ from __future__ import annotations -from typing import TYPE_CHECKING, Protocol, runtime_checkable +from typing import TYPE_CHECKING + +from orcapod.protocols.observability_protocols import ( # noqa: F401 (re-exported for convenience) + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) if TYPE_CHECKING: from orcapod.core.nodes import GraphNode + from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol -@runtime_checkable -class ExecutionObserver(Protocol): - """Observer protocol for pipeline execution events. +# --------------------------------------------------------------------------- +# NoOpLogger +# --------------------------------------------------------------------------- + + +class NoOpLogger: + """Logger that discards all captured output. - ``on_packet_start`` / ``on_packet_end`` are only invoked for function - nodes. ``on_node_start`` / ``on_node_end`` are invoked for all node - types. + Returned by :class:`NoOpObserver` when no logging sink is configured. """ - def on_node_start(self, node: "GraphNode") -> None: ... - def on_node_end(self, node: "GraphNode") -> None: ... - def on_packet_start( - self, - node: "GraphNode", - tag: "TagProtocol", - packet: "PacketProtocol", - ) -> None: ... - def on_packet_end( - self, - node: "GraphNode", - tag: "TagProtocol", - input_packet: "PacketProtocol", - output_packet: "PacketProtocol | None", - cached: bool, - ) -> None: ... + def record(self, captured: "CapturedLogs") -> None: + pass + + +# Singleton — NoOpLogger carries no state so one instance is enough. +_NOOP_LOGGER = NoOpLogger() + + +# --------------------------------------------------------------------------- +# NoOpObserver +# --------------------------------------------------------------------------- class NoOpObserver: - """Default observer that does nothing.""" + """Observer that does nothing. + + Satisfies :class:`~orcapod.protocols.observability_protocols.ExecutionObserverProtocol` + and is the default when no observability is configured. + ``create_packet_logger`` returns the shared :data:`_NOOP_LOGGER` singleton. + """ + + def on_run_start(self, run_id: str) -> None: + pass + + def on_run_end(self, run_id: str) -> None: + pass def on_node_start(self, node: "GraphNode") -> None: pass @@ -66,3 +80,21 @@ def on_packet_end( cached: bool, ) -> None: pass + + def on_packet_crash( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + error: Exception, + ) -> None: + pass + + def create_packet_logger( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + pipeline_path: tuple[str, ...] = (), + ) -> NoOpLogger: + return _NOOP_LOGGER diff --git a/src/orcapod/pipeline/sync_orchestrator.py b/src/orcapod/pipeline/sync_orchestrator.py index e788421c..b3cf8d74 100644 --- a/src/orcapod/pipeline/sync_orchestrator.py +++ b/src/orcapod/pipeline/sync_orchestrator.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +import uuid from typing import TYPE_CHECKING, Any from orcapod.pipeline.result import OrchestratorResult @@ -19,7 +20,7 @@ if TYPE_CHECKING: import networkx as nx - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol logger = logging.getLogger(__name__) @@ -39,13 +40,19 @@ class SyncPipelineOrchestrator: observer: Optional execution observer forwarded to nodes. """ - def __init__(self, observer: "ExecutionObserver | None" = None) -> None: + def __init__( + self, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", + ) -> None: self._observer = observer + self._error_policy = error_policy def run( self, graph: "nx.DiGraph", materialize_results: bool = True, + run_id: str | None = None, ) -> OrchestratorResult: """Execute the node graph synchronously. @@ -54,12 +61,18 @@ def run( materialize_results: If True, keep all node outputs in memory and return them. If False, discard buffers after downstream consumption (only DB-persisted results survive). + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. Returns: OrchestratorResult with node outputs. """ import networkx as nx + effective_run_id = run_id or str(uuid.uuid4()) + if self._observer is not None: + self._observer.on_run_start(effective_run_id) + topo_order = list(nx.topological_sort(graph)) buffers: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} processed: set[Any] = set() @@ -71,7 +84,11 @@ def run( upstream_buf = self._gather_upstream(node, graph, buffers) upstream_node = list(graph.predecessors(node))[0] input_stream = self._materialize_as_stream(upstream_buf, upstream_node) - buffers[node] = node.execute(input_stream, observer=self._observer) + buffers[node] = node.execute( + input_stream, + observer=self._observer, + error_policy=self._error_policy, + ) elif is_operator_node(node): upstream_buffers = self._gather_upstream_multi(node, graph, buffers) input_streams = [ @@ -91,6 +108,9 @@ def run( if not materialize_results: buffers.clear() + + if self._observer is not None: + self._observer.on_run_end(effective_run_id) return OrchestratorResult(node_outputs=buffers) @staticmethod diff --git a/src/orcapod/protocols/__init__.py b/src/orcapod/protocols/__init__.py index e69de29b..7f8ba7c7 100644 --- a/src/orcapod/protocols/__init__.py +++ b/src/orcapod/protocols/__init__.py @@ -0,0 +1,4 @@ +from orcapod.protocols.observability_protocols import ( + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py index 2260898a..b56d4734 100644 --- a/src/orcapod/protocols/core_protocols/executor.py +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -6,6 +6,7 @@ from orcapod.protocols.core_protocols.datagrams import PacketProtocol if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols.packet_function import PacketFunctionProtocol @@ -38,11 +39,12 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Synchronously execute *packet_function* on *packet*. The executor should invoke ``packet_function.direct_call(packet)`` - in the appropriate execution environment. + in the appropriate execution environment and pass through its + ``(result, CapturedLogs)`` tuple. """ ... @@ -50,7 +52,7 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Asynchronous counterpart of ``execute``.""" ... @@ -98,8 +100,8 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: - """Synchronously execute *fn* with *kwargs*. + ) -> "tuple[Any, CapturedLogs]": + """Synchronously execute *fn* with *kwargs*, capturing I/O. Args: fn: The Python callable to execute. @@ -108,7 +110,10 @@ def execute_callable( overrides). Returns: - The raw return value of *fn*. + A ``(raw_result, CapturedLogs)`` tuple. ``raw_result`` is the + return value of *fn* (or ``None`` on failure). + ``CapturedLogs.success`` is ``False`` when the function raised; + the traceback is stored in ``CapturedLogs.traceback``. """ ... @@ -117,8 +122,8 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> Any: - """Asynchronously execute *fn* with *kwargs*. + ) -> "tuple[Any, CapturedLogs]": + """Asynchronously execute *fn* with *kwargs*, capturing I/O. Args: fn: The Python callable to execute. @@ -126,6 +131,6 @@ async def async_execute_callable( executor_options: Optional per-call options. Returns: - The raw return value of *fn*. + A ``(raw_result, CapturedLogs)`` tuple. """ ... diff --git a/src/orcapod/protocols/core_protocols/packet_function.py b/src/orcapod/protocols/core_protocols/packet_function.py index 37ab97f0..c806597a 100644 --- a/src/orcapod/protocols/core_protocols/packet_function.py +++ b/src/orcapod/protocols/core_protocols/packet_function.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from orcapod.protocols.core_protocols.datagrams import PacketProtocol from orcapod.protocols.core_protocols.executor import PacketFunctionExecutorProtocol @@ -11,6 +11,9 @@ ) from orcapod.types import Schema +if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs + @runtime_checkable class PacketFunctionProtocol( @@ -78,51 +81,55 @@ def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: def call( self, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Process a single packet, routing through the executor if one is set. Args: packet: The data payload to process. Returns: - The processed packet, or ``None`` to filter it out. + A ``(output_packet, captured_logs)`` tuple. ``output_packet`` + is ``None`` when the function filters the packet out or when + the execution failed (check ``captured_logs.success``). """ ... async def async_call( self, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Asynchronously process a single packet, routing through the executor if set. Args: packet: The data payload to process. Returns: - The processed packet, or ``None`` to filter it out. + A ``(output_packet, captured_logs)`` tuple. """ ... def direct_call( self, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Execute the function's native computation on *packet*. This is the method executors invoke, bypassing executor routing. + On user-function failure the exception is caught internally and + ``(None, captured_with_success=False)`` is returned — no re-raise. Args: packet: The data payload to process. Returns: - The processed packet, or ``None`` to filter it out. + A ``(output_packet, captured_logs)`` tuple. """ ... async def direct_async_call( self, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": """Asynchronous counterpart of ``direct_call``.""" ... diff --git a/src/orcapod/protocols/node_protocols.py b/src/orcapod/protocols/node_protocols.py index 95677412..f51e6de6 100644 --- a/src/orcapod/protocols/node_protocols.py +++ b/src/orcapod/protocols/node_protocols.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.nodes import GraphNode - from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol from orcapod.protocols.core_protocols import ( PacketProtocol, StreamProtocol, @@ -34,14 +34,14 @@ class SourceNodeProtocol(Protocol): def execute( self, *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... async def async_execute( self, output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: ... @@ -55,7 +55,8 @@ def execute( self, input_stream: "StreamProtocol", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... async def async_execute( @@ -63,7 +64,7 @@ async def async_execute( input_channel: "ReadableChannel[tuple[TagProtocol, PacketProtocol]]", output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: ... @@ -76,7 +77,7 @@ class OperatorNodeProtocol(Protocol): def execute( self, *input_streams: "StreamProtocol", - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... async def async_execute( @@ -84,7 +85,7 @@ async def async_execute( inputs: "Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]]", output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", *, - observer: "ExecutionObserver | None" = None, + observer: "ExecutionObserverProtocol | None" = None, ) -> None: ... diff --git a/src/orcapod/protocols/observability_protocols.py b/src/orcapod/protocols/observability_protocols.py new file mode 100644 index 00000000..dad3d0fd --- /dev/null +++ b/src/orcapod/protocols/observability_protocols.py @@ -0,0 +1,140 @@ +"""Observability protocols for pipeline execution tracking and logging. + +Defines: + +* :class:`PacketExecutionLoggerProtocol` — receives captured I/O from a single + packet execution and persists it to a configured sink. +* :class:`ExecutionObserverProtocol` — lifecycle hooks for pipeline/node/packet + events, plus a factory method for creating context-bound loggers. + +Both follow the same runtime-checkable Protocol pattern used throughout the +rest of the orcapod codebase. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from orcapod.core.nodes import GraphNode + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +@runtime_checkable +class PacketExecutionLoggerProtocol(Protocol): + """Receives captured execution output and persists it. + + A logger is *bound* to a specific packet execution context (node, tag, + packet) when created by the Observer. It knows the destination (e.g. a + Delta Lake table) but does not know how the logs were collected — that is + the executor's responsibility. + """ + + def record(self, captured: "CapturedLogs") -> None: + """Persist the captured logs from a packet function execution. + + Called after every packet execution (success or failure), except for + cache hits when ``log_cache_hits=False`` (the default). + """ + ... + + +@runtime_checkable +class ExecutionObserverProtocol(Protocol): + """Observer protocol for pipeline execution lifecycle events. + + Instantiated once outside the pipeline and injected into the orchestrator. + Provides hooks for lifecycle events at the run, node, and packet level, and + acts as a factory for context-specific loggers. + + ``on_packet_start`` / ``on_packet_end`` / ``on_packet_crash`` are invoked + only for function nodes. ``on_node_start`` / ``on_node_end`` are invoked + for all node types. + """ + + def on_run_start(self, run_id: str) -> None: + """Called at the very start of an orchestrator ``run()`` call. + + Args: + run_id: A UUID string unique to this execution run. All loggers + created during the run will be stamped with this ID. + """ + ... + + def on_run_end(self, run_id: str) -> None: + """Called at the very end of an orchestrator ``run()`` call. + + Args: + run_id: The same UUID passed to ``on_run_start``. + """ + ... + + def on_node_start(self, node: "GraphNode") -> None: + """Called before a node begins processing its packets.""" + ... + + def on_node_end(self, node: "GraphNode") -> None: + """Called after a node finishes processing all packets.""" + ... + + def on_packet_start( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + ) -> None: + """Called before a packet is processed by a function node.""" + ... + + def on_packet_end( + self, + node: "GraphNode", + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + cached: bool, + ) -> None: + """Called after a packet is successfully processed (or served from cache). + + Args: + cached: ``True`` when the result came from a database cache and + the user function was not executed. + """ + ... + + def on_packet_crash( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + error: Exception, + ) -> None: + """Called when a packet's execution fails. + + Covers both user-function exceptions (captured on the worker) and + system-level crashes (e.g. ``WorkerCrashedError`` from Ray). The + pipeline continues processing remaining packets rather than aborting. + """ + ... + + def create_packet_logger( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + pipeline_path: tuple[str, ...] = (), + ) -> PacketExecutionLoggerProtocol: + """Create a context-bound logger for a single packet execution. + + The returned logger is pre-stamped with the node label, run ID, and + packet identity so every ``record()`` call writes the correct context + without the executor needing to know anything about the pipeline. + + Args: + node: The graph node being executed. + tag: The tag for the packet being processed. + packet: The input packet being processed. + pipeline_path: The node's pipeline path for log storage scoping. + """ + ... diff --git a/tests/test_channels/test_async_execute.py b/tests/test_channels/test_async_execute.py index 6aa4124d..c1c73036 100644 --- a/tests/test_channels/test_async_execute.py +++ b/tests/test_channels/test_async_execute.py @@ -155,7 +155,8 @@ def add(x: int, y: int) -> int: pf = PythonPacketFunction(add, output_keys="result") packet = Packet({"x": 3, "y": 5}) - result = await pf.direct_async_call(packet) + result, captured = await pf.direct_async_call(packet) + assert captured.success is True assert result is not None assert result.as_dict()["result"] == 8 @@ -165,11 +166,12 @@ def double(x: int) -> int: return x * 2 pf = PythonPacketFunction(double, output_keys="result") - results = await asyncio.gather( + raw_results = await asyncio.gather( pf.async_call(Packet({"x": 1})), pf.async_call(Packet({"x": 2})), pf.async_call(Packet({"x": 3})), ) + results = [r for r, _captured in raw_results] assert all(r is not None for r in results) values = [r.as_dict()["result"] for r in results if r is not None] assert values == [2, 4, 6] @@ -186,7 +188,9 @@ def record_thread(x: int) -> int: return x pf = PythonPacketFunction(record_thread, output_keys="result") - await pf.direct_async_call(Packet({"x": 42})) + result, captured = await pf.direct_async_call(Packet({"x": 42})) + assert captured.success is True + assert result is not None assert len(call_threads) == 1 @@ -727,8 +731,8 @@ async def push(stream, ch): class TestErrorPropagation: @pytest.mark.asyncio - async def test_function_exception_propagates(self): - """An exception in the packet function should propagate out.""" + async def test_function_exception_returns_none(self): + """An exception in the packet function returns (None, captured) — no raise.""" def failing(x: int) -> int: if x == 2: @@ -743,13 +747,25 @@ def failing(x: int) -> int: output_ch = Channel(buffer_size=16) await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + # The failing packet (x=2) is silently dropped; 4 of 5 succeed + assert len(results) == 4 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 1, 3, 4] - with pytest.raises(ExceptionGroup) as exc_info: - await pod.async_execute([input_ch.reader], output_ch.writer) + @pytest.mark.asyncio + async def test_direct_async_call_captures_failure(self): + """direct_async_call returns (None, captured) with success=False on error.""" - # Should contain the ValueError from the function - causes = exc_info.value.exceptions - assert any(isinstance(e, ValueError) and "boom" in str(e) for e in causes) + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result") + result, captured = await pf.direct_async_call(Packet({"x": 1})) + assert result is None + assert captured.success is False # --------------------------------------------------------------------------- diff --git a/tests/test_channels/test_copilot_review_issues.py b/tests/test_channels/test_copilot_review_issues.py index 3c786f09..f9891945 100644 --- a/tests/test_channels/test_copilot_review_issues.py +++ b/tests/test_channels/test_copilot_review_issues.py @@ -142,7 +142,7 @@ async def run_in_loop(): nonlocal creation_count with patch.object(ThreadPoolExecutor, "__init__", counting_init): for _ in range(3): - pf.direct_call(Packet({"x": 1, "y": 2})) + _result, _captured = pf.direct_call(Packet({"x": 1, "y": 2})) asyncio.run(run_in_loop()) # Current code creates a new executor per call, so creation_count == 3. @@ -213,7 +213,7 @@ def instrumented_call(packet): # Run inside an event loop to trigger the ThreadPoolExecutor path async def run_in_loop(): - pf.direct_call(Packet({"x": 5})) + _result, _captured = pf.direct_call(Packet({"x": 5})) asyncio.run(run_in_loop()) diff --git a/tests/test_channels/test_node_async_execute.py b/tests/test_channels/test_node_async_execute.py index f8d2f185..7d29c1a7 100644 --- a/tests/test_channels/test_node_async_execute.py +++ b/tests/test_channels/test_node_async_execute.py @@ -91,7 +91,7 @@ def double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - result = await cpf.async_call(packet) + result, _captured = await cpf.async_call(packet) assert result is not None assert result.as_dict()["result"] == 10 @@ -111,13 +111,13 @@ def double(x: int) -> int: packet = Packet({"x": 5}) # First call — computes - result1 = await cpf.async_call(packet) + result1, _captured1 = await cpf.async_call(packet) assert result1 is not None # Has RESULT_COMPUTED_FLAG assert result1.get_meta_value(cpf.RESULT_COMPUTED_FLAG, False) is True # Second call — should hit cache (no RESULT_COMPUTED_FLAG set to True) - result2 = await cpf.async_call(packet) + result2, _captured2 = await cpf.async_call(packet) assert result2 is not None assert result2.as_dict()["result"] == 10 # Cache hit should NOT have RESULT_COMPUTED_FLAG=True @@ -138,11 +138,11 @@ def counting_double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - await cpf.async_call(packet) + _result1, _captured1 = await cpf.async_call(packet) assert call_count == 1 # With skip_cache_lookup, should recompute - await cpf.async_call(packet, skip_cache_lookup=True) + _result2, _captured2 = await cpf.async_call(packet, skip_cache_lookup=True) assert call_count == 2 @pytest.mark.asyncio @@ -155,7 +155,7 @@ def double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - result = await cpf.async_call(packet, skip_cache_insert=True) + result, _captured = await cpf.async_call(packet, skip_cache_insert=True) assert result is not None assert result.as_dict()["result"] == 10 diff --git a/tests/test_core/packet_function/test_cached_packet_function.py b/tests/test_core/packet_function/test_cached_packet_function.py index 9a4218a1..5547cbfa 100644 --- a/tests/test_core/packet_function/test_cached_packet_function.py +++ b/tests/test_core/packet_function/test_cached_packet_function.py @@ -139,11 +139,11 @@ def test_returns_none_when_no_records(self, cached_pf): class TestCallCacheMiss: def test_returns_non_none_result(self, cached_pf, input_packet): - result = cached_pf.call(input_packet) + result, _captured = cached_pf.call(input_packet) assert result is not None def test_result_has_correct_value(self, cached_pf, input_packet): - result = cached_pf.call(input_packet) + result, _captured = cached_pf.call(input_packet) assert result["result"] == 7 # 3 + 4 def test_result_stored_in_database(self, cached_pf, input_packet, db): @@ -167,7 +167,7 @@ def test_get_all_cached_outputs_non_empty_after_call(self, cached_pf, input_pack class TestCallCacheHit: def test_second_call_returns_result(self, cached_pf, input_packet): cached_pf.call(input_packet) - result = cached_pf.call(input_packet) + result, _captured = cached_pf.call(input_packet) assert result is not None assert result["result"] == 7 @@ -204,7 +204,7 @@ def counting_add(x: int, y: int) -> int: class TestSkipCacheLookup: def test_skip_cache_lookup_still_returns_result(self, cached_pf, input_packet): cached_pf.call(input_packet) # populate cache - result = cached_pf.call(input_packet, skip_cache_lookup=True) + result, _captured = cached_pf.call(input_packet, skip_cache_lookup=True) assert result is not None assert result["result"] == 7 @@ -226,7 +226,7 @@ def test_skip_cache_lookup_adds_another_record(self, cached_pf, input_packet, db class TestSkipCacheInsert: def test_skip_cache_insert_returns_result(self, cached_pf, input_packet): - result = cached_pf.call(input_packet, skip_cache_insert=True) + result, _captured = cached_pf.call(input_packet, skip_cache_insert=True) assert result is not None assert result["result"] == 7 @@ -398,12 +398,12 @@ def test_get_execution_data_delegates(self, wrapper, inner_pf): assert wrapper.get_execution_data() == inner_pf.get_execution_data() def test_call_delegates(self, wrapper, input_packet): - result = wrapper.call(input_packet) + result, _captured = wrapper.call(input_packet) assert result is not None assert result["result"] == 7 # 3 + 4 def test_async_call_delegates_through_wrapper(self, wrapper, input_packet): - result = asyncio.run(wrapper.async_call(input_packet)) + result, _captured = asyncio.run(wrapper.async_call(input_packet)) assert result is not None assert result["result"] == 7 # 3 + 4 @@ -493,7 +493,7 @@ def test_inactive_inner_returns_none_and_does_not_store( ): inner_pf.set_active(False) cpf = CachedPacketFunction(inner_pf, result_database=db) - result = cpf.call(input_packet) + result, _captured = cpf.call(input_packet) assert result is None assert db.get_all_records(cpf.record_path) is None @@ -535,27 +535,27 @@ class TestResultComputedFlag: """Verify the meta flag that distinguishes fresh computation from cache hits.""" def test_cache_miss_sets_computed_true(self, cached_pf, input_packet): - result = cached_pf.call(input_packet) + result, _captured = cached_pf.call(input_packet) assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is True def test_cache_hit_sets_computed_false(self, cached_pf, input_packet): cached_pf.call(input_packet) # first call — populates cache - result = cached_pf.call(input_packet) # second call — cache hit + result, _captured = cached_pf.call(input_packet) # second call — cache hit assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is False def test_skip_cache_lookup_sets_computed_true(self, cached_pf, input_packet): cached_pf.call(input_packet) # populate cache - result = cached_pf.call(input_packet, skip_cache_lookup=True) + result, _captured = cached_pf.call(input_packet, skip_cache_lookup=True) assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is True def test_skip_cache_insert_sets_computed_true(self, cached_pf, input_packet): - result = cached_pf.call(input_packet, skip_cache_insert=True) + result, _captured = cached_pf.call(input_packet, skip_cache_insert=True) assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is True diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index 317a2cb1..fe7548df 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -60,13 +60,15 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) def execute_callable(self, fn, kwargs, executor_options=None): + from orcapod.pipeline.logging_capture import CapturedLogs + self.calls.append((fn, kwargs)) - return fn(**kwargs) + return fn(**kwargs), CapturedLogs(success=True) class PythonOnlyExecutor(PacketFunctionExecutorBase): @@ -83,7 +85,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return packet_function.direct_call(packet) @@ -101,7 +103,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": return packet_function.direct_call(packet) @@ -185,7 +187,7 @@ def test_execute_delegates_to_direct_call( add_pf: PythonPacketFunction, add_packet: Packet, ): - result = local_executor.execute(add_pf, add_packet) + result, _captured = local_executor.execute(add_pf, add_packet) assert result is not None assert result.as_dict()["result"] == 3 @@ -246,7 +248,7 @@ class TestExecutorRouting: def test_call_without_executor_uses_direct_call( self, add_pf: PythonPacketFunction, add_packet: Packet ): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 @@ -257,7 +259,7 @@ def test_call_with_executor_routes_through_executor( spy_executor: SpyExecutor, ): add_pf.executor = spy_executor - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy_executor.calls) == 1 @@ -269,7 +271,7 @@ def test_direct_call_bypasses_executor( spy_executor: SpyExecutor, ): add_pf.executor = spy_executor - result = add_pf.direct_call(add_packet) + result, _captured = add_pf.direct_call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 # Executor was NOT called @@ -343,7 +345,7 @@ class SimpleWrapper(PacketFunctionWrapper): pass wrapper = SimpleWrapper(add_pf, version="v0.0") - result = wrapper.call(add_packet) + result, _captured = wrapper.call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.calls) == 1 @@ -599,7 +601,9 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs + self.sync_calls.append(packet) return packet_function.direct_call(packet) @@ -607,17 +611,21 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, CapturedLogs]": self.async_calls.append(packet) return packet_function.direct_call(packet) def execute_callable(self, fn, kwargs, executor_options=None): + from orcapod.pipeline.logging_capture import CapturedLogs + self.sync_calls.append(kwargs) - return fn(**kwargs) + return fn(**kwargs), CapturedLogs(success=True) async def async_execute_callable(self, fn, kwargs, executor_options=None): + from orcapod.pipeline.logging_capture import CapturedLogs + self.async_calls.append(kwargs) - return fn(**kwargs) + return fn(**kwargs), CapturedLogs(success=True) class TestConcurrentIteration: @@ -736,15 +744,17 @@ def test_spy_executor_satisfies_protocol(self): def test_execute_callable_runs_function(self): executor = LocalExecutor() - result = executor.execute_callable(add, {"x": 3, "y": 4}) + result, captured = executor.execute_callable(add, {"x": 3, "y": 4}) assert result == 7 + assert captured.success is True def test_execute_callable_with_executor_options(self): executor = LocalExecutor() - result = executor.execute_callable( + result, captured = executor.execute_callable( add, {"x": 1, "y": 2}, executor_options={"num_cpus": 1} ) assert result == 3 + assert captured.success is True # --------------------------------------------------------------------------- @@ -801,7 +811,7 @@ def test_call_routes_through_execute_callable(self): spy = SpyExecutor() pf = PythonPacketFunction(add, output_keys="result", executor=spy) packet = Packet({"x": 1, "y": 2}) - result = pf.call(packet) + result, _captured = pf.call(packet) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.calls) == 1 @@ -812,7 +822,7 @@ def test_call_with_inactive_function_returns_none(self): pf = PythonPacketFunction(add, output_keys="result", executor=spy) pf.set_active(False) packet = Packet({"x": 1, "y": 2}) - result = pf.call(packet) + result, _captured = pf.call(packet) assert result is None assert len(spy.calls) == 0 @@ -823,7 +833,7 @@ def test_async_call_routes_through_async_execute_callable(self): spy = ConcurrentSpyExecutor() pf = PythonPacketFunction(add, output_keys="result", executor=spy) packet = Packet({"x": 1, "y": 2}) - result = asyncio.run(pf.async_call(packet)) + result, _captured = asyncio.run(pf.async_call(packet)) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.async_calls) == 1 @@ -844,16 +854,18 @@ async def async_add(x: int, y: int) -> int: return x + y executor = LocalExecutor() - result = executor.execute_callable(async_add, {"x": 5, "y": 3}) + result, captured = executor.execute_callable(async_add, {"x": 5, "y": 3}) assert result == 8 + assert captured.success is True def test_async_execute_callable_with_sync_fn(self): """LocalExecutor.async_execute_callable handles sync fns via run_in_executor.""" import asyncio executor = LocalExecutor() - result = asyncio.run(executor.async_execute_callable(add, {"x": 10, "y": 20})) + result, captured = asyncio.run(executor.async_execute_callable(add, {"x": 10, "y": 20})) assert result == 30 + assert captured.success is True def test_async_execute_callable_with_async_fn(self): """LocalExecutor.async_execute_callable awaits async functions directly.""" @@ -863,7 +875,8 @@ async def async_add(x: int, y: int) -> int: return x + y executor = LocalExecutor() - result = asyncio.run( + result, captured = asyncio.run( executor.async_execute_callable(async_add, {"x": 7, "y": 8}) ) assert result == 15 + assert captured.success is True diff --git a/tests/test_core/packet_function/test_packet_function.py b/tests/test_core/packet_function/test_packet_function.py index 4ab5cf1c..b06157cb 100644 --- a/tests/test_core/packet_function/test_packet_function.py +++ b/tests/test_core/packet_function/test_packet_function.py @@ -420,29 +420,29 @@ def test_set_active_true_re_enables(self, add_pf): class TestCall: def test_returns_packet_when_active(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) assert result is not None def test_output_has_correct_key(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) assert "result" in result.keys() def test_output_has_correct_value(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) assert result["result"] == 3 # 1 + 2 def test_source_info_contains_result_key(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) source = result.source_info() assert "result" in source def test_source_info_ends_with_key_name(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) source_str = result.source_info()["result"] assert source_str.endswith("::result") def test_source_info_contains_uri_components(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) source_str = result.source_info()["result"] for component in add_pf.uri: assert component in source_str @@ -450,7 +450,7 @@ def test_source_info_contains_uri_components(self, add_pf, add_packet): def test_source_info_record_id_is_uuid(self, add_pf, add_packet): import re - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) source_str = result.source_info()["result"] # The record_id segment is between the URI components and the key name # Format: uri_part1:uri_part2:..::record_id::key @@ -461,17 +461,18 @@ def test_source_info_record_id_is_uuid(self, add_pf, add_packet): def test_inactive_returns_none(self, add_pf, add_packet): add_pf.set_active(False) - assert add_pf.call(add_packet) is None + result, _captured = add_pf.call(add_packet) + assert result is None def test_multiple_output_keys(self, multi_pf): packet = Packet({"a": 3, "b": 4}) - result = multi_pf.call(packet) + result, _captured = multi_pf.call(packet) assert result["sum"] == 7 # 3 + 4 assert result["product"] == 12 # 3 * 4 def test_multiple_output_keys_source_info(self, multi_pf): packet = Packet({"a": 3, "b": 4}) - result = multi_pf.call(packet) + result, _captured = multi_pf.call(packet) source = result.source_info() assert "sum" in source assert "product" in source @@ -479,7 +480,7 @@ def test_multiple_output_keys_source_info(self, multi_pf): assert source["product"].endswith("::product") def test_output_packet_schema_applied(self, add_pf, add_packet): - result = add_pf.call(add_packet) + result, _captured = add_pf.call(add_packet) assert result is not None # schema from the packet function should carry through schema = result.schema() @@ -530,7 +531,7 @@ def returns_one(a, b): class TestAsyncCall: def test_async_call_returns_correct_result(self, add_pf, add_packet): - result = asyncio.run(add_pf.async_call(add_packet)) + result, _captured = asyncio.run(add_pf.async_call(add_packet)) assert result is not None assert result.as_dict()["result"] == 3 # 1 + 2 @@ -596,27 +597,28 @@ async def bad(*args: int) -> int: class TestAsyncFunctionSyncCall: def test_direct_call_returns_correct_result(self, async_add_pf, add_packet): - result = async_add_pf.direct_call(add_packet) + result, _captured = async_add_pf.direct_call(add_packet) assert result is not None assert result["result"] == 3 def test_call_returns_correct_result(self, async_add_pf, add_packet): - result = async_add_pf.call(add_packet) + result, _captured = async_add_pf.call(add_packet) assert result is not None assert result["result"] == 3 def test_inactive_returns_none(self, async_add_pf, add_packet): async_add_pf.set_active(False) - assert async_add_pf.call(add_packet) is None + result, _captured = async_add_pf.call(add_packet) + assert result is None def test_multiple_outputs(self, async_multi_pf): packet = Packet({"a": 3, "b": 4}) - result = async_multi_pf.call(packet) + result, _captured = async_multi_pf.call(packet) assert result["sum"] == 7 assert result["product"] == 12 def test_source_info_present(self, async_add_pf, add_packet): - result = async_add_pf.call(add_packet) + result, _captured = async_add_pf.call(add_packet) source = result.source_info() assert "result" in source assert source["result"].endswith("::result") @@ -629,22 +631,22 @@ def test_source_info_present(self, async_add_pf, add_packet): class TestAsyncFunctionAsyncCall: def test_direct_async_call_awaits_directly(self, async_add_pf, add_packet): - result = asyncio.run(async_add_pf.direct_async_call(add_packet)) + result, _captured = asyncio.run(async_add_pf.direct_async_call(add_packet)) assert result is not None assert result["result"] == 3 def test_async_call_returns_correct_result(self, async_add_pf, add_packet): - result = asyncio.run(async_add_pf.async_call(add_packet)) + result, _captured = asyncio.run(async_add_pf.async_call(add_packet)) assert result is not None assert result["result"] == 3 def test_inactive_returns_none(self, async_add_pf, add_packet): async_add_pf.set_active(False) - result = asyncio.run(async_add_pf.async_call(add_packet)) + result, _captured = asyncio.run(async_add_pf.async_call(add_packet)) assert result is None def test_multiple_outputs(self, async_multi_pf): packet = Packet({"a": 3, "b": 4}) - result = asyncio.run(async_multi_pf.async_call(packet)) + result, _captured = asyncio.run(async_multi_pf.async_call(packet)) assert result["sum"] == 7 assert result["product"] == 12 diff --git a/tests/test_core/test_regression_fixes.py b/tests/test_core/test_regression_fixes.py index 793f0efe..c5d3a298 100644 --- a/tests/test_core/test_regression_fixes.py +++ b/tests/test_core/test_regression_fixes.py @@ -78,13 +78,15 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> tuple[PacketProtocol | None, Any]: self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) def execute_callable(self, fn, kwargs, executor_options=None): + from orcapod.pipeline.logging_capture import CapturedLogs + self.calls.append((fn, kwargs)) - return fn(**kwargs) + return fn(**kwargs), CapturedLogs(success=True) # =========================================================================== @@ -97,8 +99,10 @@ class TestAsyncExecuteChannelCloseOnError: @pytest.mark.asyncio async def test_unary_operator_closes_channel_on_error(self): - """SelectPacketColumns with a column that doesn't exist should fail, - but the output channel must still be closed.""" + """When a packet function raises, direct_call catches the exception + and returns (None, captured_failure). process_packet discards the + captured logs and returns (tag, None). The output channel is closed + normally and no exception propagates.""" def failing(x: int) -> int: raise ValueError("boom") @@ -112,15 +116,14 @@ def failing(x: int) -> int: await feed_stream_to_channel(stream, input_ch) - with pytest.raises(ExceptionGroup): - await pod.async_execute([input_ch.reader], output_ch.writer) + # No exception raised — failures are caught inside direct_call + await pod.async_execute([input_ch.reader], output_ch.writer) - # The output channel should be closed despite the error. - # Attempting to collect should return whatever was sent before error, - # and not hang forever. + # The output channel should be closed. results = await output_ch.reader.collect() - # We don't assert content, just that it doesn't hang. + # All packets failed so no results were sent assert isinstance(results, list) + assert len(results) == 0 @pytest.mark.asyncio async def test_operator_closes_channel_on_static_process_error(self): @@ -192,7 +195,7 @@ def test_direct_call_does_not_invoke_executor(self): _, spy, wrapper = self._make_add_pf_with_spy() packet = Packet({"x": 3, "y": 4}) - result = wrapper.direct_call(packet) + result, _captured = wrapper.direct_call(packet) assert result is not None assert result.as_dict()["result"] == 7 @@ -204,7 +207,7 @@ async def test_direct_async_call_does_not_invoke_executor(self): _, spy, wrapper = self._make_add_pf_with_spy() packet = Packet({"x": 3, "y": 4}) - result = await wrapper.direct_async_call(packet) + result, _captured = await wrapper.direct_async_call(packet) assert result is not None assert result.as_dict()["result"] == 7 @@ -215,7 +218,7 @@ def test_call_still_routes_through_executor(self): _, spy, wrapper = self._make_add_pf_with_spy() packet = Packet({"x": 3, "y": 4}) - result = wrapper.call(packet) + result, _captured = wrapper.call(packet) assert result is not None assert result.as_dict()["result"] == 7 @@ -308,7 +311,7 @@ def double(x: int) -> int: # Patch async_call to use our concurrency-tracking function original_async_call = pf.async_call - async def tracked_async_call(packet: PacketProtocol) -> PacketProtocol | None: + async def tracked_async_call(packet: PacketProtocol) -> tuple: nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) @@ -458,20 +461,22 @@ def test_ensure_ray_initialized_skips_when_already_initialized(self): mock_ray.init.assert_not_called() def test_async_execute_uses_wrap_future(self): - """async_execute should use ref.future() + asyncio.wrap_future, - not bare 'await ref'.""" + """async_execute_callable should use ref.future() + asyncio.wrap_future, + not bare 'await ref'. async_execute delegates to async_execute_callable.""" import inspect from orcapod.core.executors.ray import RayExecutor - source = inspect.getsource(RayExecutor.async_execute) + source = inspect.getsource(RayExecutor.async_execute_callable) assert "ref.future()" in source, ( - "async_execute should use ref.future() for asyncio compatibility" + "async_execute_callable should use ref.future() for asyncio compatibility" + ) + assert "wrap_future" in source, ( + "async_execute_callable should use asyncio.wrap_future" ) - assert "wrap_future" in source, "async_execute should use asyncio.wrap_future" # Should NOT do bare 'await ref' assert "return await ref\n" not in source, ( - "async_execute should not use bare 'await ref'" + "async_execute_callable should not use bare 'await ref'" ) diff --git a/tests/test_core/test_result_cache.py b/tests/test_core/test_result_cache.py index cda78270..16d6a404 100644 --- a/tests/test_core/test_result_cache.py +++ b/tests/test_core/test_result_cache.py @@ -48,7 +48,7 @@ def _compute_and_store( cache: ResultCache, pf: PythonPacketFunction, input_packet: Packet ): """Helper: compute output and store in cache.""" - output = pf.direct_call(input_packet) + output, _captured = pf.direct_call(input_packet) assert output is not None cache.store( input_packet, @@ -105,7 +105,7 @@ def test_same_packet_different_record_path_is_miss(self): pf = _make_pf() input_pkt = Packet({"x": 10}) - output = pf.direct_call(input_pkt) + output, _captured = pf.direct_call(input_pkt) cache_a.store( input_pkt, output, @@ -130,7 +130,7 @@ def test_most_recent_wins(self): time.sleep(0.01) # ensure different timestamp # Store a second result for the same input (simulating recomputation) - output2 = pf.direct_call(input_pkt) + output2, _captured = pf.direct_call(input_pkt) cache.store( input_pkt, output2, diff --git a/tests/test_pipeline/test_logging_capture.py b/tests/test_pipeline/test_logging_capture.py new file mode 100644 index 00000000..1b9ab9bf --- /dev/null +++ b/tests/test_pipeline/test_logging_capture.py @@ -0,0 +1,332 @@ +"""Tests for logging_capture — CapturedLogs, tee streams, LocalCaptureContext, +and the Ray worker-side wrapper.""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import io +import logging +import sys + +import pytest + +from orcapod.pipeline.logging_capture import ( + CapturedLogs, + ContextLocalTeeStream, + ContextVarLoggingHandler, + LocalCaptureContext, + _log_capture, + _stderr_capture, + _stdout_capture, + install_capture_streams, +) + + +# --------------------------------------------------------------------------- +# CapturedLogs +# --------------------------------------------------------------------------- + + +class TestCapturedLogs: + def test_defaults(self): + c = CapturedLogs() + assert c.stdout == "" + assert c.stderr == "" + assert c.python_logs == "" + assert c.traceback is None + assert c.success is True + + def test_fields(self): + c = CapturedLogs(stdout="out", stderr="err", traceback="tb", success=False) + assert c.stdout == "out" + assert c.stderr == "err" + assert c.traceback == "tb" + assert c.success is False + + +# --------------------------------------------------------------------------- +# ContextLocalTeeStream +# --------------------------------------------------------------------------- + + +class TestContextLocalTeeStream: + def test_tees_to_buffer_and_original(self): + buf = io.StringIO() + original = io.StringIO() + token = _stdout_capture.set(buf) + tee = ContextLocalTeeStream(original, _stdout_capture) + try: + tee.write("hello") + tee.flush() + finally: + _stdout_capture.reset(token) + + assert buf.getvalue() == "hello" + assert original.getvalue() == "hello" + + def test_no_buffer_writes_only_original(self): + original = io.StringIO() + tee = ContextLocalTeeStream(original, _stdout_capture) + # No ContextVar buffer set — should just write to original + tee.write("world") + assert original.getvalue() == "world" + + def test_isolation_between_contexts(self): + """Two concurrent contexts each capture only their own writes.""" + buf_a = io.StringIO() + buf_b = io.StringIO() + original = io.StringIO() + tee = ContextLocalTeeStream(original, _stdout_capture) + + token_a = _stdout_capture.set(buf_a) + tee.write("from-A") + _stdout_capture.reset(token_a) + + token_b = _stdout_capture.set(buf_b) + tee.write("from-B") + _stdout_capture.reset(token_b) + + assert buf_a.getvalue() == "from-A" + assert buf_b.getvalue() == "from-B" + assert original.getvalue() == "from-Afrom-B" + + def test_proxy_attributes(self): + original = io.StringIO() + tee = ContextLocalTeeStream(original, _stdout_capture) + # Attribute delegation should work + assert tee.writable() == original.writable() + + +# --------------------------------------------------------------------------- +# ContextVarLoggingHandler +# --------------------------------------------------------------------------- + + +class TestContextVarLoggingHandler: + def test_captures_to_active_buffer(self): + handler = ContextVarLoggingHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + buf: list[str] = [] + token = _log_capture.set(buf) + try: + record = logging.LogRecord( + name="test", level=logging.INFO, + pathname="", lineno=0, msg="hello log", args=(), exc_info=None + ) + handler.emit(record) + finally: + _log_capture.reset(token) + + assert buf == ["hello log"] + + def test_discards_when_no_buffer(self): + handler = ContextVarLoggingHandler() + handler.setFormatter(logging.Formatter("%(message)s")) + # No ContextVar buffer — emit should be a no-op + record = logging.LogRecord( + name="test", level=logging.INFO, + pathname="", lineno=0, msg="ignored", args=(), exc_info=None + ) + handler.emit(record) # should not raise + + +# --------------------------------------------------------------------------- +# LocalCaptureContext +# --------------------------------------------------------------------------- + + +class TestLocalCaptureContext: + def test_captures_nothing_without_install(self): + """Without install_capture_streams(), capture returns empty strings.""" + ctx = LocalCaptureContext() + with ctx: + print("not captured") + captured = ctx.get_captured(success=True) + # Since streams are not installed, buffer is empty + assert captured.stdout == "" + assert captured.success is True + + def test_captures_after_install(self): + install_capture_streams() + ctx = LocalCaptureContext() + with ctx: + print("captured output") + captured = ctx.get_captured(success=True) + assert "captured output" in captured.stdout + assert captured.success is True + + def test_captures_stderr_after_install(self): + install_capture_streams() + ctx = LocalCaptureContext() + with ctx: + print("error output", file=sys.stderr) + captured = ctx.get_captured(success=True) + assert "error output" in captured.stderr + + def test_exception_does_not_suppress(self): + install_capture_streams() + ctx = LocalCaptureContext() + with pytest.raises(ValueError, match="boom"): + with ctx: + print("before error") + raise ValueError("boom") + + def test_captures_partial_output_on_exception(self): + install_capture_streams() + ctx = LocalCaptureContext() + try: + with ctx: + print("before error") + raise ValueError("boom") + except ValueError: + pass + captured = ctx.get_captured(success=False, tb="traceback text") + assert "before error" in captured.stdout + assert captured.success is False + assert captured.traceback == "traceback text" + + def test_isolation_between_concurrent_asyncio_tasks(self): + """Each asyncio task captures only its own output.""" + install_capture_streams() + + async def run_task(label: str) -> str: + ctx = LocalCaptureContext() + with ctx: + await asyncio.sleep(0) # yield to other tasks + print(label) + await asyncio.sleep(0) + return ctx.get_captured(success=True).stdout + + async def main(): + results = await asyncio.gather( + run_task("task-A"), run_task("task-B") + ) + return results + + a_out, b_out = asyncio.run(main()) + assert "task-A" in a_out + assert "task-B" not in a_out + assert "task-B" in b_out + assert "task-A" not in b_out + + def test_isolation_between_threads(self): + """Each thread captures only its own output.""" + install_capture_streams() + results: dict[str, str] = {} + + def worker(label: str) -> None: + ctx = LocalCaptureContext() + with ctx: + print(label) + results[label] = ctx.get_captured(success=True).stdout + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + futs = [pool.submit(worker, lbl) for lbl in ("thread-X", "thread-Y")] + for f in futs: + f.result() + + assert "thread-X" in results["thread-X"] + assert "thread-Y" not in results["thread-X"] + assert "thread-Y" in results["thread-Y"] + assert "thread-X" not in results["thread-Y"] + + def test_python_logging_captured(self): + install_capture_streams() + test_logger = logging.getLogger("test.capture") + ctx = LocalCaptureContext() + with ctx: + test_logger.warning("log message") + captured = ctx.get_captured(success=True) + assert "log message" in captured.python_logs + + def test_resets_contextvar_after_exit(self): + install_capture_streams() + ctx = LocalCaptureContext() + with ctx: + pass + # ContextVars should be None after context exits + assert _stdout_capture.get() is None + assert _stderr_capture.get() is None + assert _log_capture.get() is None + + +# --------------------------------------------------------------------------- +# RayExecutor._make_capture_wrapper (local process — no Ray cluster needed) +# --------------------------------------------------------------------------- + + +class TestRayCaptureWrapper: + """Tests exercise :meth:`RayExecutor._make_capture_wrapper` directly. + + The wrapper runs in the same process (no Ray cluster needed) and + returns a plain 6-tuple that the driver reassembles into CapturedLogs. + """ + + @staticmethod + def _call_wrapper(fn, kwargs): + """Helper: call the Ray capture wrapper and return (raw, CapturedLogs).""" + from orcapod.core.executors.ray import RayExecutor + + wrapper = RayExecutor._make_capture_wrapper() + raw, stdout, stderr, python_logs, tb, success = wrapper(fn, kwargs) + return raw, CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) + + def test_captures_stdout(self): + def fn(x): + print(f"result={x}") + return x * 2 + + raw, captured = self._call_wrapper(fn, {"x": 3}) + assert raw == 6 + assert "result=3" in captured.stdout + assert captured.success is True + assert captured.traceback is None + + def test_captures_stderr(self): + def fn(): + import sys + print("err line", file=sys.stderr) + return 1 + + raw, captured = self._call_wrapper(fn, {}) + assert raw == 1 + assert "err line" in captured.stderr + + def test_captures_exception(self): + def fn(): + raise RuntimeError("worker blew up") + + raw, captured = self._call_wrapper(fn, {}) + assert raw is None + assert captured.success is False + assert "RuntimeError" in captured.traceback + assert "worker blew up" in captured.traceback + + def test_captures_python_logging(self): + def fn(): + logging.getLogger("ray_wrapper_test").error("log from worker") + return True + + raw, captured = self._call_wrapper(fn, {}) + assert raw is True + assert "log from worker" in captured.python_logs + + def test_restores_fds_after_exception(self): + """File descriptors 1/2 must be restored even when fn raises.""" + import os + + original_stdout_fd = os.dup(1) + try: + def fn(): + raise ValueError("oops") + + self._call_wrapper(fn, {}) + + # Write to fd 1 — should succeed (not be pointing at a closed temp file) + os.write(1, b"") + finally: + os.close(original_stdout_fd) diff --git a/tests/test_pipeline/test_logging_observer_integration.py b/tests/test_pipeline/test_logging_observer_integration.py new file mode 100644 index 00000000..5d298c66 --- /dev/null +++ b/tests/test_pipeline/test_logging_observer_integration.py @@ -0,0 +1,392 @@ +"""Integration tests for LoggingObserver with real pipelines. + +Exercises the full logging pipeline: capture → CapturedLogs return type → +FunctionNode → observer → PacketLogger → database, using InMemoryArrowDatabase +and real Pipeline objects. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources.arrow_table_source import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import ( + AsyncPipelineOrchestrator, + Pipeline, + SyncPipelineOrchestrator, +) +from orcapod.pipeline.logging_observer import LoggingObserver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table({ + "id": pa.array([str(i) for i in range(n)], type=pa.large_string()), + "x": pa.array([10 * (i + 1) for i in range(n)], type=pa.int64()), + }) + return ArrowTableSource(table, tag_columns=["id"]) + + +def _get_function_node(pipeline: Pipeline): + """Return the first function node from the pipeline graph.""" + import networkx as nx + + for node in nx.topological_sort(pipeline._node_graph): + if node.node_type == "function": + return node + raise RuntimeError("No function node found") + + +# --------------------------------------------------------------------------- +# 1. Sync pipeline — succeeding packets → log rows with stdout captured +# --------------------------------------------------------------------------- + + +class TestSyncPipelineSuccessLogs: + def test_success_logs_captured(self): + db = InMemoryArrowDatabase() + source = _make_source() + + def double(x: int) -> int: + print(f"doubling {x}") + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_logs", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + assert logs.num_rows == 3 + success_col = logs.column("success").to_pylist() + assert all(success_col) + + +# --------------------------------------------------------------------------- +# 2. Failing packets → success=False, traceback populated +# --------------------------------------------------------------------------- + + +class TestFailingPacketsLogged: + def test_failure_logged_with_traceback(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_fail", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + assert logs.num_rows == 2 + for row_idx in range(logs.num_rows): + assert logs.column("success").to_pylist()[row_idx] is False + tb = logs.column("traceback").to_pylist()[row_idx] + assert tb is not None + assert "ValueError" in tb + assert "boom" in tb + + +# --------------------------------------------------------------------------- +# 3. Pipeline-path-mirrored storage +# --------------------------------------------------------------------------- + + +class TestPipelinePathMirroredStorage: + def test_log_path_mirrors_pipeline_path(self): + db = InMemoryArrowDatabase() + source = _make_source(1) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_mirror", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + pp = fn_node.pipeline_path + expected_log_path = pp[:1] + ("logs",) + pp[1:] + + # Verify the log path is correct by reading directly from the DB + raw = db.get_all_records(expected_log_path) + assert raw is not None + assert raw.num_rows == 1 + + +# --------------------------------------------------------------------------- +# 4. Queryable tag columns (not JSON) +# --------------------------------------------------------------------------- + + +class TestQueryableTagColumns: + def test_tag_columns_in_log_table(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def identity(x: int) -> int: + return x + + pf = PythonPacketFunction(identity, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_tags", pipeline_database=db) + with pipeline: + pod(source, label="ident") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + # "id" tag column should be a separate column, not JSON + assert "id" in logs.column_names + assert "tags" not in logs.column_names + id_values = sorted(logs.column("id").to_pylist()) + assert id_values == ["0", "1"] + + +# --------------------------------------------------------------------------- +# 5. Async orchestrator logs +# --------------------------------------------------------------------------- + + +class TestAsyncOrchestratorLogs: + def test_async_pipeline_captures_logs(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_async_logs", pipeline_database=db) + with pipeline: + pod(source, label="doubler") + + obs = LoggingObserver(log_database=db) + orch = AsyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + assert logs.num_rows == 2 + assert all(logs.column("success").to_pylist()) + + +# --------------------------------------------------------------------------- +# 6. fail_fast error policy +# --------------------------------------------------------------------------- + + +class TestFailFastErrorPolicy: + def test_fail_fast_aborts_and_logs(self): + db = InMemoryArrowDatabase() + source = _make_source(3) + + def failing(x: int) -> int: + raise RuntimeError("crash") + + pf = PythonPacketFunction(failing, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_failfast", pipeline_database=db) + with pipeline: + pod(source, label="failing") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs, error_policy="fail_fast") + + with pytest.raises(RuntimeError, match="crash"): + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + # At least the first crash should be logged before abort + assert logs is not None + assert logs.num_rows >= 1 + assert logs.column("success").to_pylist()[0] is False + + +# --------------------------------------------------------------------------- +# 7. Mixed success/failure — correct log row per packet +# --------------------------------------------------------------------------- + + +class TestMixedSuccessFailure: + def test_mixed_results_logged_correctly(self): + db = InMemoryArrowDatabase() + source = ArrowTableSource( + pa.table({ + "id": pa.array(["0", "1", "2"], type=pa.large_string()), + "x": pa.array([10, -1, 30], type=pa.int64()), + }), + tag_columns=["id"], + ) + + def safe_div(x: int) -> float: + return 100 / x + + pf = PythonPacketFunction(safe_div, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="test_mixed", pipeline_database=db) + with pipeline: + pod(source, label="divider") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + fn_node = _get_function_node(pipeline) + logs = obs.get_logs(pipeline_path=fn_node.pipeline_path) + + assert logs is not None + # x=10 succeeds, x=-1 succeeds (100/-1 = -100), x=30 succeeds + # All three should succeed since division by these values is fine + assert logs.num_rows == 3 + assert all(logs.column("success").to_pylist()) + + +# --------------------------------------------------------------------------- +# 8. Multiple function nodes — each gets own log table +# --------------------------------------------------------------------------- + + +class TestMultipleFunctionNodesSeparateLogs: + def test_two_nodes_separate_log_tables(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result") + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final") + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_multi", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + assert len(fn_nodes) == 2 + + logs1 = obs.get_logs(pipeline_path=fn_nodes[0].pipeline_path) + logs2 = obs.get_logs(pipeline_path=fn_nodes[1].pipeline_path) + + assert logs1 is not None + assert logs2 is not None + assert logs1.num_rows == 2 + assert logs2.num_rows == 2 + + # Verify they are at different paths + assert fn_nodes[0].pipeline_path != fn_nodes[1].pipeline_path + + +# --------------------------------------------------------------------------- +# 9. get_logs(pipeline_path) retrieves node-specific logs +# --------------------------------------------------------------------------- + + +class TestGetLogsNodeSpecific: + def test_get_logs_filters_by_node(self): + db = InMemoryArrowDatabase() + source = _make_source(2) + + def double(x: int) -> int: + return x * 2 + + def triple(result: int) -> int: + return result * 3 + + pf1 = PythonPacketFunction(double, output_keys="result") + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(triple, output_keys="final") + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="test_filter", pipeline_database=db) + with pipeline: + s1 = pod1(source, label="doubler") + pod2(s1, label="tripler") + + obs = LoggingObserver(log_database=db) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + import networkx as nx + + fn_nodes = [ + n + for n in nx.topological_sort(pipeline._node_graph) + if n.node_type == "function" + ] + + # Each node's logs contain only that node's label + logs1 = obs.get_logs(pipeline_path=fn_nodes[0].pipeline_path) + logs2 = obs.get_logs(pipeline_path=fn_nodes[1].pipeline_path) + + labels1 = set(logs1.column("node_label").to_pylist()) + labels2 = set(logs2.column("node_label").to_pylist()) + + assert labels1 == {"doubler"} + assert labels2 == {"tripler"} diff --git a/tests/test_pipeline/test_node_protocols.py b/tests/test_pipeline/test_node_protocols.py index e05c743c..65681222 100644 --- a/tests/test_pipeline/test_node_protocols.py +++ b/tests/test_pipeline/test_node_protocols.py @@ -257,6 +257,11 @@ def on_packet_start(self, n, t, p): events.append(("packet_start",)) def on_packet_end(self, n, t, ip, op, cached): events.append(("packet_end", cached)) + def on_packet_crash(self, n, t, p, exc): + pass + def create_packet_logger(self, n, t, p, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER input_stream = node._input_stream result = node.execute(input_stream, observer=Obs()) @@ -329,6 +334,11 @@ def on_packet_start(self, n, t, p): events.append("pkt_start") def on_packet_end(self, n, t, ip, op, cached): events.append("pkt_end") + def on_packet_crash(self, n, t, p, exc): + pass + def create_packet_logger(self, n, t, p, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER input_ch = Channel(buffer_size=16) output_ch = Channel(buffer_size=16) diff --git a/tests/test_pipeline/test_observer.py b/tests/test_pipeline/test_observer.py index 91237ca5..91a50745 100644 --- a/tests/test_pipeline/test_observer.py +++ b/tests/test_pipeline/test_observer.py @@ -1,8 +1,12 @@ -"""Tests for ExecutionObserver protocol and NoOpObserver.""" +"""Tests for ExecutionObserverProtocol, NoOpObserver, and NoOpLogger.""" from __future__ import annotations -from orcapod.pipeline.observer import ExecutionObserver, NoOpObserver +from orcapod.pipeline.observer import NoOpLogger, NoOpObserver, _NOOP_LOGGER +from orcapod.protocols.observability_protocols import ( + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) class TestNoOpObserver: @@ -10,20 +14,52 @@ class TestNoOpObserver: def test_satisfies_protocol(self): observer = NoOpObserver() - assert isinstance(observer, ExecutionObserver) + assert isinstance(observer, ExecutionObserverProtocol) + + def test_on_run_start_noop(self): + NoOpObserver().on_run_start("run-123") + + def test_on_run_end_noop(self): + NoOpObserver().on_run_end("run-123") def test_on_node_start_noop(self): - observer = NoOpObserver() - observer.on_node_start(None) # type: ignore[arg-type] + NoOpObserver().on_node_start(None) # type: ignore[arg-type] def test_on_node_end_noop(self): - observer = NoOpObserver() - observer.on_node_end(None) # type: ignore[arg-type] + NoOpObserver().on_node_end(None) # type: ignore[arg-type] def test_on_packet_start_noop(self): - observer = NoOpObserver() - observer.on_packet_start(None, None, None) # type: ignore[arg-type] + NoOpObserver().on_packet_start(None, None, None) # type: ignore[arg-type] def test_on_packet_end_noop(self): - observer = NoOpObserver() - observer.on_packet_end(None, None, None, None, cached=False) # type: ignore[arg-type] + NoOpObserver().on_packet_end(None, None, None, None, cached=False) # type: ignore[arg-type] + + def test_on_packet_crash_noop(self): + NoOpObserver().on_packet_crash(None, None, None, RuntimeError("boom")) # type: ignore[arg-type] + + def test_create_packet_logger_returns_noop(self): + logger = NoOpObserver().create_packet_logger(None, None, None) # type: ignore[arg-type] + assert logger is _NOOP_LOGGER + + def test_create_packet_logger_satisfies_protocol(self): + logger = NoOpObserver().create_packet_logger(None, None, None) # type: ignore[arg-type] + assert isinstance(logger, PacketExecutionLoggerProtocol) + + +class TestNoOpLogger: + """NoOpLogger satisfies the protocol and discards everything.""" + + def test_satisfies_protocol(self): + assert isinstance(NoOpLogger(), PacketExecutionLoggerProtocol) + + def test_record_noop(self): + from orcapod.pipeline.logging_capture import CapturedLogs + + NoOpLogger().record(CapturedLogs(stdout="hello", success=True)) + + def test_noop_logger_singleton_identity(self): + # The singleton returned by create_packet_logger is always the same object. + obs = NoOpObserver() + l1 = obs.create_packet_logger(None, None, None) # type: ignore[arg-type] + l2 = obs.create_packet_logger(None, None, None) # type: ignore[arg-type] + assert l1 is l2 diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index 6c6d39f5..40459810 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -428,9 +428,10 @@ def test_single_terminal_source(self): class TestAsyncOrchestratorErrorPropagation: - """Node failures should propagate correctly.""" + """Failed packets do not abort the pipeline; they are handled per-packet.""" - def test_node_failure_propagates(self): + def test_node_failure_does_not_abort_pipeline(self): + """A crashing packet function is skipped; the pipeline completes normally.""" def failing_fn(value: int) -> int: raise ValueError("intentional failure") @@ -445,8 +446,36 @@ def failing_fn(value: int) -> int: pipeline.compile() orch = AsyncPipelineOrchestrator() - with pytest.raises(ExceptionGroup): - orch.run(pipeline._node_graph) + # Pipeline must complete without raising; failing packet is silently dropped. + orch.run(pipeline._node_graph) + + def test_node_failure_calls_on_packet_crash(self): + """When an observer is set, on_packet_crash is called for the failing packet.""" + from orcapod.pipeline.observer import NoOpObserver + + def failing_fn(value: int) -> int: + raise ValueError("intentional failure") + + src = _make_source("key", "value", {"key": ["a"], "value": [1]}) + pf = PythonPacketFunction(failing_fn, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="error2", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="failer") + + crashes = [] + + class CrashRecorder(NoOpObserver): + def on_packet_crash(self, node, tag, packet, error): + crashes.append(error) + + pipeline.compile() + orch = AsyncPipelineOrchestrator(observer=CrashRecorder()) + orch.run(pipeline._node_graph) + + assert len(crashes) == 1 + assert isinstance(crashes[0], (ValueError, RuntimeError)) class TestAsyncOrchestratorObserverInjection: @@ -465,17 +494,20 @@ def test_linear_pipeline_observer_hooks(self): events = [] class RecordingObserver: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): events.append(("node_start", node.node_type)) - def on_node_end(self, node): events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): events.append(("packet_start", node.node_type)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events.append(("packet_end", node.node_type, cached)) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER pipeline.compile() orch = AsyncPipelineOrchestrator(observer=RecordingObserver()) @@ -517,17 +549,20 @@ def double_val(val: int) -> int: events = [] class RecordingObserver: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): events.append(("node_start", node.node_type)) - def on_node_end(self, node): events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): events.append(("packet_start", node.node_type)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events.append(("packet_end", node.node_type)) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER pipeline.compile() orch = AsyncPipelineOrchestrator(observer=RecordingObserver()) diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 99ffd956..458680d9 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1164,7 +1164,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, Any]": self.sync_calls.append(packet) return packet_function.direct_call(packet) @@ -1172,17 +1172,21 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> PacketProtocol | None: + ) -> "tuple[PacketProtocol | None, Any]": self.async_calls.append(packet) return packet_function.direct_call(packet) def execute_callable(self, fn, kwargs, executor_options=None): + from orcapod.pipeline.logging_capture import CapturedLogs + self.sync_calls.append(kwargs) - return fn(**kwargs) + return fn(**kwargs), CapturedLogs(success=True) async def async_execute_callable(self, fn, kwargs, executor_options=None): + from orcapod.pipeline.logging_capture import CapturedLogs + self.async_calls.append(kwargs) - return fn(**kwargs) + return fn(**kwargs), CapturedLogs(success=True) def with_options(self, **opts: Any) -> "_MockExecutor": return _MockExecutor(opts={**self.opts, **opts}) diff --git a/tests/test_pipeline/test_sync_orchestrator.py b/tests/test_pipeline/test_sync_orchestrator.py index cd5c7974..0bdcba10 100644 --- a/tests/test_pipeline/test_sync_orchestrator.py +++ b/tests/test_pipeline/test_sync_orchestrator.py @@ -130,17 +130,20 @@ def test_observer_hooks_fire(self): events = [] class RecordingObserver: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): events.append(("node_start", node.node_type)) - def on_node_end(self, node): events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): events.append(("packet_start",)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events.append(("packet_end", cached)) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER orch = SyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) @@ -204,17 +207,20 @@ def test_run_with_explicit_orchestrator(self): events = [] class RecordingObserver: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): events.append(("node_start", node.node_type)) - def on_node_end(self, node): events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): events.append(("packet_start",)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events.append(("packet_end",)) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER orch = SyncPipelineOrchestrator(observer=RecordingObserver()) pipeline.run(orchestrator=orch) @@ -409,17 +415,20 @@ def double_val(val: int) -> int: events = [] class RecordingObserver: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): events.append(("node_start", node.node_type)) - def on_node_end(self, node): events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): events.append(("packet_start", node.node_type)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events.append(("packet_end", node.node_type, cached)) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER orch = SyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) @@ -455,12 +464,18 @@ def test_function_node_cached_flag(self): events1 = [] class Obs1: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): pass def on_node_end(self, node): pass def on_packet_start(self, node, tag, packet): pass def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): if node.node_type == "function": events1.append(cached) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER SyncPipelineOrchestrator(observer=Obs1()).run(pipeline._node_graph) assert events1 == [False] @@ -469,12 +484,18 @@ def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): events2 = [] class Obs2: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): pass def on_node_end(self, node): pass def on_packet_start(self, node, tag, packet): pass def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): if node.node_type == "function": events2.append(cached) + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER SyncPipelineOrchestrator(observer=Obs2()).run(pipeline._node_graph) assert events2 == [True] @@ -494,12 +515,18 @@ def test_diamond_dag_observer_event_order(self): node_order = [] class OrderObserver: + def on_run_start(self, run_id): pass + def on_run_end(self, run_id): pass def on_node_start(self, node): node_order.append(("start", node.node_type)) def on_node_end(self, node): node_order.append(("end", node.node_type)) def on_packet_start(self, node, tag, packet): pass def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): pass + def on_packet_crash(self, node, tag, packet, exc): pass + def create_packet_logger(self, node, tag, packet, **kwargs): + from orcapod.pipeline.observer import _NOOP_LOGGER + return _NOOP_LOGGER SyncPipelineOrchestrator(observer=OrderObserver()).run(pipeline._node_graph) From 7b5dd434456fcc6a5b2f6487b43d68b3f9382b77 Mon Sep 17 00:00:00 2001 From: Brian Arnold Date: Wed, 18 Mar 2026 17:10:05 +0000 Subject: [PATCH 2/2] refactor(logging): replace return-path CapturedLogs with dependency-injected logger --- src/orcapod/core/cached_function_pod.py | 82 ++----- src/orcapod/core/executors/base.py | 51 ++-- src/orcapod/core/executors/local.py | 65 +++-- src/orcapod/core/executors/ray.py | 66 ++++-- src/orcapod/core/function_pod.py | 45 ++-- src/orcapod/core/nodes/function_node.py | 199 ++++++++-------- src/orcapod/core/nodes/operator_node.py | 18 +- src/orcapod/core/nodes/source_node.py | 12 +- src/orcapod/core/packet_function.py | 222 +++++++++--------- src/orcapod/core/packet_function_proxy.py | 27 ++- src/orcapod/pipeline/logging_observer.py | 122 +++++++++- src/orcapod/pipeline/observer.py | 17 +- .../protocols/core_protocols/executor.py | 37 +-- .../core_protocols/packet_function.py | 28 ++- .../protocols/observability_protocols.py | 32 ++- tests/test_channels/test_async_execute.py | 20 +- .../test_copilot_review_issues.py | 4 +- .../test_channels/test_node_async_execute.py | 16 +- .../test_cached_packet_function.py | 24 +- .../packet_function/test_executor.py | 81 +++---- .../packet_function/test_packet_function.py | 42 ++-- tests/test_core/test_regression_fixes.py | 25 +- tests/test_core/test_result_cache.py | 6 +- tests/test_pipeline/test_node_protocols.py | 82 ++++--- tests/test_pipeline/test_observer.py | 22 +- tests/test_pipeline/test_orchestrator.py | 72 +++--- tests/test_pipeline/test_pipeline.py | 20 +- tests/test_pipeline/test_sync_orchestrator.py | 155 ++++++------ 28 files changed, 874 insertions(+), 718 deletions(-) diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py index 771009dc..c6d9a00c 100644 --- a/src/orcapod/core/cached_function_pod.py +++ b/src/orcapod/core/cached_function_pod.py @@ -18,8 +18,6 @@ if TYPE_CHECKING: import pyarrow as pa - from orcapod.pipeline.logging_capture import CapturedLogs - logger = logging.getLogger(__name__) @@ -67,7 +65,11 @@ def record_path(self) -> tuple[str, ...]: return self._cache.record_path def process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a packet with pod-level caching. @@ -79,6 +81,7 @@ def process_packet( Args: tag: The tag associated with the packet. packet: The input packet to process. + logger: Optional packet execution logger. Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` @@ -86,10 +89,10 @@ def process_packet( """ cached = self._cache.lookup(packet) if cached is not None: - logger.info("Pod-level cache hit") + _logger.info("Pod-level cache hit") return tag, cached - tag, output = self._function_pod.process_packet(tag, packet) + tag, output = self._function_pod.process_packet(tag, packet, logger=logger) if output is not None: pf = self._function_pod.packet_function self._cache.store( @@ -102,7 +105,11 @@ def process_packet( return tag, output async def async_process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``. @@ -112,10 +119,12 @@ async def async_process_packet( """ cached = self._cache.lookup(packet) if cached is not None: - logger.info("Pod-level cache hit") + _logger.info("Pod-level cache hit") return tag, cached - tag, output = await self._function_pod.async_process_packet(tag, packet) + tag, output = await self._function_pod.async_process_packet( + tag, packet, logger=logger + ) if output is not None: pf = self._function_pod.packet_function self._cache.store( @@ -127,59 +136,6 @@ async def async_process_packet( output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) return tag, output - def process_packet_with_capture( - self, tag: TagProtocol, packet: PacketProtocol - ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": - """Process with pod-level caching, returning CapturedLogs alongside. - - On cache hit, returns empty CapturedLogs (no function was executed). - """ - from orcapod.pipeline.logging_capture import CapturedLogs - - cached = self._cache.lookup(packet) - if cached is not None: - logger.info("Pod-level cache hit") - return tag, cached, CapturedLogs(success=True) - - tag, output, captured = self._function_pod.process_packet_with_capture( - tag, packet - ) - if output is not None and captured.success: - pf = self._function_pod.packet_function - self._cache.store( - packet, - output, - variation_data=pf.get_function_variation_data(), - execution_data=pf.get_execution_data(), - ) - output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) - return tag, output, captured - - async def async_process_packet_with_capture( - self, tag: TagProtocol, packet: PacketProtocol - ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": - """Async counterpart of ``process_packet_with_capture``.""" - from orcapod.pipeline.logging_capture import CapturedLogs - - cached = self._cache.lookup(packet) - if cached is not None: - logger.info("Pod-level cache hit") - return tag, cached, CapturedLogs(success=True) - - tag, output, captured = await self._function_pod.async_process_packet_with_capture( - tag, packet - ) - if output is not None and captured.success: - pf = self._function_pod.packet_function - self._cache.store( - packet, - output, - variation_data=pf.get_function_variation_data(), - execution_data=pf.get_execution_data(), - ) - output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) - return tag, output, captured - def get_all_cached_outputs( self, include_system_columns: bool = False ) -> "pa.Table | None": @@ -207,3 +163,7 @@ def process( input_stream=input_stream, label=label, ) + + +# Module-level logger alias to avoid conflict with `logger` kwarg +_logger = logger diff --git a/src/orcapod/core/executors/base.py b/src/orcapod/core/executors/base.py index 1a7d0d6b..513fce4b 100644 --- a/src/orcapod/core/executors/base.py +++ b/src/orcapod/core/executors/base.py @@ -6,8 +6,8 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol class PacketFunctionExecutorBase(ABC): @@ -47,28 +47,33 @@ def supports(self, packet_function_type_id: str) -> bool: @abstractmethod def execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Synchronously execute *packet_function* on *packet*. Implementations should call ``packet_function.direct_call(packet)`` to invoke the function's native computation, bypassing executor - routing, and pass through the ``(result, CapturedLogs)`` tuple. + routing. If a logger is provided, captured I/O is recorded to it. + On failure, the original exception is re-raised after recording. """ ... async def async_execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Asynchronous counterpart of ``execute``. The default implementation delegates to ``execute`` synchronously. Subclasses should override for truly async execution. """ - return self.execute(packet_function, packet) + return self.execute(packet_function, packet, logger=logger) @property def supports_concurrent_execution(self) -> bool: @@ -97,39 +102,41 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": - """Synchronously execute *fn* with *kwargs*, returning captured I/O. + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: + """Synchronously execute *fn* with *kwargs*, returning the raw result. - Default implementation calls ``fn(**kwargs)`` with no capture and - returns empty :class:`~orcapod.pipeline.logging_capture.CapturedLogs`. + Default implementation calls ``fn(**kwargs)`` with no capture. Exceptions propagate to the caller. Subclasses (e.g. - ``LocalExecutor``, ``RayExecutor``) override to add I/O capture and - exception swallowing. + ``LocalExecutor``, ``RayExecutor``) override to add I/O capture + and logger recording. Args: fn: The Python callable to execute. kwargs: Keyword arguments to pass to *fn*. executor_options: Optional per-call options. + logger: Optional logger to record captured I/O. Returns: - ``(raw_result, CapturedLogs)`` + The raw return value of *fn*. """ - from orcapod.pipeline.logging_capture import CapturedLogs - - return fn(**kwargs), CapturedLogs() + return fn(**kwargs) async def async_execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": - """Asynchronously execute *fn* with *kwargs*, returning captured I/O. + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: + """Asynchronously execute *fn* with *kwargs*, returning the raw result. Default implementation delegates to ``execute_callable`` synchronously. Subclasses should override for truly async execution. """ - return self.execute_callable(fn, kwargs, executor_options) + return self.execute_callable(fn, kwargs, executor_options, logger=logger) def get_execution_data(self) -> dict[str, Any]: """Return metadata describing the execution environment. diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py index ec05a086..ce13be5e 100644 --- a/src/orcapod/core/executors/local.py +++ b/src/orcapod/core/executors/local.py @@ -9,8 +9,8 @@ from orcapod.core.executors.base import PacketFunctionExecutorBase if TYPE_CHECKING: - from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol class LocalExecutor(PacketFunctionExecutorBase): @@ -29,16 +29,20 @@ def supported_function_type_ids(self) -> frozenset[str]: def execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": return packet_function.direct_call(packet) async def async_execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": return await packet_function.direct_async_call(packet) # -- PythonFunctionExecutorProtocol -- @@ -48,13 +52,12 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": - from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: + from orcapod.pipeline.logging_capture import LocalCaptureContext ctx = LocalCaptureContext() - raw_result = None - success = True - tb: str | None = None with ctx: try: if inspect.iscoroutinefunction(fn): @@ -62,9 +65,15 @@ def execute_callable( else: raw_result = fn(**kwargs) except Exception: - success = False tb = _traceback_module.format_exc() - return raw_result, ctx.get_captured(success=success, tb=tb) + captured = ctx.get_captured(success=False, tb=tb) + if logger is not None: + logger.record(captured) + raise + captured = ctx.get_captured(success=True) + if logger is not None: + logger.record(captured) + return raw_result @staticmethod def _run_async_sync(fn: Callable[..., Any], kwargs: dict[str, Any]) -> Any: @@ -84,24 +93,36 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": - from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: + from orcapod.pipeline.logging_capture import LocalCaptureContext ctx = LocalCaptureContext() - raw_result = None - success = True - tb: str | None = None with ctx: try: if inspect.iscoroutinefunction(fn): raw_result = await fn(**kwargs) else: + import contextvars + import functools + loop = asyncio.get_running_loop() - raw_result = await loop.run_in_executor(None, lambda: fn(**kwargs)) + task_ctx = contextvars.copy_context() + raw_result = await loop.run_in_executor( + None, + functools.partial(task_ctx.run, fn, **kwargs), + ) except Exception: - success = False tb = _traceback_module.format_exc() - return raw_result, ctx.get_captured(success=success, tb=tb) + captured = ctx.get_captured(success=False, tb=tb) + if logger is not None: + logger.record(captured) + raise + captured = ctx.get_captured(success=True) + if logger is not None: + logger.record(captured) + return raw_result def with_options(self, **opts: Any) -> "LocalExecutor": """Return a new ``LocalExecutor``. diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py index b1236fbc..89dd9755 100644 --- a/src/orcapod/core/executors/ray.py +++ b/src/orcapod/core/executors/ray.py @@ -8,8 +8,8 @@ if TYPE_CHECKING: from orcapod.core.packet_function import PythonPacketFunction - from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol class RayExecutor(PacketFunctionExecutorBase): @@ -132,7 +132,7 @@ def _build_remote_opts(self) -> dict[str, Any]: return dict(self._remote_opts) def _as_python_packet_function( - self, packet_function: PacketFunctionProtocol + self, packet_function: "PacketFunctionProtocol" ) -> "PythonPacketFunction": """Return *packet_function* cast to ``PythonPacketFunction``, or raise. @@ -152,35 +152,33 @@ def _as_python_packet_function( def execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": - from orcapod.pipeline.logging_capture import CapturedLogs - + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": pf = self._as_python_packet_function(packet_function) if not pf.is_active(): - return None, CapturedLogs(success=True) + return None - raw, captured = self.execute_callable(pf._function, packet.as_dict()) - if not captured.success: - return None, captured - return pf._build_output_packet(raw), captured + raw = self.execute_callable(pf._function, packet.as_dict(), logger=logger) + return pf._build_output_packet(raw) async def async_execute( self, - packet_function: PacketFunctionProtocol, - packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": - from orcapod.pipeline.logging_capture import CapturedLogs - + packet_function: "PacketFunctionProtocol", + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": pf = self._as_python_packet_function(packet_function) if not pf.is_active(): - return None, CapturedLogs(success=True) + return None - raw, captured = await self.async_execute_callable(pf._function, packet.as_dict()) - if not captured.success: - return None, captured - return pf._build_output_packet(raw), captured + raw = await self.async_execute_callable( + pf._function, packet.as_dict(), logger=logger + ) + return pf._build_output_packet(raw) # -- PythonFunctionExecutorProtocol -- @@ -274,7 +272,9 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: """Execute *fn* on the Ray cluster with fd-level I/O capture. The capture wrapper is serialized by bytecode (not module reference) so @@ -291,17 +291,25 @@ def execute_callable( remote_fn = ray.remote(**self._build_remote_opts())(wrapper) ref = remote_fn.remote(fn, kwargs) raw, stdout, stderr, python_logs, tb, success = ray.get(ref) - return raw, CapturedLogs( + + captured = CapturedLogs( stdout=stdout, stderr=stderr, python_logs=python_logs, traceback=tb, success=success, ) + if logger is not None: + logger.record(captured) + if not success: + raise RuntimeError(tb or "Ray worker execution failed") + return raw async def async_execute_callable( self, fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: """Async counterpart of :meth:`execute_callable`.""" import ray @@ -316,10 +324,16 @@ async def async_execute_callable( raw, stdout, stderr, python_logs, tb, success = await asyncio.wrap_future( ref.future() ) - return raw, CapturedLogs( + + captured = CapturedLogs( stdout=stdout, stderr=stderr, python_logs=python_logs, traceback=tb, success=success, ) + if logger is not None: + logger.record(captured) + if not success: + raise RuntimeError(tb or "Ray worker execution failed") + return raw def with_options(self, **opts: Any) -> "RayExecutor": """Return a new ``RayExecutor`` with the given options merged in. diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py index 4de7caf5..f68e83ba 100644 --- a/src/orcapod/core/function_pod.py +++ b/src/orcapod/core/function_pod.py @@ -142,50 +142,38 @@ def _validate_input_schema(self, input_schema: Schema) -> None: ) def process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Process a single packet using the pod's packet function. Args: tag: The tag associated with the packet. packet: The input packet to process. + logger: Optional :class:`PacketExecutionLoggerProtocol` for + recording captured I/O. Returns: A ``(tag, output_packet)`` tuple; output_packet is ``None`` if - the function filters the packet out. CapturedLogs are discarded - (only relevant for node-level execution with observers). + the function filters the packet out. """ - result, _captured = self.packet_function.call(packet) + result = self.packet_function.call(packet, logger=logger) return tag, result async def async_process_packet( - self, tag: TagProtocol, packet: PacketProtocol + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: Any = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``process_packet``.""" - result, _captured = await self.packet_function.async_call(packet) + result = await self.packet_function.async_call(packet, logger=logger) return tag, result - def process_packet_with_capture( - self, tag: TagProtocol, packet: PacketProtocol - ) -> tuple[TagProtocol, PacketProtocol | None, "CapturedLogs"]: - """Process a single packet and return CapturedLogs alongside the result. - - Used by FunctionNode to get logs without a ContextVar side-channel. - """ - from orcapod.pipeline.logging_capture import CapturedLogs - - result, captured = self.packet_function.call(packet) - return tag, result, captured - - async def async_process_packet_with_capture( - self, tag: TagProtocol, packet: PacketProtocol - ) -> tuple[TagProtocol, PacketProtocol | None, "CapturedLogs"]: - """Async counterpart of ``process_packet_with_capture``.""" - from orcapod.pipeline.logging_capture import CapturedLogs - - result, captured = await self.packet_function.async_call(packet) - return tag, result, captured - def handle_input_streams(self, *streams: StreamProtocol) -> StreamProtocol: """Handle multiple input streams by joining them if necessary. @@ -378,6 +366,9 @@ async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: tag, result_packet = await self.async_process_packet(tag, packet) if result_packet is not None: await output.send((tag, result_packet)) + except Exception: + # Swallow packet-level errors so remaining packets continue. + logger.debug("Packet processing failed, skipping", exc_info=True) finally: if sem is not None: sem.release() diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index 4e1adc09..56539430 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -39,7 +39,6 @@ import polars as pl import pyarrow as pa - from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.observability_protocols import ExecutionObserverProtocol else: pa = LazyModule("pyarrow") @@ -488,9 +487,9 @@ def execute_packet( packet: The input packet to process. Returns: - A ``(tag, output_packet)`` tuple. CapturedLogs are discarded. + A ``(tag, output_packet)`` tuple. """ - tag_out, result, _captured = self._process_packet_internal(tag, packet) + tag_out, result = self._process_packet_internal(tag, packet) return tag_out, result def execute( @@ -512,10 +511,13 @@ def execute( Materialized list of (tag, output_packet) pairs, excluding None outputs and failed packets. """ - from orcapod.pipeline.observer import _NOOP_LOGGER + node_label = self.label or "unknown" + node_hash = self.pipeline_hash().to_string() if self._pipeline_database is not None else "" - if observer is not None: - observer.on_node_start(self) + ctx_observer = observer.contextualize(node_hash, node_label) if observer is not None else None + + if ctx_observer is not None: + ctx_observer.on_node_start(node_label, node_hash) # Gather entry IDs and check cache upstream_entries = [ @@ -529,48 +531,43 @@ def execute( output: list[tuple[TagProtocol, PacketProtocol]] = [] for tag, packet, entry_id in upstream_entries: - if observer is not None: - observer.on_packet_start(self, tag, packet) - pkt_logger = observer.create_packet_logger( - self, tag, packet, pipeline_path=pp - ) - else: - pkt_logger = _NOOP_LOGGER + if ctx_observer is not None: + ctx_observer.on_packet_start(node_label, tag, packet) if entry_id in cached: tag_out, result = cached[entry_id] - if observer is not None: - observer.on_packet_end(self, tag, packet, result, cached=True) + if ctx_observer is not None: + ctx_observer.on_packet_end( + node_label, tag, packet, result, cached=True + ) output.append((tag_out, result)) else: - tag_out, result, captured = self._process_packet_internal(tag, packet) - pkt_logger.record(captured) - if not captured.success: - if observer is not None: - observer.on_packet_crash( - self, - tag, - packet, - RuntimeError( - captured.traceback or "packet function failed" - ), - ) + pkt_logger = ( + ctx_observer.create_packet_logger(tag, packet, pipeline_path=pp) + if ctx_observer is not None + else None + ) + try: + tag_out, result = self._process_packet_internal( + tag, packet, logger=pkt_logger + ) + except Exception as exc: + if ctx_observer is not None: + ctx_observer.on_packet_crash(node_label, tag, packet, exc) if error_policy == "fail_fast": - if observer is not None: - observer.on_node_end(self) - raise RuntimeError( - captured.traceback or "packet function failed" - ) + if ctx_observer is not None: + ctx_observer.on_node_end(node_label, node_hash) + raise else: - if observer is not None: - observer.on_packet_end( - self, tag, packet, result, cached=False + if ctx_observer is not None: + ctx_observer.on_packet_end( + node_label, tag, packet, result, cached=False ) if result is not None: output.append((tag_out, result)) - if observer is not None: - observer.on_node_end(self) + if ctx_observer is not None: + ctx_observer.on_node_end(node_label, node_hash) return output def _process_packet_internal( @@ -578,27 +575,31 @@ def _process_packet_internal( tag: TagProtocol, packet: PacketProtocol, cache_index: int | None = None, - ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + *, + logger: Any = None, + ) -> tuple[TagProtocol, PacketProtocol | None]: """Core compute + persist + cache. Used by ``execute_packet``, ``execute``, and ``iter_packets``. No input validation is performed — the caller guarantees correctness. + Exceptions propagate to the caller. Returns: - A ``(tag, output_packet, captured_logs)`` 3-tuple. + A ``(tag, output_packet)`` 2-tuple. Args: tag: The input tag. packet: The input packet. cache_index: Optional explicit index for the internal cache. When ``None``, auto-assigns at ``len(_cached_output_packets)``. + logger: Optional packet execution logger. """ if self._cached_function_pod is not None: - tag_out, output_packet, captured = ( - self._cached_function_pod.process_packet_with_capture(tag, packet) + tag_out, output_packet = self._cached_function_pod.process_packet( + tag, packet, logger=logger ) - if output_packet is not None and captured.success: + if output_packet is not None: result_computed = bool( output_packet.get_meta_value( self._cached_function_pod.RESULT_COMPUTED_FLAG, False @@ -611,8 +612,8 @@ def _process_packet_internal( computed=result_computed, ) else: - tag_out, output_packet, captured = ( - self._function_pod.process_packet_with_capture(tag, packet) + tag_out, output_packet = self._function_pod.process_packet( + tag, packet, logger=logger ) # Cache internally and invalidate derived caches @@ -625,7 +626,7 @@ def _process_packet_internal( self._cached_output_table = None self._cached_content_hash_column = None - return tag_out, output_packet, captured + return tag_out, output_packet def get_cached_results( self, entry_ids: list[str] @@ -715,29 +716,32 @@ async def _async_process_packet_internal( tag: TagProtocol, packet: PacketProtocol, cache_index: int | None = None, - ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + *, + logger: Any = None, + ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``_process_packet_internal``. Computes via async path, writes pipeline provenance, and caches - internally — no schema validation. + internally — no schema validation. Exceptions propagate. Returns: - A ``(tag, output_packet, captured_logs)`` 3-tuple. + A ``(tag, output_packet)`` 2-tuple. Args: tag: The input tag. packet: The input packet. cache_index: Optional explicit index for the internal cache. When ``None``, auto-assigns at ``len(_cached_output_packets)``. + logger: Optional packet execution logger. """ if self._cached_function_pod is not None: - tag_out, output_packet, captured = ( - await self._cached_function_pod.async_process_packet_with_capture( - tag, packet + tag_out, output_packet = ( + await self._cached_function_pod.async_process_packet( + tag, packet, logger=logger ) ) - if output_packet is not None and captured.success: + if output_packet is not None: result_computed = bool( output_packet.get_meta_value( self._cached_function_pod.RESULT_COMPUTED_FLAG, False @@ -750,8 +754,10 @@ async def _async_process_packet_internal( computed=result_computed, ) else: - tag_out, output_packet, captured = ( - await self._function_pod.async_process_packet_with_capture(tag, packet) + tag_out, output_packet = ( + await self._function_pod.async_process_packet( + tag, packet, logger=logger + ) ) # Cache internally and invalidate derived caches @@ -764,7 +770,7 @@ async def _async_process_packet_internal( self._cached_output_table = None self._cached_content_hash_column = None - return tag_out, output_packet, captured + return tag_out, output_packet def compute_pipeline_entry_id( self, tag: TagProtocol, input_packet: PacketProtocol @@ -1011,9 +1017,7 @@ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: entry_id = self.compute_pipeline_entry_id(tag, packet) if entry_id in existing_entry_ids: continue - tag, output_packet, _captured = self._process_packet_internal( - tag, packet - ) + tag, output_packet = self._process_packet_internal(tag, packet) if output_packet is not None: yield tag, output_packet @@ -1050,9 +1054,7 @@ def _iter_packets_sequential( if packet is not None: yield tag, packet else: - tag, output_packet, _captured = self._process_packet_internal( - tag, packet - ) + tag, output_packet = self._process_packet_internal(tag, packet) if output_packet is not None: yield tag, output_packet self._cached_input_iterator = None @@ -1234,9 +1236,14 @@ async def async_execute( """ # TODO(PLT-930): Restore concurrency limiting (semaphore) via node-level config. # Currently all packets are processed sequentially in async_execute. + node_label = self.label or "unknown" + node_hash = self.pipeline_hash().to_string() if self._pipeline_database is not None else "" + + ctx_observer = observer.contextualize(node_hash, node_label) if observer is not None else None + try: - if observer is not None: - observer.on_node_start(self) + if ctx_observer is not None: + ctx_observer.on_node_start(node_label, node_hash) if self._cached_function_pod is not None: # DB-backed async execution: @@ -1292,25 +1299,31 @@ async def async_execute( entry_id = self.compute_pipeline_entry_id(tag, packet) if entry_id in cached_by_entry_id: tag_out, result_packet = cached_by_entry_id[entry_id] - if observer is not None: - observer.on_packet_start(self, tag, packet) - observer.on_packet_end( - self, tag, packet, result_packet, cached=True + if ctx_observer is not None: + ctx_observer.on_packet_start(node_label, tag, packet) + ctx_observer.on_packet_end( + node_label, tag, packet, result_packet, cached=True ) await output.send((tag_out, result_packet)) else: await self._async_execute_one_packet( - tag, packet, output, observer=observer + tag, packet, output, + ctx_observer=ctx_observer, + node_label=node_label, + node_hash=node_hash, ) else: # Simple async execution without DB async for tag, packet in input_channel: await self._async_execute_one_packet( - tag, packet, output, observer=observer + tag, packet, output, + ctx_observer=ctx_observer, + node_label=node_label, + node_hash=node_hash, ) - if observer is not None: - observer.on_node_end(self) + if ctx_observer is not None: + ctx_observer.on_node_end(node_label, node_hash) finally: await output.close() @@ -1320,37 +1333,33 @@ async def _async_execute_one_packet( packet: PacketProtocol, output: WritableChannel[tuple[TagProtocol, PacketProtocol]], *, - observer: "ExecutionObserverProtocol | None" = None, + ctx_observer: "ExecutionObserverProtocol | None" = None, + node_label: str = "unknown", + node_hash: str = "", ) -> None: """Process one non-cached packet in the async execute path.""" - from orcapod.pipeline.observer import _NOOP_LOGGER - pp = self.pipeline_path if self._pipeline_database is not None else () - if observer is not None: - observer.on_packet_start(self, tag, packet) - pkt_logger = observer.create_packet_logger( - self, tag, packet, pipeline_path=pp - ) - else: - pkt_logger = _NOOP_LOGGER + if ctx_observer is not None: + ctx_observer.on_packet_start(node_label, tag, packet) - tag_out, result_packet, captured = await self._async_process_packet_internal( - tag, packet + pkt_logger = ( + ctx_observer.create_packet_logger(tag, packet, pipeline_path=pp) + if ctx_observer is not None + else None ) - pkt_logger.record(captured) - if not captured.success: - if observer is not None: - observer.on_packet_crash( - self, - tag, - packet, - RuntimeError(captured.traceback or "packet function failed"), - ) + + try: + tag_out, result_packet = await self._async_process_packet_internal( + tag, packet, logger=pkt_logger + ) + except Exception as exc: + if ctx_observer is not None: + ctx_observer.on_packet_crash(node_label, tag, packet, exc) else: - if observer is not None: - observer.on_packet_end( - self, tag, packet, result_packet, cached=False + if ctx_observer is not None: + ctx_observer.on_packet_end( + node_label, tag, packet, result_packet, cached=False ) if result_packet is not None: await output.send((tag_out, result_packet)) diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py index e16da95a..9286f09f 100644 --- a/src/orcapod/core/nodes/operator_node.py +++ b/src/orcapod/core/nodes/operator_node.py @@ -445,15 +445,17 @@ def execute( Returns: Materialized list of (tag, packet) pairs. """ + node_label = self.label or "operator" + node_hash = "" if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) # Check REPLAY cache first cached_output = self.get_cached_output() if cached_output is not None: output = list(cached_output.iter_packets()) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return output # Compute @@ -481,7 +483,7 @@ def execute( self._store_output_stream(self._cached_output_stream) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return output def _compute_and_store(self) -> None: @@ -658,21 +660,23 @@ async def async_execute( output: Writable channel for output (tag, packet) pairs. observer: Optional execution observer for hooks. """ + node_label = self.label or "operator" + node_hash = "" if self._pipeline_database is None: # Simple delegation without DB if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) hashes = [s.pipeline_hash() for s in self._input_streams] await self._operator.async_execute( inputs, output, input_pipeline_hashes=hashes ) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return try: if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) if self._cache_mode == CacheMode.REPLAY: self._replay_from_cache() @@ -712,7 +716,7 @@ async def forward() -> None: self._update_modified_time() if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) finally: await output.close() diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py index 504f29f7..e0757c06 100644 --- a/src/orcapod/core/nodes/source_node.py +++ b/src/orcapod/core/nodes/source_node.py @@ -253,12 +253,14 @@ def execute( raise RuntimeError( "SourceNode in read-only mode has no stream data available" ) + node_label = self.label or "source" + node_hash = "" if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) result = list(self.stream.iter_packets()) self._cached_results = result if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) return result def run(self) -> None: @@ -280,12 +282,14 @@ async def async_execute( raise RuntimeError( "SourceNode in read-only mode has no stream data available" ) + node_label = self.label or "source" + node_hash = "" try: if observer is not None: - observer.on_node_start(self) + observer.on_node_start(node_label, node_hash) for tag, packet in self.stream.iter_packets(): await output.send((tag, packet)) if observer is not None: - observer.on_node_end(self) + observer.on_node_end(node_label, node_hash) finally: await output.close() diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py index ee0d8f0b..03deccde 100644 --- a/src/orcapod/core/packet_function.py +++ b/src/orcapod/core/packet_function.py @@ -37,12 +37,12 @@ import pyarrow as pa import pyarrow.compute as pc - from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol else: pa = LazyModule("pyarrow") pc = LazyModule("pyarrow.compute") -logger = logging.getLogger(__name__) +logger = _logger = logging.getLogger(__name__) error_handling_options = Literal["raise", "ignore", "warn"] @@ -132,6 +132,8 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls._resolved_executor_protocol = args[0] return + _SKIP_AUTO_EXECUTOR = object() + def __init__( self, version: str = "v0.0", @@ -139,6 +141,7 @@ def __init__( data_context: str | DataContext | None = None, config: Config | None = None, executor: PacketFunctionExecutorProtocol | None = None, + _skip_auto_executor: bool = False, ): super().__init__(label=label, data_context=data_context, config=config) self._active = True @@ -164,6 +167,12 @@ def __init__( # *after* super().__init__(). if executor is not None: self.executor = executor + elif not _skip_auto_executor: + # Auto-assign LocalExecutor so all execution routes through + # the executor layer (ensuring capture is always available). + from orcapod.core.executors.local import LocalExecutor + + self.executor = LocalExecutor() def computed_label(self) -> str | None: """Return the canonical function name as the label if no explicit label is given.""" @@ -288,50 +297,49 @@ def set_executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: # ==================== Execution ==================== def call( - self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Process a single packet, routing through the executor if one is set. Subclasses should override ``direct_call`` instead of this method. - - Returns: - A ``(output_packet, captured_logs)`` tuple. """ if self._executor is not None: - return self._executor.execute(self, packet) + return self._executor.execute(self, packet, logger=logger) return self.direct_call(packet) async def async_call( - self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Asynchronously process a single packet, routing through the executor if set. Subclasses should override ``direct_async_call`` instead of this method. - - Returns: - A ``(output_packet, captured_logs)`` tuple. """ if self._executor is not None: - return await self._executor.async_execute(self, packet) + return await self._executor.async_execute(self, packet, logger=logger) return await self.direct_async_call(packet) @abstractmethod def direct_call( self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Execute the function's native computation on *packet*. This is the method executors invoke. It bypasses executor routing and runs the computation directly. On user-function failure the - exception is caught internally and ``(None, captured_failure)`` - is returned — no re-raise. Subclasses must implement this. + exception is re-raised. Subclasses must implement this. """ ... @abstractmethod async def direct_async_call( self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Asynchronous counterpart of ``direct_call``.""" ... @@ -527,115 +535,91 @@ def _call_async_function_sync(self, packet: PacketProtocol) -> Any: ) def call( - self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Process a single packet, routing through the executor if one is set. When an executor implementing ``PythonFunctionExecutorProtocol`` is set, the raw callable and kwargs are handed to ``execute_callable`` - which returns ``(raw_result, CapturedLogs)``. The output packet is - built from ``raw_result``. + which captures I/O and records to the logger. """ - from orcapod.pipeline.logging_capture import CapturedLogs - if self._executor is not None: if not self._active: - return None, CapturedLogs(success=True) - raw, captured = self._executor.execute_callable( - self._function, packet.as_dict() + return None + raw = self._executor.execute_callable( + self._function, packet.as_dict(), logger=logger ) - if not captured.success: - return None, captured - return self._build_output_packet(raw), captured + return self._build_output_packet(raw) return self.direct_call(packet) async def async_call( - self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Async counterpart of ``call``.""" - from orcapod.pipeline.logging_capture import CapturedLogs - if self._executor is not None: if not self._active: - return None, CapturedLogs(success=True) - raw, captured = await self._executor.async_execute_callable( - self._function, packet.as_dict() + return None + raw = await self._executor.async_execute_callable( + self._function, packet.as_dict(), logger=logger ) - if not captured.success: - return None, captured - return self._build_output_packet(raw), captured + return self._build_output_packet(raw) return await self.direct_async_call(packet) def direct_call( self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Execute the function on *packet* synchronously (no executor path). - Uses :class:`~orcapod.pipeline.logging_capture.LocalCaptureContext` - for I/O capture. On user-function failure the exception is caught - internally and ``(None, captured_failure)`` is returned — no re-raise. + On user-function failure the exception is re-raised. For async functions, the coroutine is driven to completion via ``asyncio.run()`` (or a helper thread when already inside an event loop). """ - import traceback as _tb - - from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext - if not self._active: - return None, CapturedLogs(success=True) - - ctx = LocalCaptureContext() - raw_result = None - with ctx: - try: - if self._is_async: - raw_result = self._call_async_function_sync(packet) - else: - raw_result = self._function(**packet.as_dict()) - except Exception: - return None, ctx.get_captured(success=False, tb=_tb.format_exc()) - return self._build_output_packet(raw_result), ctx.get_captured(success=True) + return None + + if self._is_async: + raw_result = self._call_async_function_sync(packet) + else: + raw_result = self._function(**packet.as_dict()) + return self._build_output_packet(raw_result) async def direct_async_call( self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Execute the function on *packet* asynchronously (no executor path). Async functions are ``await``-ed directly. Sync functions are offloaded to a thread pool via ``run_in_executor``. On failure, - ``(None, captured_failure)`` is returned — no re-raise. + the exception is re-raised. """ import asyncio - import traceback as _tb - - from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext if not self._active: - return None, CapturedLogs(success=True) - - ctx = LocalCaptureContext() - raw_result = None - with ctx: - try: - if self._is_async: - raw_result = await self._function(**packet.as_dict()) - else: - import contextvars - import functools - - loop = asyncio.get_running_loop() - task_ctx = contextvars.copy_context() - raw_result = await loop.run_in_executor( - None, - functools.partial( - task_ctx.run, - self._function, - **packet.as_dict(), - ), - ) - except Exception: - return None, ctx.get_captured(success=False, tb=_tb.format_exc()) - return self._build_output_packet(raw_result), ctx.get_captured(success=True) + return None + + if self._is_async: + raw_result = await self._function(**packet.as_dict()) + else: + import contextvars + import functools + + loop = asyncio.get_running_loop() + task_ctx = contextvars.copy_context() + raw_result = await loop.run_in_executor( + None, + functools.partial( + task_ctx.run, + self._function, + **packet.as_dict(), + ), + ) + return self._build_output_packet(raw_result) def to_config(self) -> dict[str, Any]: """Serialize this packet function to a JSON-compatible config dict. @@ -697,8 +681,10 @@ class PacketFunctionWrapper(PacketFunctionBase[E]): """ def __init__(self, packet_function: PacketFunctionProtocol, **kwargs) -> None: - super().__init__(**kwargs) self._packet_function = packet_function + # Skip auto-executor assignment — wrappers delegate executor + # to the wrapped packet function which already has one. + super().__init__(_skip_auto_executor=True, **kwargs) def computed_label(self) -> str | None: return self._packet_function.label @@ -786,23 +772,29 @@ def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: # direct_async_call bypass executor routing as their names imply. def call( - self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": - return self._packet_function.call(packet) + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: + return self._packet_function.call(packet, logger=logger) async def async_call( - self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": - return await self._packet_function.async_call(packet) + self, + packet: PacketProtocol, + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: + return await self._packet_function.async_call(packet, logger=logger) def direct_call( self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: return self._packet_function.direct_call(packet) async def direct_async_call( self, packet: PacketProtocol - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: return await self._packet_function.direct_async_call(packet) @@ -845,19 +837,18 @@ def call( self, packet: PacketProtocol, *, + logger: "PacketExecutionLoggerProtocol | None" = None, skip_cache_lookup: bool = False, skip_cache_insert: bool = False, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": - from orcapod.pipeline.logging_capture import CapturedLogs - + ) -> PacketProtocol | None: output_packet = None if not skip_cache_lookup: - logger.info("Checking for cache...") + _logger.info("Checking for cache...") output_packet = self._cache.lookup(packet) if output_packet is not None: - logger.info(f"Cache hit for {packet}!") - return output_packet, CapturedLogs(success=True) - output_packet, captured = self._packet_function.call(packet) + _logger.info(f"Cache hit for {packet}!") + return output_packet + output_packet = self._packet_function.call(packet, logger=logger) if output_packet is not None: if not skip_cache_insert: self._cache.store( @@ -869,26 +860,25 @@ def call( output_packet = output_packet.with_meta_columns( **{self.RESULT_COMPUTED_FLAG: True} ) - return output_packet, captured + return output_packet async def async_call( self, packet: PacketProtocol, *, + logger: "PacketExecutionLoggerProtocol | None" = None, skip_cache_lookup: bool = False, skip_cache_insert: bool = False, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Async counterpart of ``call`` with cache check and recording.""" - from orcapod.pipeline.logging_capture import CapturedLogs - output_packet = None if not skip_cache_lookup: - logger.info("Checking for cache...") + _logger.info("Checking for cache...") output_packet = self._cache.lookup(packet) if output_packet is not None: - logger.info(f"Cache hit for {packet}!") - return output_packet, CapturedLogs(success=True) - output_packet, captured = await self._packet_function.async_call(packet) + _logger.info(f"Cache hit for {packet}!") + return output_packet + output_packet = await self._packet_function.async_call(packet, logger=logger) if output_packet is not None: if not skip_cache_insert: self._cache.store( @@ -900,7 +890,7 @@ async def async_call( output_packet = output_packet.with_meta_columns( **{self.RESULT_COMPUTED_FLAG: True} ) - return output_packet, captured + return output_packet def get_cached_output_for_packet( self, input_packet: PacketProtocol diff --git a/src/orcapod/core/packet_function_proxy.py b/src/orcapod/core/packet_function_proxy.py index 615b94a9..2827ea55 100644 --- a/src/orcapod/core/packet_function_proxy.py +++ b/src/orcapod/core/packet_function_proxy.py @@ -10,13 +10,16 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from orcapod.core.packet_function import PacketFunctionBase from orcapod.errors import PacketFunctionUnavailableError from orcapod.protocols.core_protocols import PacketFunctionProtocol from orcapod.types import ContentHash, Schema +if TYPE_CHECKING: + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol + class PacketFunctionProxy(PacketFunctionBase): """Stand-in for an unavailable packet function. @@ -55,8 +58,10 @@ def __init__( # Call super().__init__ so that major_version and # output_packet_schema_hash are available for URI fallback. + # Skip auto-executor: the proxy delegates executor to the bound + # function (when one is bound). version = inner["version"] - super().__init__(version=version) + super().__init__(version=version, _skip_auto_executor=True) # URI: read from config if present, otherwise compute from metadata. uri_list = config.get("uri") @@ -129,16 +134,26 @@ def _raise_unavailable(self) -> None: f"Use bind() to attach a real function, or access cached results only." ) - def call(self, packet: "PacketProtocol") -> "PacketProtocol | None": + def call( + self, + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Process a single packet; delegates to bound function or raises.""" if self._bound_function is not None: - return self._bound_function.call(packet) + return self._bound_function.call(packet, logger=logger) self._raise_unavailable() - async def async_call(self, packet: "PacketProtocol") -> "PacketProtocol | None": + async def async_call( + self, + packet: "PacketProtocol", + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> "PacketProtocol | None": """Async counterpart of ``call``.""" if self._bound_function is not None: - return await self._bound_function.async_call(packet) + return await self._bound_function.async_call(packet, logger=logger) self._raise_unavailable() def direct_call(self, packet: "PacketProtocol") -> "PacketProtocol | None": diff --git a/src/orcapod/pipeline/logging_observer.py b/src/orcapod/pipeline/logging_observer.py index 6db61312..af4054dd 100644 --- a/src/orcapod/pipeline/logging_observer.py +++ b/src/orcapod/pipeline/logging_observer.py @@ -35,6 +35,9 @@ * - ``node_label`` - ``large_utf8`` - Label of the function node + * - ``node_hash`` + - ``large_utf8`` + - Pipeline hash of the function node * - ``stdout`` - ``large_utf8`` - Captured standard output @@ -91,8 +94,8 @@ class PacketLogger: """Context-bound logger created by :class:`LoggingObserver` per packet. Holds all context needed to write a structured log row - (run_id, node_label, tag data) so the caller only needs to pass the - :class:`~orcapod.pipeline.logging_capture.CapturedLogs` payload. + (run_id, node_label, node_hash, tag data) so the caller only needs to + pass the :class:`~orcapod.pipeline.logging_capture.CapturedLogs` payload. Tag data is stored as individual queryable columns (not JSON) alongside the fixed log columns. @@ -107,12 +110,14 @@ def __init__( log_path: tuple[str, ...], run_id: str, node_label: str, + node_hash: str, tag_data: dict[str, Any], ) -> None: self._db = db self._log_path = log_path self._run_id = run_id self._node_label = node_label + self._node_hash = node_hash self._tag_data = tag_data def record(self, captured: CapturedLogs) -> None: @@ -127,6 +132,7 @@ def record(self, captured: CapturedLogs) -> None: "log_id": pa.array([log_id], type=pa.large_utf8()), "run_id": pa.array([self._run_id], type=pa.large_utf8()), "node_label": pa.array([self._node_label], type=pa.large_utf8()), + "node_hash": pa.array([self._node_hash], type=pa.large_utf8()), "stdout": pa.array([captured.stdout], type=pa.large_utf8()), "stderr": pa.array([captured.stderr], type=pa.large_utf8()), "python_logs": pa.array([captured.python_logs], type=pa.large_utf8()), @@ -149,6 +155,86 @@ def record(self, captured: CapturedLogs) -> None: ) +class _ContextualizedLoggingObserver: + """Lightweight wrapper holding parent observer + node identity context. + + Created by :meth:`LoggingObserver.contextualize`. All lifecycle hooks + and logger creation use the stamped ``node_hash`` and ``node_label``. + """ + + def __init__( + self, + parent: "LoggingObserver", + node_hash: str, + node_label: str, + ) -> None: + self._parent = parent + self._node_hash = node_hash + self._node_label = node_label + + def contextualize( + self, node_hash: str, node_label: str + ) -> "_ContextualizedLoggingObserver": + """Re-contextualize (returns a new wrapper with updated identity).""" + return _ContextualizedLoggingObserver(self._parent, node_hash, node_label) + + def on_run_start(self, run_id: str) -> None: + self._parent.on_run_start(run_id) + + def on_run_end(self, run_id: str) -> None: + self._parent.on_run_end(run_id) + + def on_node_start(self, node_label: str, node_hash: str) -> None: + self._parent.on_node_start(node_label, node_hash) + + def on_node_end(self, node_label: str, node_hash: str) -> None: + self._parent.on_node_end(node_label, node_hash) + + def on_packet_start( + self, node_label: str, tag: Any, packet: Any + ) -> None: + self._parent.on_packet_start(node_label, tag, packet) + + def on_packet_end( + self, + node_label: str, + tag: Any, + input_packet: Any, + output_packet: Any, + cached: bool, + ) -> None: + self._parent.on_packet_end(node_label, tag, input_packet, output_packet, cached) + + def on_packet_crash( + self, node_label: str, tag: Any, packet: Any, error: Exception + ) -> None: + self._parent.on_packet_crash(node_label, tag, packet, error) + + def create_packet_logger( + self, + tag: Any, + packet: Any, + pipeline_path: tuple[str, ...] = (), + ) -> PacketLogger: + """Create a logger using context from this wrapper.""" + tag_data = dict(tag) + + # Compute mirrored log path + if pipeline_path: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._parent._log_path + + return PacketLogger( + db=self._parent._db, + log_path=log_path, + run_id=self._parent._current_run_id, + node_label=self._node_label, + node_hash=self._node_hash, + tag_data=tag_data, + ) + + class LoggingObserver: """Concrete observer that writes packet execution logs to a database. @@ -191,6 +277,14 @@ def __init__( # Activate tee-capture as soon as the observer is created. install_capture_streams() + # -- contextualize -- + + def contextualize( + self, node_hash: str, node_label: str + ) -> _ContextualizedLoggingObserver: + """Return a contextualized wrapper stamped with node identity.""" + return _ContextualizedLoggingObserver(self, node_hash, node_label) + # -- lifecycle hooks -- def on_run_start(self, run_id: str) -> None: @@ -199,18 +293,18 @@ def on_run_start(self, run_id: str) -> None: def on_run_end(self, run_id: str) -> None: pass - def on_node_start(self, node: Any) -> None: + def on_node_start(self, node_label: str, node_hash: str) -> None: pass - def on_node_end(self, node: Any) -> None: + def on_node_end(self, node_label: str, node_hash: str) -> None: pass - def on_packet_start(self, node: Any, tag: Any, packet: Any) -> None: + def on_packet_start(self, node_label: str, tag: Any, packet: Any) -> None: pass def on_packet_end( self, - node: Any, + node_label: str, tag: Any, input_packet: Any, output_packet: Any, @@ -218,23 +312,28 @@ def on_packet_end( ) -> None: pass - def on_packet_crash(self, node: Any, tag: Any, packet: Any, error: Exception) -> None: + def on_packet_crash( + self, node_label: str, tag: Any, packet: Any, error: Exception + ) -> None: pass def create_packet_logger( self, - node: Any, tag: Any, packet: Any, pipeline_path: tuple[str, ...] = (), ) -> PacketLogger: - """Return a :class:`PacketLogger` bound to *node* + *tag* context. + """Return a :class:`PacketLogger` bound to *tag* context. Log rows are stored at a pipeline-path-mirrored location: ``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. This gives each function node its own log table in the database. + + Note: + When called directly on ``LoggingObserver`` (not a contextualized + wrapper), node_label and node_hash default to "unknown". Prefer + calling via a contextualized observer. """ - node_label = getattr(node, "label", None) or getattr(node, "node_type", "unknown") tag_data = dict(tag) # Compute mirrored log path @@ -247,7 +346,8 @@ def create_packet_logger( db=self._db, log_path=log_path, run_id=self._current_run_id, - node_label=node_label, + node_label="unknown", + node_hash="unknown", tag_data=tag_data, ) diff --git a/src/orcapod/pipeline/observer.py b/src/orcapod/pipeline/observer.py index e9df178c..c19ba366 100644 --- a/src/orcapod/pipeline/observer.py +++ b/src/orcapod/pipeline/observer.py @@ -14,7 +14,6 @@ ) if TYPE_CHECKING: - from orcapod.core.nodes import GraphNode from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol @@ -51,21 +50,26 @@ class NoOpObserver: ``create_packet_logger`` returns the shared :data:`_NOOP_LOGGER` singleton. """ + def contextualize( + self, node_hash: str, node_label: str + ) -> "NoOpObserver": + return self + def on_run_start(self, run_id: str) -> None: pass def on_run_end(self, run_id: str) -> None: pass - def on_node_start(self, node: "GraphNode") -> None: + def on_node_start(self, node_label: str, node_hash: str) -> None: pass - def on_node_end(self, node: "GraphNode") -> None: + def on_node_end(self, node_label: str, node_hash: str) -> None: pass def on_packet_start( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", packet: "PacketProtocol", ) -> None: @@ -73,7 +77,7 @@ def on_packet_start( def on_packet_end( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", input_packet: "PacketProtocol", output_packet: "PacketProtocol | None", @@ -83,7 +87,7 @@ def on_packet_end( def on_packet_crash( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", packet: "PacketProtocol", error: Exception, @@ -92,7 +96,6 @@ def on_packet_crash( def create_packet_logger( self, - node: "GraphNode", tag: "TagProtocol", packet: "PacketProtocol", pipeline_path: tuple[str, ...] = (), diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py index b56d4734..8290cc26 100644 --- a/src/orcapod/protocols/core_protocols/executor.py +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -6,8 +6,8 @@ from orcapod.protocols.core_protocols.datagrams import PacketProtocol if TYPE_CHECKING: - from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols.packet_function import PacketFunctionProtocol + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol @runtime_checkable @@ -37,22 +37,26 @@ def supports(self, packet_function_type_id: str) -> bool: def execute( self, - packet_function: PacketFunctionProtocol, + packet_function: "PacketFunctionProtocol", packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Synchronously execute *packet_function* on *packet*. The executor should invoke ``packet_function.direct_call(packet)`` - in the appropriate execution environment and pass through its - ``(result, CapturedLogs)`` tuple. + in the appropriate execution environment and return the result. + If a logger is provided, the executor records captured I/O to it. """ ... async def async_execute( self, - packet_function: PacketFunctionProtocol, + packet_function: "PacketFunctionProtocol", packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Asynchronous counterpart of ``execute``.""" ... @@ -100,7 +104,9 @@ def execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: """Synchronously execute *fn* with *kwargs*, capturing I/O. Args: @@ -108,12 +114,12 @@ def execute_callable( kwargs: Keyword arguments to pass to *fn*. executor_options: Optional per-call options (e.g. resource overrides). + logger: Optional logger to record captured I/O. Returns: - A ``(raw_result, CapturedLogs)`` tuple. ``raw_result`` is the - return value of *fn* (or ``None`` on failure). - ``CapturedLogs.success`` is ``False`` when the function raised; - the traceback is stored in ``CapturedLogs.traceback``. + The raw return value of *fn* (or ``None`` on failure). + On failure, the executor re-raises the original exception + after recording captured logs. """ ... @@ -122,15 +128,18 @@ async def async_execute_callable( fn: Callable[..., Any], kwargs: dict[str, Any], executor_options: dict[str, Any] | None = None, - ) -> "tuple[Any, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> Any: """Asynchronously execute *fn* with *kwargs*, capturing I/O. Args: fn: The Python callable to execute. kwargs: Keyword arguments to pass to *fn*. executor_options: Optional per-call options. + logger: Optional logger to record captured I/O. Returns: - A ``(raw_result, CapturedLogs)`` tuple. + The raw return value of *fn*. """ ... diff --git a/src/orcapod/protocols/core_protocols/packet_function.py b/src/orcapod/protocols/core_protocols/packet_function.py index c806597a..7792e873 100644 --- a/src/orcapod/protocols/core_protocols/packet_function.py +++ b/src/orcapod/protocols/core_protocols/packet_function.py @@ -12,7 +12,7 @@ from orcapod.types import Schema if TYPE_CHECKING: - from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.observability_protocols import PacketExecutionLoggerProtocol @runtime_checkable @@ -81,55 +81,59 @@ def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: def call( self, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Process a single packet, routing through the executor if one is set. Args: packet: The data payload to process. + logger: Optional logger for recording captured I/O. Returns: - A ``(output_packet, captured_logs)`` tuple. ``output_packet`` - is ``None`` when the function filters the packet out or when - the execution failed (check ``captured_logs.success``). + The output packet, or ``None`` when the function filters the + packet out or when execution failed (exception is re-raised). """ ... async def async_call( self, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + *, + logger: "PacketExecutionLoggerProtocol | None" = None, + ) -> PacketProtocol | None: """Asynchronously process a single packet, routing through the executor if set. Args: packet: The data payload to process. + logger: Optional logger for recording captured I/O. Returns: - A ``(output_packet, captured_logs)`` tuple. + The output packet, or ``None``. """ ... def direct_call( self, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Execute the function's native computation on *packet*. This is the method executors invoke, bypassing executor routing. - On user-function failure the exception is caught internally and - ``(None, captured_with_success=False)`` is returned — no re-raise. + On user-function failure the exception is re-raised. Args: packet: The data payload to process. Returns: - A ``(output_packet, captured_logs)`` tuple. + The output packet, or ``None`` if filtered. """ ... async def direct_async_call( self, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> PacketProtocol | None: """Asynchronous counterpart of ``direct_call``.""" ... diff --git a/src/orcapod/protocols/observability_protocols.py b/src/orcapod/protocols/observability_protocols.py index dad3d0fd..3b4a740f 100644 --- a/src/orcapod/protocols/observability_protocols.py +++ b/src/orcapod/protocols/observability_protocols.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable if TYPE_CHECKING: - from orcapod.core.nodes import GraphNode from orcapod.pipeline.logging_capture import CapturedLogs from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol @@ -51,8 +50,27 @@ class ExecutionObserverProtocol(Protocol): ``on_packet_start`` / ``on_packet_end`` / ``on_packet_crash`` are invoked only for function nodes. ``on_node_start`` / ``on_node_end`` are invoked for all node types. + + Observers are *contextualized* per node via :meth:`contextualize`, which + returns a lightweight wrapper stamped with node identity. The contextualized + observer is used for all hooks and logger creation within that node. """ + def contextualize( + self, node_hash: str, node_label: str + ) -> "ExecutionObserverProtocol": + """Return a contextualized copy stamped with node identity. + + Args: + node_hash: The pipeline hash of the node (stable identity). + node_label: Human-readable label of the node. + + Returns: + An observer (possibly a lightweight wrapper) that carries + node_hash and node_label context for all subsequent calls. + """ + ... + def on_run_start(self, run_id: str) -> None: """Called at the very start of an orchestrator ``run()`` call. @@ -70,17 +88,17 @@ def on_run_end(self, run_id: str) -> None: """ ... - def on_node_start(self, node: "GraphNode") -> None: + def on_node_start(self, node_label: str, node_hash: str) -> None: """Called before a node begins processing its packets.""" ... - def on_node_end(self, node: "GraphNode") -> None: + def on_node_end(self, node_label: str, node_hash: str) -> None: """Called after a node finishes processing all packets.""" ... def on_packet_start( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", packet: "PacketProtocol", ) -> None: @@ -89,7 +107,7 @@ def on_packet_start( def on_packet_end( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", input_packet: "PacketProtocol", output_packet: "PacketProtocol | None", @@ -105,7 +123,7 @@ def on_packet_end( def on_packet_crash( self, - node: "GraphNode", + node_label: str, tag: "TagProtocol", packet: "PacketProtocol", error: Exception, @@ -120,7 +138,6 @@ def on_packet_crash( def create_packet_logger( self, - node: "GraphNode", tag: "TagProtocol", packet: "PacketProtocol", pipeline_path: tuple[str, ...] = (), @@ -132,7 +149,6 @@ def create_packet_logger( without the executor needing to know anything about the pipeline. Args: - node: The graph node being executed. tag: The tag for the packet being processed. packet: The input packet being processed. pipeline_path: The node's pipeline path for log storage scoping. diff --git a/tests/test_channels/test_async_execute.py b/tests/test_channels/test_async_execute.py index c1c73036..a92a76ed 100644 --- a/tests/test_channels/test_async_execute.py +++ b/tests/test_channels/test_async_execute.py @@ -155,8 +155,7 @@ def add(x: int, y: int) -> int: pf = PythonPacketFunction(add, output_keys="result") packet = Packet({"x": 3, "y": 5}) - result, captured = await pf.direct_async_call(packet) - assert captured.success is True + result = await pf.direct_async_call(packet) assert result is not None assert result.as_dict()["result"] == 8 @@ -166,12 +165,11 @@ def double(x: int) -> int: return x * 2 pf = PythonPacketFunction(double, output_keys="result") - raw_results = await asyncio.gather( + results = await asyncio.gather( pf.async_call(Packet({"x": 1})), pf.async_call(Packet({"x": 2})), pf.async_call(Packet({"x": 3})), ) - results = [r for r, _captured in raw_results] assert all(r is not None for r in results) values = [r.as_dict()["result"] for r in results if r is not None] assert values == [2, 4, 6] @@ -188,8 +186,7 @@ def record_thread(x: int) -> int: return x pf = PythonPacketFunction(record_thread, output_keys="result") - result, captured = await pf.direct_async_call(Packet({"x": 42})) - assert captured.success is True + result = await pf.direct_async_call(Packet({"x": 42})) assert result is not None assert len(call_threads) == 1 @@ -732,7 +729,7 @@ async def push(stream, ch): class TestErrorPropagation: @pytest.mark.asyncio async def test_function_exception_returns_none(self): - """An exception in the packet function returns (None, captured) — no raise.""" + """An exception in the packet function is caught by process_packet — no raise.""" def failing(x: int) -> int: if x == 2: @@ -756,16 +753,15 @@ def failing(x: int) -> int: assert values == [0, 1, 3, 4] @pytest.mark.asyncio - async def test_direct_async_call_captures_failure(self): - """direct_async_call returns (None, captured) with success=False on error.""" + async def test_direct_async_call_raises_on_failure(self): + """direct_async_call re-raises the exception on error.""" def failing(x: int) -> int: raise ValueError("boom") pf = PythonPacketFunction(failing, output_keys="result") - result, captured = await pf.direct_async_call(Packet({"x": 1})) - assert result is None - assert captured.success is False + with pytest.raises(ValueError, match="boom"): + await pf.direct_async_call(Packet({"x": 1})) # --------------------------------------------------------------------------- diff --git a/tests/test_channels/test_copilot_review_issues.py b/tests/test_channels/test_copilot_review_issues.py index f9891945..4a1bd915 100644 --- a/tests/test_channels/test_copilot_review_issues.py +++ b/tests/test_channels/test_copilot_review_issues.py @@ -142,7 +142,7 @@ async def run_in_loop(): nonlocal creation_count with patch.object(ThreadPoolExecutor, "__init__", counting_init): for _ in range(3): - _result, _captured = pf.direct_call(Packet({"x": 1, "y": 2})) + _result = pf.direct_call(Packet({"x": 1, "y": 2})) asyncio.run(run_in_loop()) # Current code creates a new executor per call, so creation_count == 3. @@ -213,7 +213,7 @@ def instrumented_call(packet): # Run inside an event loop to trigger the ThreadPoolExecutor path async def run_in_loop(): - _result, _captured = pf.direct_call(Packet({"x": 5})) + _result = pf.direct_call(Packet({"x": 5})) asyncio.run(run_in_loop()) diff --git a/tests/test_channels/test_node_async_execute.py b/tests/test_channels/test_node_async_execute.py index 7d29c1a7..1173e6fa 100644 --- a/tests/test_channels/test_node_async_execute.py +++ b/tests/test_channels/test_node_async_execute.py @@ -91,7 +91,7 @@ def double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - result, _captured = await cpf.async_call(packet) + result = await cpf.async_call(packet) assert result is not None assert result.as_dict()["result"] == 10 @@ -111,13 +111,13 @@ def double(x: int) -> int: packet = Packet({"x": 5}) # First call — computes - result1, _captured1 = await cpf.async_call(packet) + result1 = await cpf.async_call(packet) assert result1 is not None # Has RESULT_COMPUTED_FLAG assert result1.get_meta_value(cpf.RESULT_COMPUTED_FLAG, False) is True # Second call — should hit cache (no RESULT_COMPUTED_FLAG set to True) - result2, _captured2 = await cpf.async_call(packet) + result2 = await cpf.async_call(packet) assert result2 is not None assert result2.as_dict()["result"] == 10 # Cache hit should NOT have RESULT_COMPUTED_FLAG=True @@ -138,11 +138,11 @@ def counting_double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - _result1, _captured1 = await cpf.async_call(packet) + _result1 = await cpf.async_call(packet) assert call_count == 1 # With skip_cache_lookup, should recompute - _result2, _captured2 = await cpf.async_call(packet, skip_cache_lookup=True) + _result2 = await cpf.async_call(packet, skip_cache_lookup=True) assert call_count == 2 @pytest.mark.asyncio @@ -155,7 +155,7 @@ def double(x: int) -> int: cpf = CachedPacketFunction(pf, result_database=db) packet = Packet({"x": 5}) - result, _captured = await cpf.async_call(packet, skip_cache_insert=True) + result = await cpf.async_call(packet, skip_cache_insert=True) assert result is not None assert result.as_dict()["result"] == 10 @@ -613,9 +613,9 @@ async def test_function_node_async_uses_async_process_packet_internal(self): original = node._async_process_packet_internal - async def patched(tag, packet): + async def patched(tag, packet, **kwargs): call_log.append("_async_process_packet_internal") - return await original(tag, packet) + return await original(tag, packet, **kwargs) node._async_process_packet_internal = patched diff --git a/tests/test_core/packet_function/test_cached_packet_function.py b/tests/test_core/packet_function/test_cached_packet_function.py index 5547cbfa..9a4218a1 100644 --- a/tests/test_core/packet_function/test_cached_packet_function.py +++ b/tests/test_core/packet_function/test_cached_packet_function.py @@ -139,11 +139,11 @@ def test_returns_none_when_no_records(self, cached_pf): class TestCallCacheMiss: def test_returns_non_none_result(self, cached_pf, input_packet): - result, _captured = cached_pf.call(input_packet) + result = cached_pf.call(input_packet) assert result is not None def test_result_has_correct_value(self, cached_pf, input_packet): - result, _captured = cached_pf.call(input_packet) + result = cached_pf.call(input_packet) assert result["result"] == 7 # 3 + 4 def test_result_stored_in_database(self, cached_pf, input_packet, db): @@ -167,7 +167,7 @@ def test_get_all_cached_outputs_non_empty_after_call(self, cached_pf, input_pack class TestCallCacheHit: def test_second_call_returns_result(self, cached_pf, input_packet): cached_pf.call(input_packet) - result, _captured = cached_pf.call(input_packet) + result = cached_pf.call(input_packet) assert result is not None assert result["result"] == 7 @@ -204,7 +204,7 @@ def counting_add(x: int, y: int) -> int: class TestSkipCacheLookup: def test_skip_cache_lookup_still_returns_result(self, cached_pf, input_packet): cached_pf.call(input_packet) # populate cache - result, _captured = cached_pf.call(input_packet, skip_cache_lookup=True) + result = cached_pf.call(input_packet, skip_cache_lookup=True) assert result is not None assert result["result"] == 7 @@ -226,7 +226,7 @@ def test_skip_cache_lookup_adds_another_record(self, cached_pf, input_packet, db class TestSkipCacheInsert: def test_skip_cache_insert_returns_result(self, cached_pf, input_packet): - result, _captured = cached_pf.call(input_packet, skip_cache_insert=True) + result = cached_pf.call(input_packet, skip_cache_insert=True) assert result is not None assert result["result"] == 7 @@ -398,12 +398,12 @@ def test_get_execution_data_delegates(self, wrapper, inner_pf): assert wrapper.get_execution_data() == inner_pf.get_execution_data() def test_call_delegates(self, wrapper, input_packet): - result, _captured = wrapper.call(input_packet) + result = wrapper.call(input_packet) assert result is not None assert result["result"] == 7 # 3 + 4 def test_async_call_delegates_through_wrapper(self, wrapper, input_packet): - result, _captured = asyncio.run(wrapper.async_call(input_packet)) + result = asyncio.run(wrapper.async_call(input_packet)) assert result is not None assert result["result"] == 7 # 3 + 4 @@ -493,7 +493,7 @@ def test_inactive_inner_returns_none_and_does_not_store( ): inner_pf.set_active(False) cpf = CachedPacketFunction(inner_pf, result_database=db) - result, _captured = cpf.call(input_packet) + result = cpf.call(input_packet) assert result is None assert db.get_all_records(cpf.record_path) is None @@ -535,27 +535,27 @@ class TestResultComputedFlag: """Verify the meta flag that distinguishes fresh computation from cache hits.""" def test_cache_miss_sets_computed_true(self, cached_pf, input_packet): - result, _captured = cached_pf.call(input_packet) + result = cached_pf.call(input_packet) assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is True def test_cache_hit_sets_computed_false(self, cached_pf, input_packet): cached_pf.call(input_packet) # first call — populates cache - result, _captured = cached_pf.call(input_packet) # second call — cache hit + result = cached_pf.call(input_packet) # second call — cache hit assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is False def test_skip_cache_lookup_sets_computed_true(self, cached_pf, input_packet): cached_pf.call(input_packet) # populate cache - result, _captured = cached_pf.call(input_packet, skip_cache_lookup=True) + result = cached_pf.call(input_packet, skip_cache_lookup=True) assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is True def test_skip_cache_insert_sets_computed_true(self, cached_pf, input_packet): - result, _captured = cached_pf.call(input_packet, skip_cache_insert=True) + result = cached_pf.call(input_packet, skip_cache_insert=True) assert result is not None flag = result.get_meta_value(CachedPacketFunction.RESULT_COMPUTED_FLAG) assert flag is True diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index fe7548df..78ee2d2b 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -60,15 +60,13 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> "PacketProtocol | None": self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): - from orcapod.pipeline.logging_capture import CapturedLogs - + def execute_callable(self, fn, kwargs, executor_options=None, **kw): self.calls.append((fn, kwargs)) - return fn(**kwargs), CapturedLogs(success=True) + return fn(**kwargs) class PythonOnlyExecutor(PacketFunctionExecutorBase): @@ -85,7 +83,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> "PacketProtocol | None": return packet_function.direct_call(packet) @@ -103,7 +101,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> "PacketProtocol | None": return packet_function.direct_call(packet) @@ -187,7 +185,7 @@ def test_execute_delegates_to_direct_call( add_pf: PythonPacketFunction, add_packet: Packet, ): - result, _captured = local_executor.execute(add_pf, add_packet) + result = local_executor.execute(add_pf, add_packet) assert result is not None assert result.as_dict()["result"] == 3 @@ -207,8 +205,8 @@ def test_with_options_returns_new_instance(self, local_executor: LocalExecutor): class TestExecutorProperty: - def test_default_executor_is_none(self, add_pf: PythonPacketFunction): - assert add_pf.executor is None + def test_default_executor_is_local(self, add_pf: PythonPacketFunction): + assert isinstance(add_pf.executor, LocalExecutor) def test_set_executor( self, add_pf: PythonPacketFunction, spy_executor: SpyExecutor @@ -216,7 +214,7 @@ def test_set_executor( add_pf.executor = spy_executor assert add_pf.executor is spy_executor - def test_unset_executor( + def test_unset_executor_to_none( self, add_pf: PythonPacketFunction, spy_executor: SpyExecutor ): add_pf.executor = spy_executor @@ -248,7 +246,7 @@ class TestExecutorRouting: def test_call_without_executor_uses_direct_call( self, add_pf: PythonPacketFunction, add_packet: Packet ): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 @@ -259,7 +257,7 @@ def test_call_with_executor_routes_through_executor( spy_executor: SpyExecutor, ): add_pf.executor = spy_executor - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy_executor.calls) == 1 @@ -271,7 +269,7 @@ def test_direct_call_bypasses_executor( spy_executor: SpyExecutor, ): add_pf.executor = spy_executor - result, _captured = add_pf.direct_call(add_packet) + result = add_pf.direct_call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 # Executor was NOT called @@ -345,7 +343,7 @@ class SimpleWrapper(PacketFunctionWrapper): pass wrapper = SimpleWrapper(add_pf, version="v0.0") - result, _captured = wrapper.call(add_packet) + result = wrapper.call(add_packet) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.calls) == 1 @@ -409,15 +407,16 @@ def test_pod_executor_set_targets_packet_function(self): pod.executor = spy assert pf.executor is spy - def test_pod_executor_unset(self): + def test_pod_executor_swap(self): from orcapod.core.function_pod import FunctionPod spy = SpyExecutor() + spy2 = SpyExecutor() pf = PythonPacketFunction(add, output_keys="result") pf.executor = spy pod = FunctionPod(pf) - pod.executor = None - assert pf.executor is None + pod.executor = spy2 + assert pf.executor is spy2 def test_pod_process_uses_executor(self): from orcapod.core.function_pod import FunctionPod @@ -547,14 +546,14 @@ def test_decorator_incompatible_executor_raises(self): def my_add(x: int, y: int) -> int: return x + y - def test_decorator_without_executor_defaults_to_none(self): + def test_decorator_without_executor_defaults_to_local(self): from orcapod.core.function_pod import function_pod @function_pod(output_keys="result") def my_add(x: int, y: int) -> int: return x + y - assert my_add.pod.executor is None + assert isinstance(my_add.pod.executor, LocalExecutor) # --------------------------------------------------------------------------- @@ -601,9 +600,7 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": - from orcapod.pipeline.logging_capture import CapturedLogs - + ) -> "PacketProtocol | None": self.sync_calls.append(packet) return packet_function.direct_call(packet) @@ -611,21 +608,17 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, CapturedLogs]": + ) -> "PacketProtocol | None": self.async_calls.append(packet) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): - from orcapod.pipeline.logging_capture import CapturedLogs - + def execute_callable(self, fn, kwargs, executor_options=None, **kw): self.sync_calls.append(kwargs) - return fn(**kwargs), CapturedLogs(success=True) - - async def async_execute_callable(self, fn, kwargs, executor_options=None): - from orcapod.pipeline.logging_capture import CapturedLogs + return fn(**kwargs) + async def async_execute_callable(self, fn, kwargs, executor_options=None, **kw): self.async_calls.append(kwargs) - return fn(**kwargs), CapturedLogs(success=True) + return fn(**kwargs) class TestConcurrentIteration: @@ -744,17 +737,15 @@ def test_spy_executor_satisfies_protocol(self): def test_execute_callable_runs_function(self): executor = LocalExecutor() - result, captured = executor.execute_callable(add, {"x": 3, "y": 4}) + result = executor.execute_callable(add, {"x": 3, "y": 4}) assert result == 7 - assert captured.success is True def test_execute_callable_with_executor_options(self): executor = LocalExecutor() - result, captured = executor.execute_callable( + result = executor.execute_callable( add, {"x": 1, "y": 2}, executor_options={"num_cpus": 1} ) assert result == 3 - assert captured.success is True # --------------------------------------------------------------------------- @@ -795,7 +786,8 @@ def test_set_executor_accepts_compatible_protocol(self): assert pf.executor is executor def test_set_executor_accepts_none(self): - pf = PythonPacketFunction(add, output_keys="result", executor=LocalExecutor()) + spy = SpyExecutor() + pf = PythonPacketFunction(add, output_keys="result", executor=spy) pf.set_executor(None) assert pf.executor is None @@ -811,7 +803,7 @@ def test_call_routes_through_execute_callable(self): spy = SpyExecutor() pf = PythonPacketFunction(add, output_keys="result", executor=spy) packet = Packet({"x": 1, "y": 2}) - result, _captured = pf.call(packet) + result = pf.call(packet) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.calls) == 1 @@ -822,7 +814,7 @@ def test_call_with_inactive_function_returns_none(self): pf = PythonPacketFunction(add, output_keys="result", executor=spy) pf.set_active(False) packet = Packet({"x": 1, "y": 2}) - result, _captured = pf.call(packet) + result = pf.call(packet) assert result is None assert len(spy.calls) == 0 @@ -833,7 +825,7 @@ def test_async_call_routes_through_async_execute_callable(self): spy = ConcurrentSpyExecutor() pf = PythonPacketFunction(add, output_keys="result", executor=spy) packet = Packet({"x": 1, "y": 2}) - result, _captured = asyncio.run(pf.async_call(packet)) + result = asyncio.run(pf.async_call(packet)) assert result is not None assert result.as_dict()["result"] == 3 assert len(spy.async_calls) == 1 @@ -854,18 +846,16 @@ async def async_add(x: int, y: int) -> int: return x + y executor = LocalExecutor() - result, captured = executor.execute_callable(async_add, {"x": 5, "y": 3}) + result = executor.execute_callable(async_add, {"x": 5, "y": 3}) assert result == 8 - assert captured.success is True def test_async_execute_callable_with_sync_fn(self): """LocalExecutor.async_execute_callable handles sync fns via run_in_executor.""" import asyncio executor = LocalExecutor() - result, captured = asyncio.run(executor.async_execute_callable(add, {"x": 10, "y": 20})) + result = asyncio.run(executor.async_execute_callable(add, {"x": 10, "y": 20})) assert result == 30 - assert captured.success is True def test_async_execute_callable_with_async_fn(self): """LocalExecutor.async_execute_callable awaits async functions directly.""" @@ -875,8 +865,7 @@ async def async_add(x: int, y: int) -> int: return x + y executor = LocalExecutor() - result, captured = asyncio.run( + result = asyncio.run( executor.async_execute_callable(async_add, {"x": 7, "y": 8}) ) assert result == 15 - assert captured.success is True diff --git a/tests/test_core/packet_function/test_packet_function.py b/tests/test_core/packet_function/test_packet_function.py index b06157cb..636dd29b 100644 --- a/tests/test_core/packet_function/test_packet_function.py +++ b/tests/test_core/packet_function/test_packet_function.py @@ -420,29 +420,29 @@ def test_set_active_true_re_enables(self, add_pf): class TestCall: def test_returns_packet_when_active(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert result is not None def test_output_has_correct_key(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert "result" in result.keys() def test_output_has_correct_value(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert result["result"] == 3 # 1 + 2 def test_source_info_contains_result_key(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) source = result.source_info() assert "result" in source def test_source_info_ends_with_key_name(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) source_str = result.source_info()["result"] assert source_str.endswith("::result") def test_source_info_contains_uri_components(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) source_str = result.source_info()["result"] for component in add_pf.uri: assert component in source_str @@ -450,7 +450,7 @@ def test_source_info_contains_uri_components(self, add_pf, add_packet): def test_source_info_record_id_is_uuid(self, add_pf, add_packet): import re - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) source_str = result.source_info()["result"] # The record_id segment is between the URI components and the key name # Format: uri_part1:uri_part2:..::record_id::key @@ -461,18 +461,18 @@ def test_source_info_record_id_is_uuid(self, add_pf, add_packet): def test_inactive_returns_none(self, add_pf, add_packet): add_pf.set_active(False) - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert result is None def test_multiple_output_keys(self, multi_pf): packet = Packet({"a": 3, "b": 4}) - result, _captured = multi_pf.call(packet) + result = multi_pf.call(packet) assert result["sum"] == 7 # 3 + 4 assert result["product"] == 12 # 3 * 4 def test_multiple_output_keys_source_info(self, multi_pf): packet = Packet({"a": 3, "b": 4}) - result, _captured = multi_pf.call(packet) + result = multi_pf.call(packet) source = result.source_info() assert "sum" in source assert "product" in source @@ -480,7 +480,7 @@ def test_multiple_output_keys_source_info(self, multi_pf): assert source["product"].endswith("::product") def test_output_packet_schema_applied(self, add_pf, add_packet): - result, _captured = add_pf.call(add_packet) + result = add_pf.call(add_packet) assert result is not None # schema from the packet function should carry through schema = result.schema() @@ -531,7 +531,7 @@ def returns_one(a, b): class TestAsyncCall: def test_async_call_returns_correct_result(self, add_pf, add_packet): - result, _captured = asyncio.run(add_pf.async_call(add_packet)) + result = asyncio.run(add_pf.async_call(add_packet)) assert result is not None assert result.as_dict()["result"] == 3 # 1 + 2 @@ -597,28 +597,28 @@ async def bad(*args: int) -> int: class TestAsyncFunctionSyncCall: def test_direct_call_returns_correct_result(self, async_add_pf, add_packet): - result, _captured = async_add_pf.direct_call(add_packet) + result = async_add_pf.direct_call(add_packet) assert result is not None assert result["result"] == 3 def test_call_returns_correct_result(self, async_add_pf, add_packet): - result, _captured = async_add_pf.call(add_packet) + result = async_add_pf.call(add_packet) assert result is not None assert result["result"] == 3 def test_inactive_returns_none(self, async_add_pf, add_packet): async_add_pf.set_active(False) - result, _captured = async_add_pf.call(add_packet) + result = async_add_pf.call(add_packet) assert result is None def test_multiple_outputs(self, async_multi_pf): packet = Packet({"a": 3, "b": 4}) - result, _captured = async_multi_pf.call(packet) + result = async_multi_pf.call(packet) assert result["sum"] == 7 assert result["product"] == 12 def test_source_info_present(self, async_add_pf, add_packet): - result, _captured = async_add_pf.call(add_packet) + result = async_add_pf.call(add_packet) source = result.source_info() assert "result" in source assert source["result"].endswith("::result") @@ -631,22 +631,22 @@ def test_source_info_present(self, async_add_pf, add_packet): class TestAsyncFunctionAsyncCall: def test_direct_async_call_awaits_directly(self, async_add_pf, add_packet): - result, _captured = asyncio.run(async_add_pf.direct_async_call(add_packet)) + result = asyncio.run(async_add_pf.direct_async_call(add_packet)) assert result is not None assert result["result"] == 3 def test_async_call_returns_correct_result(self, async_add_pf, add_packet): - result, _captured = asyncio.run(async_add_pf.async_call(add_packet)) + result = asyncio.run(async_add_pf.async_call(add_packet)) assert result is not None assert result["result"] == 3 def test_inactive_returns_none(self, async_add_pf, add_packet): async_add_pf.set_active(False) - result, _captured = asyncio.run(async_add_pf.async_call(add_packet)) + result = asyncio.run(async_add_pf.async_call(add_packet)) assert result is None def test_multiple_outputs(self, async_multi_pf): packet = Packet({"a": 3, "b": 4}) - result, _captured = asyncio.run(async_multi_pf.async_call(packet)) + result = asyncio.run(async_multi_pf.async_call(packet)) assert result["sum"] == 7 assert result["product"] == 12 diff --git a/tests/test_core/test_regression_fixes.py b/tests/test_core/test_regression_fixes.py index c5d3a298..824a7bcf 100644 --- a/tests/test_core/test_regression_fixes.py +++ b/tests/test_core/test_regression_fixes.py @@ -78,15 +78,15 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> tuple[PacketProtocol | None, Any]: + *, + logger=None, + ) -> PacketProtocol | None: self.calls.append((packet_function, packet)) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): - from orcapod.pipeline.logging_capture import CapturedLogs - + def execute_callable(self, fn, kwargs, executor_options=None, *, logger=None): self.calls.append((fn, kwargs)) - return fn(**kwargs), CapturedLogs(success=True) + return fn(**kwargs) # =========================================================================== @@ -99,9 +99,8 @@ class TestAsyncExecuteChannelCloseOnError: @pytest.mark.asyncio async def test_unary_operator_closes_channel_on_error(self): - """When a packet function raises, direct_call catches the exception - and returns (None, captured_failure). process_packet discards the - captured logs and returns (tag, None). The output channel is closed + """When a packet function raises, process_packet catches the exception + and returns (tag, None). The output channel is closed normally and no exception propagates.""" def failing(x: int) -> int: @@ -195,7 +194,7 @@ def test_direct_call_does_not_invoke_executor(self): _, spy, wrapper = self._make_add_pf_with_spy() packet = Packet({"x": 3, "y": 4}) - result, _captured = wrapper.direct_call(packet) + result = wrapper.direct_call(packet) assert result is not None assert result.as_dict()["result"] == 7 @@ -207,7 +206,7 @@ async def test_direct_async_call_does_not_invoke_executor(self): _, spy, wrapper = self._make_add_pf_with_spy() packet = Packet({"x": 3, "y": 4}) - result, _captured = await wrapper.direct_async_call(packet) + result = await wrapper.direct_async_call(packet) assert result is not None assert result.as_dict()["result"] == 7 @@ -218,7 +217,7 @@ def test_call_still_routes_through_executor(self): _, spy, wrapper = self._make_add_pf_with_spy() packet = Packet({"x": 3, "y": 4}) - result, _captured = wrapper.call(packet) + result = wrapper.call(packet) assert result is not None assert result.as_dict()["result"] == 7 @@ -311,13 +310,13 @@ def double(x: int) -> int: # Patch async_call to use our concurrency-tracking function original_async_call = pf.async_call - async def tracked_async_call(packet: PacketProtocol) -> tuple: + async def tracked_async_call(packet: PacketProtocol, **kwargs): nonlocal concurrent_count, max_concurrent concurrent_count += 1 max_concurrent = max(max_concurrent, concurrent_count) await asyncio.sleep(0.01) concurrent_count -= 1 - return await original_async_call(packet) + return await original_async_call(packet, **kwargs) pf.async_call = tracked_async_call # type: ignore diff --git a/tests/test_core/test_result_cache.py b/tests/test_core/test_result_cache.py index 16d6a404..cda78270 100644 --- a/tests/test_core/test_result_cache.py +++ b/tests/test_core/test_result_cache.py @@ -48,7 +48,7 @@ def _compute_and_store( cache: ResultCache, pf: PythonPacketFunction, input_packet: Packet ): """Helper: compute output and store in cache.""" - output, _captured = pf.direct_call(input_packet) + output = pf.direct_call(input_packet) assert output is not None cache.store( input_packet, @@ -105,7 +105,7 @@ def test_same_packet_different_record_path_is_miss(self): pf = _make_pf() input_pkt = Packet({"x": 10}) - output, _captured = pf.direct_call(input_pkt) + output = pf.direct_call(input_pkt) cache_a.store( input_pkt, output, @@ -130,7 +130,7 @@ def test_most_recent_wins(self): time.sleep(0.01) # ensure different timestamp # Store a second result for the same input (simulating recomputation) - output2, _captured = pf.direct_call(input_pkt) + output2 = pf.direct_call(input_pkt) cache.store( input_pkt, output2, diff --git a/tests/test_pipeline/test_node_protocols.py b/tests/test_pipeline/test_node_protocols.py index 65681222..20d210bc 100644 --- a/tests/test_pipeline/test_node_protocols.py +++ b/tests/test_pipeline/test_node_protocols.py @@ -157,17 +157,18 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): - events.append(("start", n.node_type)) - def on_node_end(self, n): - events.append(("end", n.node_type)) - def on_packet_start(self, n, t, p): + def on_node_start(self, node_label, node_hash): + events.append(("start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("end", node_label)) + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass node.execute(observer=Obs()) - assert events == [("start", "source"), ("end", "source")] + assert events[0][0] == "start" + assert events[1][0] == "end" def test_execute_without_observer(self): """execute() works fine with no observer.""" @@ -206,13 +207,13 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): + def on_node_start(self, node_label, node_hash): events.append("start") - def on_node_end(self, n): + def on_node_end(self, node_label, node_hash): events.append("end") - def on_packet_start(self, n, t, p): + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass output_ch = Channel(buffer_size=16) @@ -249,17 +250,19 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): - events.append(("node_start", n.node_type)) - def on_node_end(self, n): - events.append(("node_end", n.node_type)) - def on_packet_start(self, n, t, p): + def contextualize(self, node_hash, node_label): + return self + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, t, p): events.append(("packet_start",)) - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): events.append(("packet_end", cached)) - def on_packet_crash(self, n, t, p, exc): + def on_packet_crash(self, node_label, t, p, exc): pass - def create_packet_logger(self, n, t, p, **kwargs): + def create_packet_logger(self, t, p, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER @@ -267,8 +270,8 @@ def create_packet_logger(self, n, t, p, **kwargs): result = node.execute(input_stream, observer=Obs()) assert len(result) == 2 - assert events[0] == ("node_start", "function") - assert events[-1] == ("node_end", "function") + assert events[0][0] == "node_start" + assert events[-1][0] == "node_end" packet_events = [e for e in events if e[0].startswith("packet")] assert len(packet_events) == 4 # 2 start + 2 end @@ -326,17 +329,19 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): + def contextualize(self, node_hash, node_label): + return self + def on_node_start(self, node_label, node_hash): events.append("node_start") - def on_node_end(self, n): + def on_node_end(self, node_label, node_hash): events.append("node_end") - def on_packet_start(self, n, t, p): + def on_packet_start(self, node_label, t, p): events.append("pkt_start") - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): events.append("pkt_end") - def on_packet_crash(self, n, t, p, exc): + def on_packet_crash(self, node_label, t, p, exc): pass - def create_packet_logger(self, n, t, p, **kwargs): + def create_packet_logger(self, t, p, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER @@ -378,18 +383,19 @@ def test_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): - events.append(("node_start", n.node_type)) - def on_node_end(self, n): - events.append(("node_end", n.node_type)) - def on_packet_start(self, n, t, p): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass result = node.execute(*node._input_streams, observer=Obs()) assert len(result) == 2 - assert events == [("node_start", "operator"), ("node_end", "operator")] + assert events[0][0] == "node_start" + assert events[1][0] == "node_end" def test_execute_without_observer(self): node = self._make_join_node() @@ -418,13 +424,13 @@ async def test_async_execute_with_observer(self): events = [] class Obs: - def on_node_start(self, n): + def on_node_start(self, node_label, node_hash): events.append("start") - def on_node_end(self, n): + def on_node_end(self, node_label, node_hash): events.append("end") - def on_packet_start(self, n, t, p): + def on_packet_start(self, node_label, t, p): pass - def on_packet_end(self, n, t, ip, op, cached): + def on_packet_end(self, node_label, t, ip, op, cached): pass input_ch = Channel(buffer_size=16) diff --git a/tests/test_pipeline/test_observer.py b/tests/test_pipeline/test_observer.py index 91a50745..7d0f177c 100644 --- a/tests/test_pipeline/test_observer.py +++ b/tests/test_pipeline/test_observer.py @@ -23,28 +23,32 @@ def test_on_run_end_noop(self): NoOpObserver().on_run_end("run-123") def test_on_node_start_noop(self): - NoOpObserver().on_node_start(None) # type: ignore[arg-type] + NoOpObserver().on_node_start("label", "hash") def test_on_node_end_noop(self): - NoOpObserver().on_node_end(None) # type: ignore[arg-type] + NoOpObserver().on_node_end("label", "hash") def test_on_packet_start_noop(self): - NoOpObserver().on_packet_start(None, None, None) # type: ignore[arg-type] + NoOpObserver().on_packet_start("label", None, None) # type: ignore[arg-type] def test_on_packet_end_noop(self): - NoOpObserver().on_packet_end(None, None, None, None, cached=False) # type: ignore[arg-type] + NoOpObserver().on_packet_end("label", None, None, None, cached=False) # type: ignore[arg-type] def test_on_packet_crash_noop(self): - NoOpObserver().on_packet_crash(None, None, None, RuntimeError("boom")) # type: ignore[arg-type] + NoOpObserver().on_packet_crash("label", None, None, RuntimeError("boom")) # type: ignore[arg-type] def test_create_packet_logger_returns_noop(self): - logger = NoOpObserver().create_packet_logger(None, None, None) # type: ignore[arg-type] + logger = NoOpObserver().create_packet_logger(None, None) # type: ignore[arg-type] assert logger is _NOOP_LOGGER def test_create_packet_logger_satisfies_protocol(self): - logger = NoOpObserver().create_packet_logger(None, None, None) # type: ignore[arg-type] + logger = NoOpObserver().create_packet_logger(None, None) # type: ignore[arg-type] assert isinstance(logger, PacketExecutionLoggerProtocol) + def test_contextualize_returns_self(self): + obs = NoOpObserver() + assert obs.contextualize("hash", "label") is obs + class TestNoOpLogger: """NoOpLogger satisfies the protocol and discards everything.""" @@ -60,6 +64,6 @@ def test_record_noop(self): def test_noop_logger_singleton_identity(self): # The singleton returned by create_packet_logger is always the same object. obs = NoOpObserver() - l1 = obs.create_packet_logger(None, None, None) # type: ignore[arg-type] - l2 = obs.create_packet_logger(None, None, None) # type: ignore[arg-type] + l1 = obs.create_packet_logger(None, None) # type: ignore[arg-type] + l2 = obs.create_packet_logger(None, None) # type: ignore[arg-type] assert l1 is l2 diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index 40459810..587b4d1a 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -467,7 +467,7 @@ def failing_fn(value: int) -> int: crashes = [] class CrashRecorder(NoOpObserver): - def on_packet_crash(self, node, tag, packet, error): + def on_packet_crash(self, node_label, tag, packet, error): crashes.append(error) pipeline.compile() @@ -496,33 +496,35 @@ def test_linear_pipeline_observer_hooks(self): class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): - events.append(("packet_start", node.node_type)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - events.append(("packet_end", node.node_type, cached)) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): + events.append(("packet_start", node_label)) + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", node_label, cached)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self pipeline.compile() orch = AsyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - # Source fires node_start/node_end - assert ("node_start", "source") in events - assert ("node_end", "source") in events + # Source fires node_start/node_end (label contains "ArrowTableSource" or similar) + source_starts = [e for e in events if e[0] == "node_start" and e[1] != "doubler"] + assert len(source_starts) >= 1 # Function fires node_start, per-packet hooks, node_end - assert ("node_start", "function") in events - assert ("node_end", "function") in events + assert ("node_start", "doubler") in events + assert ("node_end", "doubler") in events fn_packet_ends = [ e for e in events - if e[0] == "packet_end" and e[1] == "function" + if e[0] == "packet_end" and e[1] == "doubler" ] assert len(fn_packet_ends) == 2 # All should be cached=False (first run, no DB) @@ -551,32 +553,34 @@ def double_val(val: int) -> int: class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): - events.append(("packet_start", node.node_type)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - events.append(("packet_end", node.node_type)) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): + events.append(("packet_start", node_label)) + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", node_label)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self pipeline.compile() orch = AsyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - # All three node types fire start/end - for node_type in ("source", "operator", "function"): - assert ("node_start", node_type) in events - assert ("node_end", node_type) in events + # All labeled nodes fire start/end + assert ("node_start", "mapper") in events + assert ("node_end", "mapper") in events + assert ("node_start", "doubler") in events + assert ("node_end", "doubler") in events # Only function nodes fire packet-level hooks - assert ("packet_start", "function") in events - assert ("packet_start", "source") not in events - assert ("packet_start", "operator") not in events + assert ("packet_start", "doubler") in events + assert ("packet_start", "mapper") not in events def test_no_observer_works(self): """Async pipeline runs fine with no observer.""" diff --git a/tests/test_pipeline/test_pipeline.py b/tests/test_pipeline/test_pipeline.py index 458680d9..203d9ec9 100644 --- a/tests/test_pipeline/test_pipeline.py +++ b/tests/test_pipeline/test_pipeline.py @@ -1164,7 +1164,9 @@ def execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, Any]": + *, + logger=None, + ) -> "PacketProtocol | None": self.sync_calls.append(packet) return packet_function.direct_call(packet) @@ -1172,21 +1174,19 @@ async def async_execute( self, packet_function: PacketFunctionProtocol, packet: PacketProtocol, - ) -> "tuple[PacketProtocol | None, Any]": + *, + logger=None, + ) -> "PacketProtocol | None": self.async_calls.append(packet) return packet_function.direct_call(packet) - def execute_callable(self, fn, kwargs, executor_options=None): - from orcapod.pipeline.logging_capture import CapturedLogs - + def execute_callable(self, fn, kwargs, executor_options=None, *, logger=None): self.sync_calls.append(kwargs) - return fn(**kwargs), CapturedLogs(success=True) - - async def async_execute_callable(self, fn, kwargs, executor_options=None): - from orcapod.pipeline.logging_capture import CapturedLogs + return fn(**kwargs) + async def async_execute_callable(self, fn, kwargs, executor_options=None, *, logger=None): self.async_calls.append(kwargs) - return fn(**kwargs), CapturedLogs(success=True) + return fn(**kwargs) def with_options(self, **opts: Any) -> "_MockExecutor": return _MockExecutor(opts={**self.opts, **opts}) diff --git a/tests/test_pipeline/test_sync_orchestrator.py b/tests/test_pipeline/test_sync_orchestrator.py index 0bdcba10..bdd58efb 100644 --- a/tests/test_pipeline/test_sync_orchestrator.py +++ b/tests/test_pipeline/test_sync_orchestrator.py @@ -132,28 +132,32 @@ def test_observer_hooks_fire(self): class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): events.append(("packet_start",)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): events.append(("packet_end", cached)) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self orch = SyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - assert events[0] == ("node_start", "source") - assert events[1] == ("node_end", "source") - assert events[2] == ("node_start", "function") + # First two events are source node start/end + assert events[0][0] == "node_start" + assert events[1][0] == "node_end" + # Then function node start, packet hooks, function node end + assert events[2] == ("node_start", "doubler") assert events[3] == ("packet_start",) assert events[4] == ("packet_end", False) - assert events[5] == ("node_end", "function") + assert events[5] == ("node_end", "doubler") class TestSyncOrchestratorUnknownNodeType: @@ -209,18 +213,20 @@ def test_run_with_explicit_orchestrator(self): class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): events.append(("packet_start",)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): events.append(("packet_end",)) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self orch = SyncPipelineOrchestrator(observer=RecordingObserver()) pipeline.run(orchestrator=orch) @@ -417,36 +423,33 @@ def double_val(val: int) -> int: class RecordingObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): - events.append(("node_start", node.node_type)) - def on_node_end(self, node): - events.append(("node_end", node.node_type)) - def on_packet_start(self, node, tag, packet): - events.append(("packet_start", node.node_type)) - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - events.append(("packet_end", node.node_type, cached)) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_node_start(self, node_label, node_hash): + events.append(("node_start", node_label)) + def on_node_end(self, node_label, node_hash): + events.append(("node_end", node_label)) + def on_packet_start(self, node_label, tag, packet): + events.append(("packet_start", node_label)) + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", node_label, cached)) + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self orch = SyncPipelineOrchestrator(observer=RecordingObserver()) orch.run(pipeline._node_graph) - # Source fires node_start/node_end only (no packet-level hooks) - assert ("node_start", "source") in events - assert ("node_end", "source") in events - assert ("packet_start", "source") not in events - - # Operator fires node_start/node_end only - assert ("node_start", "operator") in events - assert ("node_end", "operator") in events - assert ("packet_start", "operator") not in events + # Mapper fires node_start/node_end only (no packet-level hooks) + assert ("node_start", "mapper") in events + assert ("node_end", "mapper") in events + assert ("packet_start", "mapper") not in events # Function fires node_start, per-packet hooks, node_end - assert ("node_start", "function") in events - assert ("node_end", "function") in events - fn_packet_events = [e for e in events if e[0] == "packet_start" and e[1] == "function"] + assert ("node_start", "doubler") in events + assert ("node_end", "doubler") in events + fn_packet_events = [e for e in events if e[0] == "packet_start" and e[1] == "doubler"] assert len(fn_packet_events) == 2 # 2 packets def test_function_node_cached_flag(self): @@ -466,16 +469,18 @@ def test_function_node_cached_flag(self): class Obs1: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): pass - def on_node_end(self, node): pass - def on_packet_start(self, node, tag, packet): pass - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - if node.node_type == "function": + def on_node_start(self, node_label, node_hash): pass + def on_node_end(self, node_label, node_hash): pass + def on_packet_start(self, node_label, tag, packet): pass + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + if node_label == "doubler": events1.append(cached) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self SyncPipelineOrchestrator(observer=Obs1()).run(pipeline._node_graph) assert events1 == [False] @@ -486,16 +491,18 @@ def create_packet_logger(self, node, tag, packet, **kwargs): class Obs2: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): pass - def on_node_end(self, node): pass - def on_packet_start(self, node, tag, packet): pass - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): - if node.node_type == "function": + def on_node_start(self, node_label, node_hash): pass + def on_node_end(self, node_label, node_hash): pass + def on_packet_start(self, node_label, tag, packet): pass + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): + if node_label == "doubler": events2.append(cached) - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self SyncPipelineOrchestrator(observer=Obs2()).run(pipeline._node_graph) assert events2 == [True] @@ -517,27 +524,31 @@ def test_diamond_dag_observer_event_order(self): class OrderObserver: def on_run_start(self, run_id): pass def on_run_end(self, run_id): pass - def on_node_start(self, node): - node_order.append(("start", node.node_type)) - def on_node_end(self, node): - node_order.append(("end", node.node_type)) - def on_packet_start(self, node, tag, packet): pass - def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): pass - def on_packet_crash(self, node, tag, packet, exc): pass - def create_packet_logger(self, node, tag, packet, **kwargs): + def on_node_start(self, node_label, node_hash): + node_order.append(("start", node_label)) + def on_node_end(self, node_label, node_hash): + node_order.append(("end", node_label)) + def on_packet_start(self, node_label, tag, packet): pass + def on_packet_end(self, node_label, tag, input_pkt, output_pkt, cached): pass + def on_packet_crash(self, node_label, tag, packet, exc): pass + def create_packet_logger(self, tag, packet, **kwargs): from orcapod.pipeline.observer import _NOOP_LOGGER return _NOOP_LOGGER + def contextualize(self, node_hash, node_label): + return self SyncPipelineOrchestrator(observer=OrderObserver()).run(pipeline._node_graph) - # Extract just the node types in start order - starts = [nt for event, nt in node_order if event == "start"] - # Sources first (order between them doesn't matter), then operator, then function - assert starts.count("source") == 2 - assert starts.index("operator") > max( - i for i, s in enumerate(starts) if s == "source" - ) - assert starts.index("function") > starts.index("operator") + # Extract just the node labels in start order + starts = [label for event, label in node_order if event == "start"] + # Sources first, then operator ("join"), then function ("adder") + join_idx = starts.index("join") + adder_idx = starts.index("adder") + # join and adder should come after source labels + source_labels = [s for s in starts if s not in ("join", "adder")] + assert len(source_labels) == 2 + assert join_idx > max(starts.index(s) for s in source_labels) + assert adder_idx > join_idx def test_no_observer_works(self): """Pipeline runs fine with no observer (None)."""