diff --git a/.zed/rules b/.zed/rules index c5206c96..dc83f5e7 100644 --- a/.zed/rules +++ b/.zed/rules @@ -51,7 +51,16 @@ import path doesn't work, create a proper re-export package with an __init__.py Use Google style (https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) Python docstrings everywhere. -## Linear issues +## Linear issue tracking + +All work must be linked to a Linear issue. Before starting any feature, bug fix, or +refactor: + +1. Check for an existing issue — search Linear for a corresponding issue. +2. If none exists — ask the developer whether to create one. Do not proceed without + either a linked issue or explicit approval to skip. +3. When a new issue is discovered during development (bug, design problem, deferred + work), create a corresponding Linear issue using the template below. When creating Linear issues, always use this template for the description: @@ -82,8 +91,28 @@ When creating Linear issues, always use this template for the description: Remove any optional sections that don't apply rather than leaving them empty. -When working on a Linear issue, create and checkout a git branch using the gitBranchName -returned by Linear (e.g. eywalker/plt-911-add-documentation-for-orcapod-python). +### Branches and PRs + +When working on a feature, create and checkout a git branch using the gitBranchName +returned by the primary Linear issue (e.g. eywalker/plt-911-add-documentation-for-orcapod-python). + +If a feature branch / PR corresponds to multiple Linear issues, list all of them in the +PR description body so that Linear's GitHub integration auto-tracks the PR against each +issue. Use the format "Fixes PLT-123" or "Closes PLT-123" (GitHub magic words) for issues +that the PR fully resolves, and simply mention "PLT-456" for issues that are related but +not fully resolved by the PR. + +## Responding to PR reviews + +When asked to respond to PR reviewer comments: + +1. **Fetch and present** — Read all review comments, then present a response plan as a table: + each comment, its severity, whether to fix or explain, and the proposed action. +2. **Wait for approval** — Let the user approve the plan before making changes. +3. **Fix, then reply** — Make all fixes in a single commit, then post replies to each + reviewer comment explaining what was done (or why it was declined). + +Never make fixes silently or skip the plan step. ## Git commits diff --git a/CLAUDE.md b/CLAUDE.md index 53a3c1f9..722c6d2a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -55,7 +55,16 @@ import path doesn't work, create a proper re-export package with an `__init__.py Use [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) Python docstrings everywhere. -## Linear issues +## Linear issue tracking + +All work must be linked to a Linear issue. Before starting any feature, bug fix, or +refactor: + +1. **Check for an existing issue** — search Linear for a corresponding issue. +2. **If none exists** — ask the developer whether to create one. Do not proceed without + either a linked issue or explicit approval to skip. +3. **When a new issue is discovered** during development (bug, design problem, deferred + work), create a corresponding Linear issue using the template below. When creating Linear issues, always use this template for the description: @@ -88,8 +97,28 @@ Out of scope: Remove any optional sections that don't apply rather than leaving them empty. -When working on a Linear issue, create and checkout a git branch using the `gitBranchName` -returned by Linear (e.g. `eywalker/plt-911-add-documentation-for-orcapod-python`). +### Branches and PRs + +When working on a feature, create and checkout a git branch using the `gitBranchName` +returned by the primary Linear issue (e.g. `eywalker/plt-911-add-documentation-for-orcapod-python`). + +If a feature branch / PR corresponds to multiple Linear issues, list all of them in the +PR description body so that Linear's GitHub integration auto-tracks the PR against each +issue. Use the format `Fixes PLT-123` or `Closes PLT-123` (GitHub magic words) for issues +that the PR fully resolves, and simply mention `PLT-456` for issues that are related but +not fully resolved by the PR. + +## Responding to PR reviews + +When asked to respond to PR reviewer comments: + +1. **Fetch and present** — Read all review comments, then present a response plan as a table: + each comment, its severity, whether to fix or explain, and the proposed action. +2. **Wait for approval** — Let the user approve the plan before making changes. +3. **Fix, then reply** — Make all fixes in a single commit, then post replies to each + reviewer comment explaining what was done (or why it was declined). + +Never make fixes silently or skip the plan step. ## Git commits diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index e5f453b1..ca3b198c 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -4,7 +4,7 @@ import asyncio import logging -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from typing import TYPE_CHECKING, Any, cast from orcapod import contexts @@ -28,10 +28,7 @@ from orcapod.types import ( ColumnConfig, ContentHash, - NodeConfig, - PipelineConfig, Schema, - resolve_concurrency, ) from orcapod.utils import arrow_utils, schema_utils from orcapod.utils.lazy_module import LazyModule @@ -41,6 +38,8 @@ if TYPE_CHECKING: import polars as pl import pyarrow as pa + + from orcapod.pipeline.observer import ExecutionObserver else: pa = LazyModule("pyarrow") pl = LazyModule("polars") @@ -486,29 +485,51 @@ def execute_packet( return self._process_packet_internal(tag, packet) def execute( - self, input_stream: StreamProtocol + self, + input_stream: StreamProtocol, + *, + observer: "ExecutionObserver | None" = None, ) -> list[tuple[TagProtocol, PacketProtocol]]: """Execute all packets from a stream: compute, persist, and cache. - Internal method for orchestrators. The caller must guarantee that - the input stream's identity (content hash, schema) matches - ``self._input_stream``. No validation is performed. - - More efficient than calling ``execute_packet`` per-packet when - observer hooks aren't needed. - Args: input_stream: The input stream to process. + observer: Optional execution observer for hooks. Returns: Materialized list of (tag, output_packet) pairs, excluding None outputs. """ + if observer is not None: + observer.on_node_start(self) + + # Gather entry IDs and check cache + upstream_entries = [ + (tag, packet, self.compute_pipeline_entry_id(tag, packet)) + for tag, packet in input_stream.iter_packets() + ] + entry_ids = [eid for _, _, eid in upstream_entries] + cached = self.get_cached_results(entry_ids=entry_ids) + output: list[tuple[TagProtocol, PacketProtocol]] = [] - for tag, packet in input_stream.iter_packets(): - tag_out, result = self._process_packet_internal(tag, packet) - if result is not None: + for tag, packet, entry_id in upstream_entries: + if observer is not None: + observer.on_packet_start(self, tag, packet) + + if entry_id in cached: + tag_out, result = cached[entry_id] + if observer is not None: + observer.on_packet_end(self, tag, packet, result, cached=True) output.append((tag_out, result)) + else: + tag_out, result = self._process_packet_internal(tag, packet) + if observer is not None: + observer.on_packet_end(self, tag, packet, result, cached=False) + if result is not None: + output.append((tag_out, result)) + + if observer is not None: + observer.on_node_end(self) return output def _process_packet_internal( @@ -1141,27 +1162,35 @@ def as_table( async def async_execute( self, - inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + input_channel: ReadableChannel[tuple[TagProtocol, PacketProtocol]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], - pipeline_config: PipelineConfig | None = None, + *, + observer: "ExecutionObserver | None" = None, ) -> None: """Streaming async execution for FunctionNode. When a database is attached, uses two-phase execution: replay cached results first, then compute missing packets concurrently. Otherwise, routes each packet through ``async_process_packet`` directly. + + Args: + input_channel: Single readable channel of (tag, packet) pairs. + output: Writable channel for output (tag, packet) pairs. + observer: Optional execution observer for hooks. """ + # TODO(PLT-930): Restore concurrency limiting (semaphore) via node-level config. + # Currently all packets are processed sequentially in async_execute. try: - pipeline_config = pipeline_config or PipelineConfig() - # TODO: revisit this logic as use of accidental property is not desirable - node_config = getattr(self._function_pod, "node_config", NodeConfig()) - max_concurrency = resolve_concurrency(node_config, pipeline_config) + if observer is not None: + observer.on_node_start(self) if self._cached_function_pod is not None: - # Two-phase async execution with DB backing - # Phase 1: emit existing results from DB + # DB-backed async execution: + # Phase 1: build cache lookup from pipeline DB PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" - existing_entry_ids: set[str] = set() + cached_by_entry_id: dict[ + str, tuple[TagProtocol, PacketProtocol] + ] = {} taginfo = self._pipeline_database.get_all_records( self.pipeline_path, @@ -1184,12 +1213,9 @@ async def async_execute( ) if joined.num_rows > 0: tag_keys = self._input_stream.keys()[0] - existing_entry_ids = set( - cast( - list[str], - joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), - ) - ) + entry_ids_col = joined.column( + PIPELINE_ENTRY_ID_COL + ).to_pylist() drop_cols = [ c for c in joined.column_names @@ -1202,63 +1228,53 @@ async def async_execute( existing_stream = ArrowTableStream( data_table, tag_columns=tag_keys ) - for tag, packet in existing_stream.iter_packets(): - await output.send((tag, packet)) - - # Phase 2: process new packets concurrently - sem = ( - asyncio.Semaphore(max_concurrency) - if max_concurrency is not None - else None - ) + for eid, (tag_out, pkt_out) in zip( + entry_ids_col, existing_stream.iter_packets() + ): + cached_by_entry_id[eid] = (tag_out, pkt_out) - async def process_one_db( - tag: TagProtocol, packet: PacketProtocol - ) -> None: - try: + # Phase 2: drive output from input channel — cached or compute + async for tag, packet in input_channel: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in cached_by_entry_id: + tag_out, result_packet = cached_by_entry_id[entry_id] + if observer is not None: + observer.on_packet_start(self, tag, packet) + observer.on_packet_end( + self, tag, packet, result_packet, cached=True + ) + await output.send((tag_out, result_packet)) + else: + if observer is not None: + observer.on_packet_start(self, tag, packet) ( tag_out, result_packet, ) = await self._async_process_packet_internal(tag, packet) + if observer is not None: + observer.on_packet_end( + self, tag, packet, result_packet, cached=False + ) if result_packet is not None: await output.send((tag_out, result_packet)) - finally: - if sem is not None: - sem.release() - - async with asyncio.TaskGroup() as tg: - async for tag, packet in inputs[0]: - entry_id = self.compute_pipeline_entry_id(tag, packet) - if entry_id in existing_entry_ids: - continue - if sem is not None: - await sem.acquire() - tg.create_task(process_one_db(tag, packet)) else: # Simple async execution without DB - sem = ( - asyncio.Semaphore(max_concurrency) - if max_concurrency is not None - else None - ) + async for tag, packet in input_channel: + if observer is not None: + observer.on_packet_start(self, tag, packet) + ( + tag_out, + result_packet, + ) = await self._async_process_packet_internal(tag, packet) + if observer is not None: + observer.on_packet_end( + self, tag, packet, result_packet, cached=False + ) + if result_packet is not None: + await output.send((tag_out, result_packet)) - async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: - try: - ( - tag_out, - result_packet, - ) = await self._async_process_packet_internal(tag, packet) - if result_packet is not None: - await output.send((tag_out, result_packet)) - finally: - if sem is not None: - sem.release() - - async with asyncio.TaskGroup() as tg: - async for tag, packet in inputs[0]: - if sem is not None: - await sem.acquire() - tg.create_task(process_one(tag, packet)) + if observer is not None: + observer.on_node_end(self) finally: await output.close() diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py index be54ad5c..b5145d60 100644 --- a/src/orcapod/core/nodes/operator_node.py +++ b/src/orcapod/core/nodes/operator_node.py @@ -30,6 +30,8 @@ if TYPE_CHECKING: import pyarrow as pa + + from orcapod.pipeline.observer import ExecutionObserver else: pa = LazyModule("pyarrow") @@ -432,19 +434,27 @@ def get_cached_output(self) -> "StreamProtocol | None": def execute( self, *input_streams: StreamProtocol, + observer: "ExecutionObserver | None" = None, ) -> list[tuple[TagProtocol, PacketProtocol]]: """Execute input streams: compute, persist, and cache. - Internal method for orchestrators. The caller must guarantee that - the input streams' identities (content hash, schema) match - ``self._input_streams``. No validation is performed. - Args: *input_streams: Input streams to execute. + observer: Optional execution observer for hooks. Returns: Materialized list of (tag, packet) pairs. """ + if observer is not None: + observer.on_node_start(self) + + # Check REPLAY cache first + cached_output = self.get_cached_output() + if cached_output is not None: + output = list(cached_output.iter_packets()) + if observer is not None: + observer.on_node_end(self) + return output # Compute result_stream = self._operator.process(*input_streams) @@ -470,6 +480,8 @@ def execute( ): self._store_output_stream(self._cached_output_stream) + if observer is not None: + observer.on_node_end(self) return output def _compute_and_store(self) -> None: @@ -628,6 +640,8 @@ async def async_execute( self, inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: "ExecutionObserver | None" = None, ) -> None: """Async execution with cache mode handling when DB is attached. @@ -638,16 +652,28 @@ async def async_execute( - REPLAY: emit from DB, close output. - OFF: delegate to operator, forward results. - LOG: delegate to operator, forward + collect results, then store in DB. + + Args: + inputs: Sequence of readable channels from upstream nodes. + output: Writable channel for output (tag, packet) pairs. + observer: Optional execution observer for hooks. """ if self._pipeline_database is None: # Simple delegation without DB + if observer is not None: + observer.on_node_start(self) hashes = [s.pipeline_hash() for s in self._input_streams] await self._operator.async_execute( inputs, output, input_pipeline_hashes=hashes ) + if observer is not None: + observer.on_node_end(self) return try: + if observer is not None: + observer.on_node_start(self) + if self._cache_mode == CacheMode.REPLAY: self._replay_from_cache() assert self._cached_output_stream is not None @@ -684,6 +710,9 @@ async def forward() -> None: self._store_output_stream(stream) self._update_modified_time() + + if observer is not None: + observer.on_node_end(self) finally: await output.close() diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py index 4c2f6e9c..df8d47b1 100644 --- a/src/orcapod/core/nodes/source_node.py +++ b/src/orcapod/core/nodes/source_node.py @@ -2,11 +2,11 @@ from __future__ import annotations -from collections.abc import Iterator, Sequence +from collections.abc import Iterator from typing import TYPE_CHECKING, Any from orcapod import contexts -from orcapod.channels import ReadableChannel, WritableChannel +from orcapod.channels import WritableChannel from orcapod.config import Config, DEFAULT_CONFIG from orcapod.core.streams.base import StreamBase from orcapod.protocols import core_protocols as cp @@ -15,6 +15,8 @@ if TYPE_CHECKING: import pyarrow as pa + from orcapod.pipeline.observer import ExecutionObserver + class SourceNode(StreamBase): """Represents a root source stream in the computation graph.""" @@ -234,21 +236,56 @@ def iter_packets(self) -> Iterator[tuple[cp.TagProtocol, cp.PacketProtocol]]: return iter(self._cached_results) return self.stream.iter_packets() + def execute( + self, + *, + observer: "ExecutionObserver | None" = None, + ) -> list[tuple[cp.TagProtocol, cp.PacketProtocol]]: + """Execute this source: materialize packets and return. + + Args: + observer: Optional execution observer for hooks. + + Returns: + List of (tag, packet) tuples. + """ + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + if observer is not None: + observer.on_node_start(self) + result = list(self.stream.iter_packets()) + self._cached_results = result + if observer is not None: + observer.on_node_end(self) + return result + def run(self) -> None: """No-op for source nodes — data is already available.""" async def async_execute( self, - inputs: Sequence[ReadableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]]], output: WritableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]], + *, + observer: "ExecutionObserver | None" = None, ) -> None: - """Push all (tag, packet) pairs from the wrapped stream to the output channel.""" + """Push all (tag, packet) pairs from the wrapped stream to the output channel. + + Args: + output: Channel to write results to. + observer: Optional execution observer for hooks. + """ if self.stream is None: raise RuntimeError( "SourceNode in read-only mode has no stream data available" ) try: + if observer is not None: + observer.on_node_start(self) for tag, packet in self.stream.iter_packets(): await output.send((tag, packet)) + if observer is not None: + observer.on_node_end(self) finally: await output.close() diff --git a/src/orcapod/pipeline/async_orchestrator.py b/src/orcapod/pipeline/async_orchestrator.py index 61833e1d..04a6e7d3 100644 --- a/src/orcapod/pipeline/async_orchestrator.py +++ b/src/orcapod/pipeline/async_orchestrator.py @@ -1,10 +1,8 @@ """Async pipeline orchestrator for push-based channel execution. -Walks a compiled ``Pipeline``'s persistent node graph and launches all -nodes concurrently via ``asyncio.TaskGroup``, wiring them together with -bounded channels. After execution, results are available in the -pipeline databases via the usual ``get_all_records()`` / ``as_source()`` -accessors on each persistent node. +Walks a compiled pipeline's node graph and launches all nodes concurrently +via ``asyncio.TaskGroup``, wiring them together with bounded channels. +Uses TypeGuard dispatch with tightened per-type async_execute signatures. """ from __future__ import annotations @@ -15,93 +13,100 @@ from typing import TYPE_CHECKING, Any from orcapod.channels import BroadcastChannel, Channel -from orcapod.types import PipelineConfig +from orcapod.pipeline.result import OrchestratorResult +from orcapod.protocols.node_protocols import ( + is_function_node, + is_operator_node, + is_source_node, +) if TYPE_CHECKING: import networkx as nx - from orcapod.pipeline.graph import Pipeline + from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol logger = logging.getLogger(__name__) class AsyncPipelineOrchestrator: - """Execute a compiled ``Pipeline`` asynchronously using channels. + """Execute a compiled pipeline asynchronously using channels. - After ``Pipeline.compile()``, the orchestrator: + After compilation, the orchestrator: - 1. Walks ``Pipeline._node_graph`` (persistent nodes) in topological - order. + 1. Walks the node graph in topological order. 2. Creates bounded channels (or broadcast channels for fan-out) between connected nodes. 3. Launches every node's ``async_execute`` concurrently via - ``asyncio.TaskGroup``. + ``asyncio.TaskGroup``, using TypeGuard dispatch for per-type + signatures. - Results are written to the pipeline databases by the persistent - nodes themselves (``FunctionNode``, ``OperatorNode`` - in LOG mode, etc.). After ``run()`` returns, callers retrieve data - via ``pipeline.