diff --git a/.claude/settings.json b/.claude/settings.json new file mode 100644 index 00000000..0f243d7a --- /dev/null +++ b/.claude/settings.json @@ -0,0 +1,5 @@ +{ + "enabledPlugins": { + "superpowers@claude-plugins-official": true + } +} diff --git a/.github/workflows/run-objective-tests.yml b/.github/workflows/run-objective-tests.yml new file mode 100644 index 00000000..b245f6de --- /dev/null +++ b/.github/workflows/run-objective-tests.yml @@ -0,0 +1,40 @@ +name: Run Objective Tests + +on: + push: + branches: [main, dev] + pull_request: + branches: [main, dev] + workflow_dispatch: # Allows manual triggering + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y graphviz libgraphviz-dev + + - name: Install dependencies + run: uv sync --locked --all-extras --dev + + - name: Run objective tests + run: uv run pytest test-objective/ -v --cov=src --cov-report=term-missing --cov-report=xml + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 2929f239..1b5e609b 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -2,9 +2,9 @@ name: Run Tests on: push: - branches: [main] + branches: [main, dev] pull_request: - branches: [main] + branches: [main, dev] workflow_dispatch: # Allows manual triggering jobs: @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11", "3.12"] + python-version: ["3.11", "3.12"] steps: - uses: actions/checkout@v4 @@ -25,6 +25,9 @@ jobs: with: python-version: ${{ matrix.python-version }} + - name: Install system dependencies + run: sudo apt-get update && sudo apt-get install -y graphviz libgraphviz-dev + - name: Install dependencies run: uv sync --locked --all-extras --dev diff --git a/.gitignore b/.gitignore index 1abf5b48..c7feac3a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +# NFS cache files +.nfs* # Ignore data files in notebooks folder notebooks/**/*.json notebooks/**/*.yaml @@ -13,7 +15,7 @@ notebooks/**/*.db # skip any computation results **/results -# Ignore common data types by default +# Ignore common data types by default *.csv *.parquet *.xls @@ -106,9 +108,6 @@ instance/ # Scrapy stuff: .scrapy -# Sphinx documentation -docs/_build/ - # PyBuilder .pybuilder/ target/ @@ -219,3 +218,9 @@ dj_*_conf.json # pixi environments .pixi/* !.pixi/config.toml + +# Superpowers (Claude Code skill artifacts) +superpowers/ + +# General JSON files at project root +/*.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..12724a53 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,21 @@ +repos: + - repo: https://github.com/tsvikas/sync-with-uv + rev: v0.4.0 # replace with the latest version + hooks: + - id: sync-with-uv + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.4 + hooks: + - id: ruff-format + types_or: [ python, pyi ] + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: trailing-whitespace + types_or: [ python, pyi ] + - id: end-of-file-fixer + types_or: [ python, pyi ] + - id: check-yaml + - id: check-added-large-files + - id: check-merge-conflict diff --git a/.zed/rules b/.zed/rules new file mode 100644 index 00000000..dc83f5e7 --- /dev/null +++ b/.zed/rules @@ -0,0 +1,292 @@ +## Running commands + +Always run Python commands via `uv run`, e.g.: + + uv run pytest tests/ + uv run python -c "..." + +Never use `python`, `pytest`, or `python3` directly. + +## Updating agent instructions + +When adding or changing any instruction, update BOTH: +- CLAUDE.md (for Claude Code) +- .zed/rules (for Zed AI) + +## Design issues log + +DESIGN_ISSUES.md at the project root is the canonical log of known design problems, bugs, and +code quality issues. + +When fixing a bug or addressing a design problem: +1. Check DESIGN_ISSUES.md first — if a matching issue exists, update its status to + "in progress" while working and "resolved" once done, adding a brief Fix: note. +2. If no matching issue exists, ask the user whether it should be added before proceeding. + If yes, add it (status "open" or "in progress" as appropriate). + +When discovering a new issue that won't be fixed immediately, ask the user whether it should be +logged in DESIGN_ISSUES.md before adding it. + +## Superpowers artifacts + +Place all superpowers-related artifacts (design specs, plans, etc.) in the superpowers/ +directory at the project root — NOT under docs/. The docs/ directory is reserved for +actual library documentation. + +- Specs go in superpowers/specs/ + +## Backward compatibility + +This is a greenfield project pre-v0.1.0. Do NOT add backward-compatibility shims, +re-exports, aliases, or deprecation wrappers when making design or implementation changes. +Just change the code and update all references directly. + +## No sys.modules hacks + +Never manipulate sys.modules directly (e.g. sys.modules.setdefault). If a subpackage +import path doesn't work, create a proper re-export package with an __init__.py instead. + +## Docstrings + +Use Google style (https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) +Python docstrings everywhere. + +## Linear issue tracking + +All work must be linked to a Linear issue. Before starting any feature, bug fix, or +refactor: + +1. Check for an existing issue — search Linear for a corresponding issue. +2. If none exists — ask the developer whether to create one. Do not proceed without + either a linked issue or explicit approval to skip. +3. When a new issue is discovered during development (bug, design problem, deferred + work), create a corresponding Linear issue using the template below. + +When creating Linear issues, always use this template for the description: + + ## Overview + What is this project about? Describe the problem space and the high-level approach. + + ## Goals & Success Criteria + * Specific, measurable outcomes. + + ## Scope & Boundaries + (Optional — remove if not needed.) + In scope: + * ... + Out of scope: + * ... + + ## Dependencies & Risks + (Optional — remove if none.) + * ... + + ## Resources & References + (Optional — remove if none.) + * ... + + ## Milestones + (Optional — only for projects longer than ~4 weeks. Remove for shorter projects.) + * ... + +Remove any optional sections that don't apply rather than leaving them empty. + +### Branches and PRs + +When working on a feature, create and checkout a git branch using the gitBranchName +returned by the primary Linear issue (e.g. eywalker/plt-911-add-documentation-for-orcapod-python). + +If a feature branch / PR corresponds to multiple Linear issues, list all of them in the +PR description body so that Linear's GitHub integration auto-tracks the PR against each +issue. Use the format "Fixes PLT-123" or "Closes PLT-123" (GitHub magic words) for issues +that the PR fully resolves, and simply mention "PLT-456" for issues that are related but +not fully resolved by the PR. + +## Responding to PR reviews + +When asked to respond to PR reviewer comments: + +1. **Fetch and present** — Read all review comments, then present a response plan as a table: + each comment, its severity, whether to fix or explain, and the proposed action. +2. **Wait for approval** — Let the user approve the plan before making changes. +3. **Fix, then reply** — Make all fixes in a single commit, then post replies to each + reviewer comment explaining what was done (or why it was declined). + +Never make fixes silently or skip the plan step. + +## Git commits + +Always use Conventional Commits style (https://www.conventionalcommits.org/): + + (): + +Common types: feat, fix, refactor, test, docs, chore, perf, ci. + +Examples: +- feat(schema): add optional_fields to Schema +- fix(packet_function): reject variadic parameters at construction +- test(function_pod): add schema validation tests +- refactor(schema_utils): use Schema.optional_fields directly + +--- + +## Project layout + +src/orcapod/ + types.py — Schema, ColumnConfig, ContentHash + system_constants.py — Column prefixes and separators + errors.py — InputValidationError, DuplicateTagError, FieldNotResolvableError + config.py — Config dataclass + contexts/ — DataContext (semantic_hasher, arrow_hasher, type_converter) + protocols/ + hashing_protocols.py — PipelineElementProtocol, ContentIdentifiableProtocol + core_protocols/ — StreamProtocol, PodProtocol, SourceProtocol, + PacketFunctionProtocol, DatagramProtocol, TagProtocol, + PacketProtocol, TrackerProtocol + core/ + base.py — ContentIdentifiableBase, PipelineElementBase, TraceableBase + static_output_pod.py — StaticOutputPod (operator base), DynamicPodStream + function_pod.py — FunctionPod, FunctionPodStream, FunctionNode + packet_function.py — PacketFunctionBase, PythonPacketFunction, CachedPacketFunction + operator_node.py — OperatorNode (DB-backed operator execution) + tracker.py — Invocation tracking + datagrams/ + datagram.py — Datagram (unified dict/Arrow backing, lazy conversion) + tag_packet.py — Tag (+ system tags), Packet (+ source info) + sources/ + base.py — RootSource (abstract, no upstream) + arrow_table_source.py — Core source — all other sources delegate to it + derived_source.py — DerivedSource (backed by FunctionNode/OperatorNode DB) + csv_source.py, dict_source.py, list_source.py, + data_frame_source.py, delta_table_source.py — Delegating wrappers + source_registry.py — SourceRegistry for provenance resolution + streams/ + base.py — StreamBase (abstract) + arrow_table_stream.py — ArrowTableStream (concrete, immutable) + operators/ + base.py — UnaryOperator, BinaryOperator, NonZeroInputOperator + join.py — Join (N-ary inner join, commutative) + merge_join.py — MergeJoin (binary, colliding cols → sorted list[T]) + semijoin.py — SemiJoin (binary, non-commutative) + batch.py — Batch (group rows, types become list[T]) + column_selection.py — Select/Drop Tag/Packet columns + mappers.py — MapTags, MapPackets (rename columns) + filters.py — PolarsFilter + hashing/ + semantic_hashing/ — BaseSemanticHasher, type handlers + semantic_types/ — Type conversion (Python ↔ Arrow) + databases/ — ArrowDatabaseProtocol implementations (Delta Lake, in-memory) + utils/ + arrow_data_utils.py — System tag manipulation, source info, column helpers + arrow_utils.py — Arrow table utilities + schema_utils.py — Schema extraction, union, intersection, compatibility + lazy_module.py — LazyModule for deferred heavy imports + +tests/ + test_core/ + datagrams/ — Lazy conversion, dict/Arrow round-trip + sources/ — Source construction, protocol conformance, DerivedSource + streams/ — ArrowTableStream behavior + function_pod/ — FunctionPod, FunctionNode, pipeline hash integration + operators/ — All operators, OperatorNode, MergeJoin + packet_function/ — PacketFunction, CachedPacketFunction + test_hashing/ — Semantic hasher, hash stability + test_databases/ — Delta Lake, in-memory, no-op databases + test_semantic_types/ — Type converter tests + +--- + +## Architecture overview + +See orcapod-design.md at the project root for the full design specification. + +### Core data flow + + RootSource → ArrowTableStream → [Operator / FunctionPod] → ArrowTableStream → ... + +Every stream is an immutable sequence of (Tag, Packet) pairs backed by a PyArrow Table. +Tag columns are join keys and metadata; packet columns are the data payload. + +### Core abstractions + +Datagram (core/datagrams/datagram.py) — immutable data container with lazy dict ↔ Arrow +conversion. Two specializations: +- Tag — metadata columns + hidden system tag columns for provenance tracking +- Packet — data columns + per-column source info provenance tokens + +Stream (core/streams/arrow_table_stream.py) — immutable (Tag, Packet) sequence. +Key methods: output_schema(), keys(), iter_packets(), as_table(). + +Source (core/sources/) — produces a stream from external data. ArrowTableSource is the core +implementation; CSV/Delta/DataFrame/Dict/List sources all delegate to it internally. Each +source adds source-info columns and a system tag column. DerivedSource wraps a +FunctionNode/OperatorNode's DB records as a new source. + +Function Pod (core/function_pod.py) — wraps a PacketFunction that transforms individual +packets. Never inspects tags. Two execution models: +- FunctionPod → FunctionPodStream: lazy, in-memory +- FunctionNode: DB-backed, two-phase (yield cached results first, then compute missing) + +Operator (core/operators/) — structural pod transforming streams without synthesizing new +packet values. All subclass StaticOutputPod: +- UnaryOperator — 1 input (Batch, Select/Drop columns, Map, Filter) +- BinaryOperator — 2 inputs (MergeJoin, SemiJoin) +- NonZeroInputOperator — 1+ inputs (Join) + +OperatorNode (core/operator_node.py) — DB-backed operator execution, analogous to +FunctionNode. + +### Strict operator / function pod boundary + +Operators: inspect tags (never packet content), can rename columns, cannot synthesize values. +Function Pods: inspect packet content (never tags), synthesize new values, cached by content. + +### Two identity chains + +Every pipeline element has two parallel hashes: + +1. content_hash() — data-inclusive. Changes when data changes. Used for deduplication. +2. pipeline_hash() — schema + topology only. Ignores data content. Used for DB path scoping + so that different sources with identical schemas share database tables. + +Base case: RootSource.pipeline_identity_structure() returns (tag_schema, packet_schema). +Each downstream node's pipeline hash commits to its own identity plus upstream pipeline +hashes, forming a Merkle chain. + +### Column naming conventions + + __ prefix — System metadata (ColumnConfig meta) + _source_ prefix — Source info provenance (ColumnConfig source) + _tag:: prefix — System tag (ColumnConfig system_tags) + _context_key — Data context (ColumnConfig context) + +Prefixes are computed from SystemConstant in system_constants.py. + +### System tag evolution rules + +1. Name-preserving — single-stream ops. Column name/value pass through unchanged. +2. Name-extending — multi-input ops. System tag column name gets + ::{pipeline_hash}:{canonical_position} appended. Commutative operators sort by + pipeline_hash and sort system tag values per row. +3. Type-evolving — aggregation ops. Column type changes from str to list[str]. + +### Key patterns + +- LazyModule("pyarrow") — deferred import for heavy deps. Used in + if TYPE_CHECKING: / else: blocks. +- Argument symmetry — operators return frozenset (commutative) or tuple (ordered). +- StaticOutputPod.process() → DynamicPodStream — wraps static_process() with staleness + detection and automatic recomputation. +- Source delegation — CSVSource, DictSource, etc. create an internal ArrowTableSource. + +### Important implementation details + +- ArrowTableSource raises ValueError if any tag_columns are not in the table. +- ArrowTableStream requires at least one packet column; raises ValueError otherwise. +- FunctionNode Phase 1 returns ALL records in the shared pipeline_path DB table. + Phase 2 skips inputs whose hash is already in the DB. +- Empty data → ArrowTableSource raises ValueError("Table is empty"). +- DerivedSource before run() → raises ValueError (no computed records). +- Join requires non-overlapping packet columns; raises InputValidationError on collision. +- MergeJoin requires colliding columns to have identical types; merges into sorted list[T]. +- Operators predict output schema (including system tag names) without computation. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..722c6d2a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,332 @@ +# Claude Code instructions for orcapod-python + +## Running commands + +Always run Python commands via `uv run`, e.g.: + +``` +uv run pytest tests/ +uv run python -c "..." +``` + +Never use `python`, `pytest`, or `python3` directly. + +## Updating agent instructions + +When adding or changing any instruction, update BOTH: +- `CLAUDE.md` (for Claude Code) +- `.zed/rules` (for Zed AI) + +## Design issues log + +`DESIGN_ISSUES.md` at the project root is the canonical log of known design problems, bugs, and +code quality issues. + +When fixing a bug or addressing a design problem: +1. Check `DESIGN_ISSUES.md` first — if a matching issue exists, update its status to + `in progress` while working and `resolved` once done, adding a brief **Fix:** note. +2. If no matching issue exists, ask the user whether it should be added before proceeding. + If yes, add it (status `open` or `in progress` as appropriate). + +When discovering a new issue that won't be fixed immediately, ask the user whether it should be +logged in `DESIGN_ISSUES.md` before adding it. + +## Superpowers artifacts + +Place all superpowers-related artifacts (design specs, plans, etc.) in the `superpowers/` +directory at the project root — **not** under `docs/`. The `docs/` directory is reserved for +actual library documentation. + +- Specs go in `superpowers/specs/` + +## Backward compatibility + +This is a greenfield project pre-v0.1.0. Do **not** add backward-compatibility shims, +re-exports, aliases, or deprecation wrappers when making design or implementation changes. +Just change the code and update all references directly. + +## No `sys.modules` hacks + +Never manipulate `sys.modules` directly (e.g. `sys.modules.setdefault`). If a subpackage +import path doesn't work, create a proper re-export package with an `__init__.py` instead. + +## Docstrings + +Use [Google style](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) +Python docstrings everywhere. + +## Linear issue tracking + +All work must be linked to a Linear issue. Before starting any feature, bug fix, or +refactor: + +1. **Check for an existing issue** — search Linear for a corresponding issue. +2. **If none exists** — ask the developer whether to create one. Do not proceed without + either a linked issue or explicit approval to skip. +3. **When a new issue is discovered** during development (bug, design problem, deferred + work), create a corresponding Linear issue using the template below. + +When creating Linear issues, always use this template for the description: + +``` +## Overview +What is this project about? Describe the problem space and the high-level approach. + +## Goals & Success Criteria +* Specific, measurable outcomes. + +## Scope & Boundaries +(Optional — remove if not needed.) +In scope: +* ... +Out of scope: +* ... + +## Dependencies & Risks +(Optional — remove if none.) +* ... + +## Resources & References +(Optional — remove if none.) +* ... + +## Milestones +(Optional — only for projects longer than ~4 weeks. Remove for shorter projects.) +* ... +``` + +Remove any optional sections that don't apply rather than leaving them empty. + +### Branches and PRs + +When working on a feature, create and checkout a git branch using the `gitBranchName` +returned by the primary Linear issue (e.g. `eywalker/plt-911-add-documentation-for-orcapod-python`). + +If a feature branch / PR corresponds to multiple Linear issues, list all of them in the +PR description body so that Linear's GitHub integration auto-tracks the PR against each +issue. Use the format `Fixes PLT-123` or `Closes PLT-123` (GitHub magic words) for issues +that the PR fully resolves, and simply mention `PLT-456` for issues that are related but +not fully resolved by the PR. + +## Responding to PR reviews + +When asked to respond to PR reviewer comments: + +1. **Fetch and present** — Read all review comments, then present a response plan as a table: + each comment, its severity, whether to fix or explain, and the proposed action. +2. **Wait for approval** — Let the user approve the plan before making changes. +3. **Fix, then reply** — Make all fixes in a single commit, then post replies to each + reviewer comment explaining what was done (or why it was declined). + +Never make fixes silently or skip the plan step. + +## Git commits + +Always use [Conventional Commits](https://www.conventionalcommits.org/) style: + +``` +(): +``` + +Common types: `feat`, `fix`, `refactor`, `test`, `docs`, `chore`, `perf`, `ci`. + +Examples: +- `feat(schema): add optional_fields to Schema` +- `fix(packet_function): reject variadic parameters at construction` +- `test(function_pod): add schema validation tests` +- `refactor(schema_utils): use Schema.optional_fields directly` + +--- + +## Project layout + +``` +src/orcapod/ +├── types.py # Schema, ColumnConfig, ContentHash +├── system_constants.py # Column prefixes and separators +├── errors.py # InputValidationError, DuplicateTagError, FieldNotResolvableError +├── config.py # Config dataclass +├── contexts/ # DataContext (semantic_hasher, arrow_hasher, type_converter) +├── protocols/ +│ ├── hashing_protocols.py # PipelineElementProtocol, ContentIdentifiableProtocol +│ └── core_protocols/ # StreamProtocol, PodProtocol, SourceProtocol, +│ # PacketFunctionProtocol, DatagramProtocol, TagProtocol, +│ # PacketProtocol, TrackerProtocol +├── core/ +│ ├── base.py # ContentIdentifiableBase, PipelineElementBase, TraceableBase +│ ├── static_output_pod.py # StaticOutputPod (operator base), DynamicPodStream +│ ├── function_pod.py # FunctionPod, FunctionPodStream, FunctionNode +│ ├── packet_function.py # PacketFunctionBase, PythonPacketFunction, CachedPacketFunction +│ ├── operator_node.py # OperatorNode (DB-backed operator execution) +│ ├── tracker.py # Invocation tracking +│ ├── datagrams/ +│ │ ├── datagram.py # Datagram (unified dict/Arrow backing, lazy conversion) +│ │ └── tag_packet.py # Tag (+ system tags), Packet (+ source info) +│ ├── sources/ +│ │ ├── base.py # RootSource (abstract, no upstream) +│ │ ├── arrow_table_source.py # Core source — all other sources delegate to it +│ │ ├── derived_source.py # DerivedSource (backed by FunctionNode/OperatorNode DB) +│ │ ├── csv_source.py, dict_source.py, list_source.py, +│ │ │ data_frame_source.py, delta_table_source.py # Delegating wrappers +│ │ └── source_registry.py # SourceRegistry for provenance resolution +│ ├── streams/ +│ │ ├── base.py # StreamBase (abstract) +│ │ └── arrow_table_stream.py # ArrowTableStream (concrete, immutable) +│ └── operators/ +│ ├── base.py # UnaryOperator, BinaryOperator, NonZeroInputOperator +│ ├── join.py # Join (N-ary inner join, commutative) +│ ├── merge_join.py # MergeJoin (binary, colliding cols → sorted list[T]) +│ ├── semijoin.py # SemiJoin (binary, non-commutative) +│ ├── batch.py # Batch (group rows, types become list[T]) +│ ├── column_selection.py # Select/Drop Tag/Packet columns +│ ├── mappers.py # MapTags, MapPackets (rename columns) +│ └── filters.py # PolarsFilter +├── hashing/ +│ └── semantic_hashing/ # BaseSemanticHasher, type handlers +├── semantic_types/ # Type conversion (Python ↔ Arrow) +├── databases/ # ArrowDatabaseProtocol implementations (Delta Lake, in-memory) +└── utils/ + ├── arrow_data_utils.py # System tag manipulation, source info, column helpers + ├── arrow_utils.py # Arrow table utilities + ├── schema_utils.py # Schema extraction, union, intersection, compatibility + └── lazy_module.py # LazyModule for deferred heavy imports + +tests/ +├── test_core/ +│ ├── datagrams/ # Lazy conversion, dict/Arrow round-trip +│ ├── sources/ # Source construction, protocol conformance, DerivedSource +│ ├── streams/ # ArrowTableStream behavior +│ ├── function_pod/ # FunctionPod, FunctionNode, pipeline hash integration +│ ├── operators/ # All operators, OperatorNode, MergeJoin +│ └── packet_function/ # PacketFunction, CachedPacketFunction +├── test_hashing/ # Semantic hasher, hash stability +├── test_databases/ # Delta Lake, in-memory, no-op databases +└── test_semantic_types/ # Type converter tests +``` + +--- + +## Architecture overview + +See `orcapod-design.md` at the project root for the full design specification. + +### Core data flow + +``` +RootSource → ArrowTableStream → [Operator / FunctionPod] → ArrowTableStream → ... +``` + +Every stream is an immutable sequence of (Tag, Packet) pairs backed by a PyArrow Table. +Tag columns are join keys and metadata; packet columns are the data payload. + +### Core abstractions + +**Datagram** (`core/datagrams/datagram.py`) — immutable data container with lazy dict ↔ Arrow +conversion. Two specializations: +- **Tag** — metadata columns + hidden system tag columns for provenance tracking +- **Packet** — data columns + per-column source info provenance tokens + +**Stream** (`core/streams/arrow_table_stream.py`) — immutable (Tag, Packet) sequence. +Key methods: `output_schema()`, `keys()`, `iter_packets()`, `as_table()`. + +**Source** (`core/sources/`) — produces a stream from external data. `ArrowTableSource` is the +core implementation; CSV/Delta/DataFrame/Dict/List sources all delegate to it internally. Each +source adds source-info columns and a system tag column. `DerivedSource` wraps a +FunctionNode/OperatorNode's DB records as a new source. + +**Function Pod** (`core/function_pod.py`) — wraps a `PacketFunction` that transforms individual +packets. Never inspects tags. Two execution models: +- `FunctionPod` → `FunctionPodStream`: lazy, in-memory +- `FunctionNode`: DB-backed, two-phase (yield cached results first, then compute missing) + +**Operator** (`core/operators/`) — structural pod transforming streams without synthesizing new +packet values. All subclass `StaticOutputPod`: +- `UnaryOperator` — 1 input (Batch, Select/Drop columns, Map, Filter) +- `BinaryOperator` — 2 inputs (MergeJoin, SemiJoin) +- `NonZeroInputOperator` — 1+ inputs (Join) + +**OperatorNode** (`core/operator_node.py`) — DB-backed operator execution, analogous to +FunctionNode. + +### Strict operator / function pod boundary + +| | Operator | Function Pod | +|---|---|---| +| Inspects packet content | Never | Yes | +| Inspects / uses tags | Yes | No | +| Can rename columns | Yes | No | +| Synthesizes new values | No | Yes | +| Stream arity | Configurable | Single in, single out | + +### Two identity chains + +Every pipeline element has two parallel hashes: + +1. **`content_hash()`** — data-inclusive. Changes when data changes. Used for deduplication + and memoization. +2. **`pipeline_hash()`** — schema + topology only. Ignores data content. Used for DB path + scoping so that different sources with identical schemas share database tables. + +Base case: `RootSource.pipeline_identity_structure()` returns `(tag_schema, packet_schema)`. +Each downstream node's pipeline hash commits to its own identity plus the pipeline hashes of +its upstreams, forming a Merkle chain. + +The pipeline hash uses a **resolver pattern** — `PipelineElementProtocol` objects route through +`pipeline_hash()`, other `ContentIdentifiable` objects route through `content_hash()`. + +### Column naming conventions + +| Prefix | Meaning | Example | Controlled by | +|--------|---------|---------|---------------| +| `__` | System metadata | `__packet_id`, `__pod_version` | `ColumnConfig(meta=True)` | +| `_source_` | Source info provenance | `_source_age` | `ColumnConfig(source=True)` | +| `_tag::` | System tag | `_tag::source:abc123` | `ColumnConfig(system_tags=True)` | +| `_context_key` | Data context | `_context_key` | `ColumnConfig(context=True)` | + +Prefixes are computed from `SystemConstant` in `system_constants.py`. The `constants` singleton +(with no global prefix) is used throughout. + +### System tag evolution rules + +1. **Name-preserving** — single-stream ops (filter, select, map). Column name and value pass + through unchanged. +2. **Name-extending** — multi-input ops (join, merge join). Each input's system tag column + name gets `::{pipeline_hash}:{canonical_position}` appended. Commutative operators + canonically order inputs by `pipeline_hash` and sort system tag values per row. +3. **Type-evolving** — aggregation ops (batch). Column type changes from `str` to `list[str]`. + +### Schema types and ColumnConfig + +`Schema` (`types.py`) — immutable `Mapping[str, DataType]` with `optional_fields` support. +`output_schema()` always returns `(tag_schema, packet_schema)` as a tuple of Schemas. + +`ColumnConfig` (`types.py`) — frozen dataclass controlling which column groups are included. +Fields: `meta`, `context`, `source`, `system_tags`, `content_hash`, `sort_by_tags`. +Normalize via `ColumnConfig.handle_config(columns, all_info)` at the top of `output_schema()` +and `as_table()` methods. `all_info=True` sets everything to True. + +### Key patterns + +- **`LazyModule("pyarrow")`** — deferred import for heavy deps (pyarrow, polars). Used in + `if TYPE_CHECKING:` / `else:` blocks at module level. +- **Argument symmetry** — each operator declares `argument_symmetry(streams)` returning + `frozenset` (commutative) or `tuple` (ordered). Determines how upstream hashes combine. +- **`StaticOutputPod.process()` → `DynamicPodStream`** — wraps `static_process()` output + with timestamp-based staleness detection and automatic recomputation. +- **Source delegation** — CSVSource, DictSource, etc. all create an internal + `ArrowTableSource` and delegate every method to it. + +### Important implementation details + +- `ArrowTableSource.__init__` raises `ValueError` if any `tag_columns` are not in the table. +- `ArrowTableStream` requires at least one packet column; raises `ValueError` otherwise. +- `FunctionNode.iter_packets()` Phase 1 returns ALL records in the shared `pipeline_path` + DB table (not filtered to current inputs). Phase 2 skips inputs whose hash is already + in the DB. +- Empty data → `ArrowTableSource` raises `ValueError("Table is empty")`. +- `DerivedSource` before `run()` → raises `ValueError` (no computed records). +- Join requires non-overlapping packet columns; raises `InputValidationError` on collision. +- MergeJoin requires colliding packet columns to have identical types; merges into sorted + `list[T]` with source columns reordered to match. +- Operators predict their output schema (including system tag column names) without + performing the actual computation. diff --git a/DESIGN_ISSUES.md b/DESIGN_ISSUES.md new file mode 100644 index 00000000..9dcab19b --- /dev/null +++ b/DESIGN_ISSUES.md @@ -0,0 +1,916 @@ +# Design & Implementation Issues + +A running log of identified design problems, bugs, and code quality issues. +Each item has a status: `open`, `in progress`, or `resolved`. + +**Severity guide:** +- **critical** — Correctness bugs, silent data loss, or security issues. +- **high** — Broken or incomplete features that affect users or downstream consumers. +- **medium** — Performance, error-handling, or code-quality issues worth fixing in the + normal course of development. +- **low** — Nice-to-haves, cosmetic improvements, or speculative refactors. +- **trivial** — Typos, dead comments, purely cosmetic. + +--- + +## `src/orcapod/core/base.py` + +### B1 — `PipelineElementBase` should be merged into `TraceableBase` +**Status:** resolved +**Severity:** medium + +`TraceableBase` and `PipelineElementBase` co-occur in every active computation-node class +(`StreamBase`, `PacketFunctionBase`, `_FunctionPodBase`). The two current exceptions are design +gaps rather than intentional choices: + +- `StaticOutputPod(TraceableBase)` — should implement `PipelineElementProtocol`; its absence + forced `DynamicPodStream.pipeline_identity_structure()` to include an `isinstance` check as + a workaround. +- `Invocation(TraceableBase)` — legacy tracking mechanism, planned for revision. + +Note: merging into `TraceableBase` is correct at the *computation-node* level. +`ContentIdentifiableBase` (which `TraceableBase` builds on) should **not** absorb +`PipelineElementBase` — data datagrams (`Tag`, `Packet`) are legitimately content-identifiable +without being pipeline elements. + +**Fix:** Added `PipelineElementBase` to `TraceableBase`'s bases. Added +`pipeline_identity_structure()` to `StaticOutputPod`. Removed redundant explicit +`PipelineElementBase` from `StreamBase`, `ArrowTableStream`, `PacketFunctionBase`, +`_FunctionPodBase`, `FunctionPodStream`, `FunctionNode`, `OperatorNode`, and +`DynamicPodStream` declarations. + +--- + +### B2 — Mutable `data_context` setter may invalidate cached state +**Status:** open +**Severity:** medium + +`DataContextMixin.data_context` (line ~92) has a property setter that allows runtime context +changes. If a stream or pod has already cached schemas or hashes derived from the previous +context, those caches silently become stale. + +Options: (1) remove the setter and make context immutable after construction, or (2) add cache +invalidation on context change and document when changing context is safe. + +--- + +## `src/orcapod/core/packet_function.py` + +### P1 — `parse_function_outputs` is dead code +**Status:** resolved +**Severity:** medium +`parse_function_outputs` is a module-level function with a `self` parameter, suggesting it was +originally a method. It is never called. `PythonPacketFunction.call` duplicates its logic verbatim. +Should be deleted or wired up as a method on `PacketFunctionBase` and used from `call`. + +**Fix:** Converted to a proper standalone function `parse_function_outputs(output_keys, values)`. +Replaced the duplicated unpacking block in `PythonPacketFunction.call` with a call to it. +Updated tests accordingly. + +--- + +### P2 — `CachedPacketFunction.call` silently drops the `RESULT_COMPUTED_FLAG` +**Status:** resolved +**Severity:** high +On a cache miss, the flag is set but the result is discarded: +```python +output_packet.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) # return value ignored +``` +If `with_meta_columns` returns a new packet (immutable update), the flag is never actually +attached. + +**Fix:** Assigned the return value: `output_packet = output_packet.with_meta_columns(...)`. +Added tests verifying the flag is `True` on cache miss and `False` on cache hit. + +--- + +### P3 — `PacketFunctionWrapper.__init__` passes no `version` to `PacketFunctionBase` +**Status:** open +**Severity:** medium +`PacketFunctionBase.__init__` requires a `version` string and parses it into +`_major_version`/`_minor_version`. `PacketFunctionWrapper` calls `super().__init__(**kwargs)` +without a `version`, so it either crashes (no version in kwargs) or silently defaults to `"v0.0"`. +Those parsed fields are then shadowed by the delegating properties, making them dead state. +Options: pass the inner function's version through, or avoid calling the base version-parsing +logic entirely. + +--- + +### P4 — `PythonPacketFunction` computes the output schema hash twice +**Status:** open +**Severity:** low +`__init__` stores `self._output_schema_hash` (line ~289). `PacketFunctionBase` also lazily +caches `self._output_packet_schema_hash` (different attribute name) via +`output_packet_schema_hash`. Two fields holding the same value. One is redundant. + +--- + +### P5 — Large dead commented-out block in `get_all_cached_outputs` +**Status:** resolved +**Severity:** low +The block commenting out `pod_id_columns` removal is leftover from an old design. It makes it +ambiguous whether system columns are actually filtered. + +**Fix:** Deleted the commented-out block. + +--- + +### P6 — Cache-matching policy (`match_tier`) accepted but never used +**Status:** open +**Severity:** high + +`CachedPacketFunction.get_cached_output_for_packet()` (line ~547) accepts a `match_tier` +parameter that is documented in the interface but completely ignored. Cache lookups always use +exact matching. Two inline TODOs mark this: +- `# TODO: add match based on match_tier if specified` +- `# TODO: implement matching policy/strategy` + +This means any caller passing a non-default `match_tier` silently gets exact-match behavior, +which could lead to unnecessary cache misses or incorrect assumptions about cache hit semantics. + +Requires: design a matching strategy interface; implement at least exact and fuzzy tiers. + +--- + +### P7 — `PythonPacketFunction.__init__` unconditionally extracts git info +**Status:** open +**Severity:** medium + +`PythonPacketFunction.__init__()` (line ~324) always calls `get_git_info()`, which runs git +subprocess commands. This fails or significantly slows initialization in non-git environments +(CI containers, notebooks, deployed services). + +`# TODO: turn this into optional addition` + +Fix: add an `include_git_info=True` parameter; skip extraction when `False`. + +--- + +## `src/orcapod/core/function_pod.py` + +### F1 — `_FunctionPodBase.process` is `@abstractmethod` with unreachable body code +**Status:** resolved +**Severity:** high +The method is decorated `@abstractmethod` but has real logic after the `...` (handle_input_streams, +schema validation, tracker recording, FunctionPodStream construction). Since Python never executes +the body of an abstract method via normal dispatch, this code is unreachable. `FunctionPod` +then duplicates this logic verbatim. + +**Fix:** Removed the unreachable body code from `_FunctionPodBase.process()`, keeping it as +a pure abstract method with only `...`. `FunctionPod.process()` retains its own concrete +implementation. + +--- + +### F2 — Typo in `TrackedPacketFunctionPod` docstring +**Status:** open +**Severity:** trivial +`"A think wrapper"` should be `"A thin wrapper"`. + +--- + +### F3 — Dual URI computation paths in the class hierarchy +**Status:** open +**Severity:** low +`TrackedPacketFunctionPod.uri` assembles the URI from `self.packet_function.*` with its own lazy +schema-hash cache. `WrappedFunctionPod.uri` simply delegates to `self._function_pod.uri`. These +should agree (and do, after the `packet_function` fix), but having two independent implementations +makes future changes fragile. + +--- + +### F4 — `FunctionPodNode` is not a subclass of `TrackedPacketFunctionPod` +**Status:** open +**Severity:** medium +`FunctionPodNode` reimplements `process_packet`, `process`, `__call__`, `output_schema`, +`validate_inputs`, and `argument_symmetry` from scratch rather than inheriting from +`SimpleFunctionPod`/`TrackedPacketFunctionPod` and overriding the parts that differ +(fixed input stream, pipeline record writing). The result is a large amount of structural +duplication that diverges silently over time. + +--- + +### F5 — `FunctionPodStream` and `FunctionPodNodeStream` are near-identical copy-pastes +**Status:** open +**Severity:** medium +`iter_packets`, `as_table` (including content_hash and sort_by_tags logic), `keys`, +`output_schema`, `source`, and `upstreams` are duplicated almost line-for-line. The only +behavioural differences are: +- `FunctionPodNodeStream` has `refresh_cache()` +- `FunctionPodNodeStream.output_schema` reads from `_fp_node._cached_packet_function` directly + +A shared base stream class would eliminate the duplication. + +--- + +### F6 — `WrappedFunctionPod.process` makes the wrapper transparent to observability +**Status:** open +**Severity:** medium +`process` simply calls `self._function_pod.process(...)`, so the returned stream's `source` is +the *inner* pod, not the `WrappedFunctionPod`. Anything that inspects `stream.source` (e.g. +tracking, lineage) will see the inner pod and be unaware of the wrapper. Whether this is +intentional should be documented; if not, `process` needs to construct a new stream whose source +is `self`. + +--- + +### F7 — TOCTOU race in `FunctionPodNode.add_pipeline_record` +**Status:** open +**Severity:** medium +The method checks for an existing record with `get_record_by_id` and skips insertion if found. +But it then calls `add_record(..., skip_duplicates=False)`, which will raise on a duplicate. A +race between the lookup and the insert (e.g. two concurrent processes handling the same tag+packet) +would cause a crash instead of a graceful skip. Should use `skip_duplicates=True` for consistency +with the intent. + +--- + +### F8 — `CallableWithPod` protocol placement breaks logical grouping +**Status:** open +**Severity:** low +`CallableWithPod` is defined between `FunctionPodStream` and `function_pod`, breaking the natural +grouping. It should be co-located with `function_pod` or moved to the protocols module. + +--- + +### F11 — Schema validation raises `ValueError` instead of custom exception +**Status:** open +**Severity:** high + +`_validate_input_schema()` (line ~162) raises a generic `ValueError` when the packet schema +is incompatible: +```python +# TODO: use custom exception type for better error handling +``` + +The codebase already has `InputValidationError` (in `errors.py`) which is the correct exception +for this case. Using `ValueError` means callers cannot distinguish schema incompatibility from +other value errors without string-matching the message. + +Fix: change `ValueError` to `InputValidationError`. + +--- + +### F12 — System tag columns excluded from cache entry ID +**Status:** open +**Severity:** high + +`FunctionPodNode.record_packet_for_cache()` (line ~1077) builds a tag table for entry-ID +computation but excludes system tag columns: +```python +# TODO: add system tag columns +``` + +Two packets with identical user tags but different provenance (arriving from different +pipeline branches, thus having different system tags) produce the same cache key. This can +cause cache collisions where a result computed for one pipeline branch is returned for +another. + +Fix: include system tag columns in the `tag_with_hash` table before computing the entry ID hash. + +--- + +### F13 — `_FunctionPodBase.output_schema()` omits source-info columns +**Status:** open +**Severity:** medium + +`output_schema()` (line ~238) does not include source-info columns even when `ColumnConfig` +requests them: +```python +# TODO: handle and extend to include additional columns +``` + +This means callers using `columns={"source": True}` on a FunctionPod's output schema get an +incomplete schema, inconsistent with `as_table()` which does include source columns. + +--- + +### F14 — `FunctionPodStream.as_table()` uses Polars detour for Arrow sorting +**Status:** open +**Severity:** medium + +`as_table()` (line ~568) converts Arrow → Polars → sort → Arrow when sorting by tags: +```python +# TODO: reimplement using polars natively +``` + +The comment is misleading — the fix is actually to use PyArrow's native `.sort_by()` method +directly, eliminating the Polars dependency for this code path and reducing conversion overhead. + +--- + +### F10 — `FunctionPodNodeStream.iter_packets` recomputes every packet on every call +**Status:** resolved +**Severity:** high +`iter_packets` always iterates the full input stream and calls `process_packet` for every packet, +even when results are already stored in the result/pipeline databases. This defeats the purpose +of the two-database design (result DB + pipeline DB) used to cache computed outputs. + +**Fix:** Refactored `iter_packets` to first call `FunctionPodNode.get_all_records(columns={"meta": True})` +to load already-computed (tag, output-packet) pairs from the databases (mirroring the legacy +`PodNodeStream` design), yield those via `TableStream`, then collect the set of already-processed +`INPUT_PACKET_HASH` values and only call `process_packet` for input packets not yet in the DB. +Also added `FunctionPodNode.get_all_records(columns, all_info)` using `ColumnConfig` to control +which column groups (meta, source, system_tags) are returned. + +--- + +## `src/orcapod/core/cached_function_pod.py` / `src/orcapod/core/packet_function.py` + +### CFP1 — Extract shared result caching logic from CachedPacketFunction and CachedFunctionPod +**Status:** resolved +**Severity:** medium + +`CachedPacketFunction` and `CachedFunctionPod` implement nearly identical result caching +logic: DB lookup by `INPUT_PACKET_HASH_COL`, conflict resolution by most-recent timestamp, +record storage with variation/execution/timestamp columns, and a `RESULT_COMPUTED_FLAG` +meta column. The match tier / matching policy design (P6) will also need to apply to both. + +**Fix:** Extracted `ResultCache` class (`src/orcapod/core/result_cache.py`) that owns the DB, +record path, lookup (with `additional_constraints` for future match tiers), store, conflict +resolution, and auto-flush logic. Both `CachedPacketFunction` and `CachedFunctionPod` now +delegate to a `ResultCache` instance. The match tier strategy (P6) can be implemented once +on `ResultCache.lookup` via `additional_constraints`. + +--- + +## `src/orcapod/core/nodes/function_node.py` + +### FN1 — `FunctionNode.async_execute` Phase 2 was fully sequential +**Status:** resolved +**Severity:** high + +`FunctionNode.async_execute` (formerly `PersistentFunctionNode`) had a fully sequential Phase 2 +— each packet was awaited one at a time in a simple `async for` loop. This meant async packet +functions (which can overlap I/O via `await`) got no concurrency benefit when run through the +Pipeline API. + +The parent `FunctionNode.async_execute` already had the correct concurrent pattern using +`asyncio.Semaphore + TaskGroup`, but the persistent override did not replicate it. + +**Fix:** Rewrote Phase 2 to use the same `Semaphore + TaskGroup` pattern as the parent class. +Phase 1 (replay cached results from DB) remains unchanged. Concurrency is controlled via +`NodeConfig.max_concurrency` and `PipelineConfig`, resolved through `resolve_concurrency()`. + +--- + +## `src/orcapod/core/sources/` + +### S1 — `source_name` and `source_id` are redundant and inconsistent +**Status:** resolved +**Severity:** high + +`RootSource` defines `source_id` (canonical registry key, defaults to content hash). +`ArrowTableSource` defines `source_name` (provenance token prefix, defaults to `source_id`). +These are intended as the same concept — a stable name for the source — but they're two +separate parameters that can silently diverge: + +- **Provenance tokens** embed `source_name` (e.g. `"heights::row_0"`) +- **SourceRegistry** is keyed by `source_id` +- If they differ, provenance tokens cannot be resolved via the registry + +Delegating sources make this worse: +- `CSVSource` sets `source_name = file_path` but never sets `source_id` → registry key is a + content hash while provenance tokens use the file path +- `DeltaTableSource` sets `source_name = resolved.name` but never sets `source_id` → same issue + +Additionally, delegating sources all return `self._arrow_source.identity_structure()` which is +`("ArrowTableSource", tag_columns, table_hash)`. This means the outer source type (CSV, Delta, +etc.) is invisible to the content hash, and `source_id` (defaulting to content hash) will be +identical for a CSVSource and an ArrowTableSource with the same data. + +**Fix:** Dropped `source_name` entirely. `source_id` is now the single identifier used for +provenance strings, registry key, and `computed_label()`. Delegating sources set `source_id` +to their meaningful default (`CSVSource` → `file_path`, `DeltaTableSource` → `resolved.name`). +All delegating sources now pass `source_id=self.source_id` to their inner `ArrowTableSource`. +Added `computed_label()` to `RootSource` returning `_explicit_source_id`. + +--- + +### F9 — `as_table()` crashes with `KeyError` on empty stream +**Status:** resolved +**Severity:** high +Both `FunctionPodStream.as_table()` and `FunctionPodNodeStream.as_table()` unconditionally call +`.drop([constants.CONTEXT_KEY])` on the tags table built from the accumulated packets. When the +stream is empty (e.g. because the packet function is inactive), `iter_packets()` yields nothing, +`tag_schema` stays `None`, and `pa.Table.from_pylist([], schema=None)` produces a zero-column +table. The subsequent `.drop([constants.CONTEXT_KEY])` then raises `KeyError` because the column +does not exist. + +**Fix:** Guarded both `.drop([constants.CONTEXT_KEY])` calls in `FunctionPodStream.as_table()` and +`FunctionPodNodeStream.as_table()` with a column-existence check. Also made the final +`output_table = self._cached_output_table.drop(drop_columns)` safe by filtering `drop_columns` +to only columns that exist in the table. + +--- + +## `src/orcapod/core/streams/` + +### ST1 — `drop_packet_columns` may leave orphan source-info columns +**Status:** open +**Severity:** medium + +The `StreamProtocol.drop_packet_columns()` method (line ~309) drops data columns but it is +unclear whether the corresponding `_source_` columns are also removed: +```python +# TODO: check to make sure source columns are also dropped +``` + +If source-info columns survive after the data column is dropped, downstream consumers may see +stale provenance data or schema mismatches. + +--- + +### ST2 — `iter_packets()` does not support table batch streaming +**Status:** open +**Severity:** low + +`ArrowTableStream.iter_packets()` (line ~261) works only with fully materialized Arrow tables, +not with `RecordBatchReader` or lazy batch iteration: +```python +# TODO: make it work with table batch stream +``` + +Relevant for future streaming/chunked processing of large datasets. + +--- + +### ST3 — Column selection operators duplicate `validate_unary_input()` five times +**Status:** open +**Severity:** medium + +`SelectTagColumns`, `SelectPacketColumns`, `DropTagColumns`, `DropPacketColumns` (in +`column_selection.py:58`, `137`, `214`, `292`) and `PolarsFilterByPacketColumns` +(`filters.py:135`) each have near-identical `validate_unary_input()` implementations. All are +marked: +```python +# TODO: remove redundant logic +``` + +The only difference between them is which key set (tag vs. packet) is checked and the error +message text. A shared parameterized validation helper would eliminate the duplication. + +--- + +## `src/orcapod/core/operators/` — Async execution + +### O1 — Operators use barrier-mode `async_execute` only; streaming/incremental overrides needed +**Status:** in progress +**Severity:** medium + +All operators originally used the default barrier-mode `async_execute` inherited from +`StaticOutputPod`: collect all input rows into memory, materialize to `ArrowTableStream`(s), +run the existing sync `static_process`, then emit results. This works correctly but negates the +latency and memory benefits of the push-based channel model. + +Three categories of improvement are planned: + +1. **Streaming overrides (row-by-row, zero buffering)** — for operators that process rows + independently: + - ~~`PolarsFilter` — evaluate predicate per row, emit or drop immediately~~ (kept barrier: + Polars expressions require DataFrame context for evaluation) + - `MapTags` / `MapPackets` — rename columns per row, emit immediately ✅ + - `SelectTagColumns` / `SelectPacketColumns` — project columns per row, emit immediately ✅ + - `DropTagColumns` / `DropPacketColumns` — drop columns per row, emit immediately ✅ + +2. **Incremental overrides (stateful, eager emit)** — for multi-input operators that can + produce partial results before all inputs are consumed: + - `Join` — symmetric hash join for 2 inputs (streaming, with correct + system-tag name-extending via `input_pipeline_hashes` passed directly + to `async_execute`); barrier fallback for N>2 inputs via `static_process`. ✅ + - `MergeJoin` — kept barrier: complex column-merging logic + - `SemiJoin` — build right, stream left through hash lookup ✅ + +3. **Streaming accumulation:** + - `Batch` — emit full batches as they accumulate (`batch_size > 0`); barrier fallback + when `batch_size == 0` (batch everything) ✅ + +**Remaining:** `PolarsFilter` (barrier), `MergeJoin` (barrier) could receive incremental +overrides in the future but require careful handling of Polars expression evaluation and +system-tag evolution respectively. + +--- + +## `src/orcapod/core/` — AddResult pod and Pod Groups + +### G1 — `AddResult`: a first-class pod type for packet enrichment +**Status:** open +**Severity:** medium + +The most common pipeline pattern is *enrichment*: run a function on a packet and append the +computed columns to the original packet rather than replacing it. This is logically equivalent +to `FunctionPod → Join(original, computed)`, but implementing it as a first-class pod type is +both simpler and more efficient. + +#### Why a dedicated pod type, not a composite + +A naïve decomposition into `FunctionPod + Join` works but has unnecessary overhead: + +1. **Materialization waste** — FunctionPod produces an intermediate stream that is only created + to be immediately joined back. AddResult can compute new columns and merge them into the + original packet in a single pass, with no intermediate stream. +2. **Redundant tag matching** — Join must re-match tags that trivially correspond (they came + from the same input row). AddResult already holds the (tag, packet) pair and can skip the + matching entirely. +3. **Simpler async path** — streams row-by-row like FunctionPod: read (tag, packet), call + the packet function, merge original packet columns + new columns, emit. No broadcast, + passthrough channel, or rejoin wiring needed. + +#### Provenance model: fused implementation, not a third category + +The pipeline's provenance guarantees rest on a clean two-category model: + +| Category | Produces new data? | Provenance role | +|---|---|---| +| **Source / FunctionPod** | Yes | Provenance tracked — new values are attributed | +| **Operator** | No | Provenance transparent — every output value traces to a Source or FunctionPod | + +This is powerful because provenance tracking only happens at Source and FunctionPod boundaries. +Operators are "free" — they restructure but never create values that need attribution. + +**AddResult does not introduce a third provenance category.** It is a *fused implementation* +of a pattern fully expressible in the existing model (`FunctionPod + Join`). Its provenance +semantics are *derived from* the decomposition, not an extension of the model: + +- **Preserved columns** — passed through from upstream, provenance transparent (operator-like). + Source-info columns pass through unchanged, exactly as Join would propagate them. +- **Computed columns** — produced by the wrapped PacketFunction, provenance tracked + (function-pod-like). Source-info columns reference the PacketFunction, exactly as + FunctionPod would attribute them. + +There is no third kind of output column. Every column in an AddResult output has a clear +provenance story that maps directly to an existing category. The fusion is an optimization — +analogous to a database query optimizer fusing filter+project into a single scan without +changing relational algebra semantics. + +This means the theoretical model stays clean (Source, Operator, FunctionPod), and AddResult +is justified as a performance/ergonomic optimization whose correctness can be verified by +checking equivalence with its decomposition. + +#### API sketch + +```python +# Sync +grade_pf = PythonPacketFunction(compute_letter_grade, output_keys="letter_grade") +enriched = AddResult(grade_pf).process(stream) +# enriched has all original columns + "letter_grade" + +# Async (streaming, row-by-row) +await AddResult(grade_pf).async_execute([input_ch], output_ch) +``` + +#### Implementation notes + +- `output_schema()` returns `(input_tag_schema, input_packet_schema | function_output_schema)` + — the union of original packet columns and new computed columns. +- Must raise `InputValidationError` if function output keys collide with existing packet + column names (same constraint as Join on overlapping packet columns). +- `pipeline_hash` should behave as if the decomposition were performed — commits to the + wrapped `PacketFunction`'s identity plus the upstream's pipeline hash. +- Source-info on computed columns references the PacketFunction (as FunctionPod would). + Source-info on preserved columns passes through unchanged (as Join would). +- `async_execute` can use the same semaphore-based concurrency control as `FunctionPod`. + +--- + +## `src/orcapod/hashing/semantic_hashing/` + +### H1 — Semantic hasher does not support PEP 604 union types (`int | None`) +**Status:** open +**Severity:** medium + +The `BaseSemanticHasher` raises `BeartypeDoorNonpepException` when hashing a +`PythonPacketFunction` whose return type uses PEP 604 syntax (`int | None`). +The hasher's `_handle_unknown` path receives `types.UnionType` (the Python 3.10+ type for +`X | Y` expressions) and has no registered handler for it. + +`typing.Optional[int]` also fails (different error path through beartype). + +This means packet functions cannot use union return types — a common pattern for functions +that may filter packets by returning `None`. + +**Workaround:** Use non-union return types and raise/return sentinel values instead. + +**Fix needed:** Register a `TypeHandlerProtocol` for `types.UnionType` (and +`typing.Union`/`typing.Optional`) in the semantic hasher's type handler registry. + +--- + +### G2 — Pod Group abstraction for other composite pod patterns +**Status:** open +**Severity:** low + +Beyond AddResult (which warrants its own pod type — see G1), other composite patterns may +benefit from a general **PodGroup** abstraction that encapsulates a reusable sub-graph behind +a single pod-like interface. + +A PodGroup: +- Accepts one or more input streams and produces one output stream (same interface as a pod) +- Internally contains a fixed sub-graph of pods, operators, and channels +- Hides the internal wiring from the user +- Participates in pipeline hashing as a single composite element + +Potential patterns (to be designed as needs arise): +- **ConditionalPod** — route packets to different pods based on a predicate, merge results +- **FanOutFanIn** — broadcast to N pods, collect and merge/concat results +- **FallbackPod** — try primary pod, fall back to secondary on error/None result + +--- + +## `src/orcapod/databases/delta_lake_databases.py` + +### D1 — `flush()` swallows individual batch write errors +**Status:** open +**Severity:** critical + +`flush()` (line ~817) iterates over all pending batches, logs errors individually, but never +raises. Callers have no way to know that writes partially or fully failed: +```python +# TODO: capture and re-raise exceptions at the end +``` + +This is silent data loss — a batch write can fail and the caller proceeds as if everything +was persisted. + +Fix: accumulate exceptions during the loop; raise an `ExceptionGroup` (or custom aggregate +error) at the end containing all failures. + +--- + +### D2 — `flush_batch()` uses `mode="overwrite"` for new table creation +**Status:** open +**Severity:** high + +`flush_batch()` (line ~856) creates new Delta tables with `mode="overwrite"`: +```python +# TODO: reconsider mode="overwrite" here +``` + +If a table already exists at that path (race condition, stale state, or misconfigured pipeline +path), existing data is silently destroyed. Should use `mode="error"` to fail fast, or +`mode="append"` with an explicit existence check. + +--- + +### D3 — `_refresh_existing_ids_cache()` loads entire Delta table into memory +**Status:** open +**Severity:** high + +The method (line ~252) calls `to_pyarrow_table()` on the full Delta table just to extract the +ID column: +```python +# TODO: replace this with more targetted loading of only the target column and in batches +``` + +For large tables, this is a critical memory bottleneck. Delta Lake supports column projection +(`columns=[id_col]`) and batch reading, which would reduce memory usage by orders of magnitude. + +--- + +### D4 — `_refresh_existing_ids_cache()` catches missing column reactively +**Status:** open +**Severity:** high + +In the same method (line ~257), if the ID column doesn't exist, a `KeyError` is caught as a +fallback: +```python +# TODO: replace this with proper checking of the table schema first! +``` + +Schema should be validated proactively by loading schema metadata before attempting to read +the column. + +--- + +### D5 — `_handle_schema_compatibility()` uses naive equality and broad exception catching +**Status:** open +**Severity:** medium + +The method (lines ~467, ~485) compares schemas with simple equality and catches all exceptions: +```python +# TODO: perform more careful check +# TODO: perform more careful error check +``` + +Should implement nuanced schema comparison (e.g., field-order invariance, nullable vs. +non-nullable promotion) and catch specific exceptions rather than bare `except`. + +--- + +### D6 — `defaultdict` used for `_cache_dirty` is not serializable +**Status:** open +**Severity:** medium + +`__init__()` (line ~69) initializes `_cache_dirty` as `defaultdict(bool)`: +```python +# TODO: reconsider this approach as this is NOT serializable +``` + +`defaultdict` is not pickle-serializable, which blocks multiprocessing or serialization +use cases. Fix: use a regular dict with `.get(key, False)`. + +--- + +### D7 — Silent deduplication in `_deduplicate_within_table()` +**Status:** open +**Severity:** medium + +The method (line ~383) silently drops duplicate rows with no warning: +```python +# TODO: consider erroring out if duplicates are found +``` + +Duplicates may indicate an upstream bug. Should at least log a warning; consider making +behavior configurable (warn, error, or silent). + +--- + +## `src/orcapod/hashing/` + +### H1 — `FunctionSignatureExtractor` ignores `input_types` and `output_types` parameters +**Status:** open +**Severity:** critical + +`extract_function_info()` (`function_info_extractors.py:36`) accepts `input_typespec` and +`output_typespec` but never incorporates them into the extracted signature string: +```python +# FIXME: Fix this implementation!! +# BUG: Currently this is not using the input_types and output_types parameters +``` + +The extracted signature is therefore type-agnostic — two functions with identical names but +different type annotations produce the same hash. This is a correctness bug that can cause +cache collisions between type-overloaded functions. + +Fix: wire the type specs into the signature string; update tests. + +--- + +### H2 — Arrow hasher processes full table at once +**Status:** open +**Severity:** medium + +`SemanticArrowHasher._process_table_columns()` (`arrow_hashers.py:104`) calls `to_pylist()` +on the entire table, loading all rows into Python memory: +```python +# TODO: Process in batchwise/chunk-wise fashion for memory efficiency +``` + +For large tables, this is a significant memory bottleneck. Should use Arrow's `to_batches()` +for chunk-wise processing. + +--- + +### H3 — Visitor pattern does not traverse map types +**Status:** open +**Severity:** medium + +`visit_map()` in `visitors.py:225` is a pass-through that does not recurse into map +keys/values: +```python +TODO: Implement proper map traversal if needed for semantic types in keys/values. +``` + +Semantic types nested inside map columns will not be processed during hashing, leading to +incorrect or incomplete hash values. + +--- + +### H4 — Legacy backwards-compatible exports in `hashing/__init__.py` +**Status:** open +**Severity:** low + +The module (line ~141) re-exports old API names for backwards compatibility: +```python +# TODO: remove legacy section +``` + +Should be removed in the next breaking release. Consider adding deprecation warnings first. + +--- + +## `src/orcapod/utils/` + +### U1 — Source-info column type hard-coded to `large_string` +**Status:** open +**Severity:** critical + +In `add_source_info_to_table()` (`arrow_utils.py:604`), when source info is a collection it is +unconditionally cast to `pa.list_(pa.large_string())`: +```python +# TODO: this won't work other data types!!! +``` + +Any non-string collection values will fail or silently corrupt data. The logic also has an +unclear nested isinstance check (line ~602: `# TODO: clean up the logic here`). + +Fix: inspect collection element types to select the appropriate Arrow type; refactor the +conditional logic. + +--- + +### U2 — Bare `except` in `get_git_info()` +**Status:** open +**Severity:** high + +`get_git_info()` (`git_utils.py:55`) catches all exceptions including `KeyboardInterrupt` and +`SystemExit`: +```python +except: # TODO: specify exception +``` + +Fix: catch `(OSError, subprocess.SubprocessError, FileNotFoundError)` specifically. + +--- + +### U3 — `check_arrow_schema_compatibility()` lacks strict mode and type coercion +**Status:** open +**Severity:** high + +The function (`arrow_utils.py:433`, `462`) documents strict vs. non-strict behavior but only +partially implements the non-strict path: +```python +# TODO: add strict comparison +# TODO: if not strict, allow type coercion +``` + +Currently, the function always raises on type mismatch instead of coercing compatible types +when in non-strict mode. Users cannot choose between strict field-order checking and permissive +type promotion. + +--- + +### U4 — `is_subhint` does not handle type invariance +**Status:** open +**Severity:** high + +`check_schema_compatibility()` (`schema_utils.py:37`) uses beartype's `is_subhint` which +treats all generics as covariant: +```python +# TODO: is_subhint does not handle invariance properly +``` + +For mutable containers (`list[int]` vs `list[float]`), this produces false positives — +schemas are reported as compatible when they are not (e.g., a `list[int]` field is accepted +where `list[float]` is expected, but appending a float to `list[int]` would fail at runtime). + +Options: add an invariance-aware wrapper, switch to a stricter type comparison, or document +the limitation prominently. + +--- + +## `src/orcapod/contexts/registry.py` + +### C1 — Redundant manual validation duplicates JSON Schema checks +**Status:** open +**Severity:** medium + +`_load_spec_file()` (line ~141) performs manual required-field checking followed by JSON Schema +validation: +```python +# TODO: clean this up -- sounds redundant to the validation performed by schema check +``` + +The manual check is fully subsumed by the JSON Schema validation. Either remove the manual +checks, or make the JSON Schema validation optional and keep manual checks as the fallback. + +--- + +## `src/orcapod/core/tracker.py` + +### T1 — `SourceNode.identity_structure()` assumes root source +**Status:** open +**Severity:** medium + +The method (line ~163) delegates directly to `stream.identity_structure()`: +```python +# TODO: revisit this logic for case where stream is not a root source +``` + +For derived sources (e.g., `DerivedSource`), the stream may not have a meaningful +`identity_structure()`. Needs an isinstance check or protocol-based dispatch. + +--- + +## `src/orcapod/core/nodes/` — Config and context delegation chain + +### T2 — `orcapod_config` not on any protocol; delegation chain needs review +**Status:** open +**Severity:** medium + +Nodes (`SourceNode`, `FunctionNode`, `OperatorNode`) now delegate `data_context` to their +wrapped entity via property overrides, ensuring transparent context pass-through. However, +`orcapod_config` is only on `DataContextMixin` (concrete base), not on any protocol +(`StreamProtocol`, `PodProtocol`, etc.). + +Open questions: +1. Should `orcapod_config` be added to a protocol (e.g. `TraceableProtocol`)? + Adding it couples protocol consumers to the `Config` type. Leaving it off means + nodes can't transparently delegate config the same way they delegate context. +2. Should `Pipeline` optionally hold `data_context` and/or `config` to allow + pipeline-level overrides that propagate to all nodes during `compile()`? +3. The current chain is: Pipeline (no context) → Node (delegates to wrapped entity) → + wrapped entity (owns context). Should there be a way for Pipeline to inject a + context override? + +--- diff --git a/TESTING_PLAN.md b/TESTING_PLAN.md new file mode 100644 index 00000000..3d7e3e18 --- /dev/null +++ b/TESTING_PLAN.md @@ -0,0 +1,680 @@ +# Comprehensive Specification-Derived Testing Plan + +## Context + +The orcapod-python codebase has grown complex with many interdependent components. Existing tests were often written by the same agent that implemented the code, risking "self-affirmation" — tests that validate what was built rather than what was specified. This plan creates an independent test suite derived purely from **design documents, protocol definitions, and interface contracts**, organized in a new `test-objective/` root folder. + +## Approach: Specification-First Testing + +Tests are derived from these specification sources (NOT from reading implementation code): +1. `orcapod-design.md` — the canonical design specification +2. Protocol definitions in `src/orcapod/protocols/` — interface contracts +3. Type annotations and docstrings — method signatures and documented behavior +4. `CLAUDE.md` architecture overview — documented invariants and constraints +5. `DESIGN_ISSUES.md` — known bugs that tests should catch + +## Deliverables + +### 1. `TESTING_PLAN.md` — comprehensive test case catalog at project root +### 2. `test-objective/` — concrete test implementations at project root + +--- + +## File Structure + +``` +test-objective/ +├── conftest.py # Shared fixtures (sources, streams, functions) +├── unit/ +│ ├── __init__.py +│ ├── test_types.py # Schema, ColumnConfig, ContentHash +│ ├── test_datagram.py # Datagram core behavior +│ ├── test_tag.py # Tag (system tags, ColumnConfig filtering) +│ ├── test_packet.py # Packet (source info, provenance) +│ ├── test_stream.py # ArrowTableStream construction & iteration +│ ├── test_sources.py # All source types + error conditions +│ ├── test_source_registry.py # SourceRegistry CRUD + edge cases +│ ├── test_packet_function.py # PythonPacketFunction + CachedPacketFunction +│ ├── test_function_pod.py # FunctionPod, FunctionPodStream +│ ├── test_operators.py # All operators (Join, MergeJoin, SemiJoin, etc.) +│ ├── test_nodes.py # FunctionNode, OperatorNode, Persistent variants +│ ├── test_hashing.py # SemanticHasher, TypeHandlerRegistry, handlers +│ ├── test_databases.py # InMemory, DeltaLake, NoOp databases +│ ├── test_schema_utils.py # Schema extraction, union, intersection +│ ├── test_arrow_utils.py # Arrow table/schema utilities +│ ├── test_arrow_data_utils.py # System tags, source info, column helpers +│ ├── test_semantic_types.py # UniversalTypeConverter, SemanticTypeRegistry +│ ├── test_contexts.py # DataContext resolution, validation +│ ├── test_tracker.py # BasicTrackerManager, GraphTracker +│ └── test_lazy_module.py # LazyModule deferred import behavior +├── integration/ +│ ├── __init__.py +│ ├── test_pipeline_flows.py # End-to-end pipeline scenarios +│ ├── test_caching_flows.py # DB-backed caching (FunctionNode, OperatorNode) +│ ├── test_hash_invariants.py # Hash stability & Merkle chain properties +│ ├── test_provenance.py # System tag lineage through pipelines +│ └── test_column_config_filtering.py # ColumnConfig behavior across all components +└── property/ + ├── __init__.py + ├── test_schema_properties.py # Hypothesis-based schema algebra + ├── test_hash_properties.py # Hash determinism, collision resistance + └── test_operator_algebra.py # Commutativity, associativity, idempotency +``` + +--- + +## Unit Test Cases by Module + +### 1. `test_types.py` — Schema, ColumnConfig, ContentHash + +**Schema:** +- `test_schema_construction_from_dict` — Schema({"a": int, "b": str}) stores correct fields +- `test_schema_construction_with_kwargs` — Schema(fields, x=int) merges kwargs with precedence +- `test_schema_optional_fields` — optional_fields stored as frozenset, not in required_fields +- `test_schema_required_fields` — required_fields = all fields minus optional_fields +- `test_schema_immutability` — Schema is an immutable Mapping (no __setitem__) +- `test_schema_merge_compatible` — Schema.merge() combines non-conflicting schemas +- `test_schema_merge_type_conflict_raises` — Schema.merge() raises ValueError on type conflicts +- `test_schema_with_values_overrides_silently` — with_values() overrides without errors +- `test_schema_select_existing_fields` — select() returns subset +- `test_schema_select_missing_field_raises` — select() raises KeyError on missing field +- `test_schema_drop_existing_fields` — drop() removes fields +- `test_schema_drop_missing_field_silent` — drop() silently ignores missing fields +- `test_schema_is_compatible_with_superset` — returns True when other is superset +- `test_schema_is_not_compatible_with_subset` — returns False when other is subset +- `test_schema_empty` — Schema.empty() returns zero-field schema +- `test_schema_mapping_interface` — __getitem__, __contains__, __iter__, __len__ work correctly + +**ContentHash:** +- `test_content_hash_immutability` — frozen dataclass, cannot reassign method/digest +- `test_content_hash_to_hex` — to_hex(8) returns 8-char hex string +- `test_content_hash_to_int` — to_int() returns consistent integer +- `test_content_hash_to_uuid` — to_uuid() returns deterministic UUID +- `test_content_hash_to_base64` — to_base64() returns valid base64 +- `test_content_hash_to_string_and_from_string_roundtrip` — from_string(to_string()) == original +- `test_content_hash_display_name` — display_name() returns "method:short_hex" format +- `test_content_hash_equality` — same method+digest are equal +- `test_content_hash_inequality` — different digests are not equal + +**ColumnConfig:** +- `test_column_config_defaults` — all fields False by default +- `test_column_config_all` — ColumnConfig.all() sets everything True +- `test_column_config_data_only` — ColumnConfig.data_only() sets everything False +- `test_column_config_handle_config_dict` — handle_config(dict) normalizes to ColumnConfig +- `test_column_config_handle_config_all_info_override` — all_info=True overrides individual fields +- `test_column_config_frozen` — cannot modify after construction + +### 2. `test_datagram.py` — Datagram + +**Construction:** +- `test_datagram_from_dict` — construct from Python dict +- `test_datagram_from_arrow_table` — construct from pa.Table +- `test_datagram_from_record_batch` — construct from pa.RecordBatch +- `test_datagram_with_meta_info` — meta columns stored separately +- `test_datagram_with_python_schema` — explicit schema used over inference +- `test_datagram_with_record_id` — custom record_id stored as datagram_id + +**Dict-like Access:** +- `test_datagram_getitem_existing_key` — returns correct value +- `test_datagram_getitem_missing_key_raises` — raises KeyError +- `test_datagram_contains` — __contains__ returns True/False correctly +- `test_datagram_iter` — __iter__ yields all data column names +- `test_datagram_get_with_default` — get() returns default for missing keys + +**Lazy Conversion (key invariant):** +- `test_datagram_dict_access_uses_dict_backing` — dict access doesn't trigger Arrow conversion +- `test_datagram_as_table_triggers_arrow_conversion` — as_table() produces Arrow table +- `test_datagram_dict_arrow_roundtrip_preserves_data` — dict→Arrow→dict preserves values +- `test_datagram_arrow_dict_roundtrip_preserves_data` — Arrow→dict→Arrow preserves values + +**Schema Methods:** +- `test_datagram_keys_data_only` — keys() returns only data column names by default +- `test_datagram_keys_all_info` — keys(all_info=True) includes meta columns +- `test_datagram_schema_matches_keys` — schema() field names match keys() +- `test_datagram_arrow_schema_type_consistency` — arrow_schema() types match schema() types + +**Format Conversions:** +- `test_datagram_as_dict` — returns plain Python dict +- `test_datagram_as_table` — returns single-row pa.Table +- `test_datagram_as_arrow_compatible_dict` — values are Arrow-compatible + +**Data Operations (immutability):** +- `test_datagram_select_returns_new_instance` — original unchanged +- `test_datagram_drop_returns_new_instance` — original unchanged +- `test_datagram_rename_returns_new_instance` — original unchanged +- `test_datagram_update_existing_columns_only` — update() only changes existing columns +- `test_datagram_with_columns_new_only` — with_columns() only adds new columns +- `test_datagram_copy_creates_independent_copy` — mutations to copy don't affect original + +**Meta Operations:** +- `test_datagram_get_meta_value_auto_prefixed` — get_meta_value() auto-adds prefix +- `test_datagram_with_meta_columns_returns_new` — immutable update +- `test_datagram_drop_meta_columns_returns_new` — immutable drop + +**Content Hashing:** +- `test_datagram_content_hash_deterministic` — same data → same hash +- `test_datagram_content_hash_changes_with_data` — different data → different hash +- `test_datagram_equality_by_content` — equal content → equal datagrams + +### 3. `test_tag.py` — Tag + +- `test_tag_construction_with_system_tags` — system tags stored separately from data +- `test_tag_system_tags_excluded_from_default_keys` — keys() doesn't show system tags +- `test_tag_system_tags_included_with_column_config` — keys(columns={"system_tags": True}) shows them +- `test_tag_as_dict_excludes_system_tags_by_default` — as_dict() only has data +- `test_tag_as_dict_all_info_includes_system_tags` — as_dict(all_info=True) has everything +- `test_tag_as_table_excludes_system_tags_by_default` +- `test_tag_as_table_all_info_includes_system_tags` +- `test_tag_schema_excludes_system_tags_by_default` +- `test_tag_copy_preserves_system_tags` — copy() includes system tags +- `test_tag_as_datagram_conversion` — as_datagram() returns Datagram (not Tag) +- `test_tag_system_tags_method_returns_copy` — system_tags() returns dict copy, not reference + +### 4. `test_packet.py` — Packet + +- `test_packet_construction_with_source_info` — source_info stored per data column +- `test_packet_source_info_excluded_from_default_keys` — keys() doesn't show _source_ columns +- `test_packet_source_info_included_with_column_config` — keys(columns={"source": True}) +- `test_packet_with_source_info_returns_new` — immutable update +- `test_packet_rename_updates_source_info_keys` — rename() also renames source_info keys +- `test_packet_with_columns_adds_source_info_entry` — new columns get source_info=None +- `test_packet_as_datagram_conversion` — as_datagram() returns Datagram +- `test_packet_as_dict_excludes_source_columns_by_default` +- `test_packet_as_dict_all_info_includes_source_columns` +- `test_packet_copy_preserves_source_info` + +### 5. `test_stream.py` — ArrowTableStream + +**Construction:** +- `test_stream_from_table_with_tag_columns` — tag/packet column separation +- `test_stream_requires_at_least_one_packet_column` — ValueError if no packet columns +- `test_stream_with_system_tag_columns` — system tag columns tracked +- `test_stream_with_source_info` — source info attached to packet columns +- `test_stream_with_producer` — producer property set +- `test_stream_with_upstreams` — upstreams tuple set + +**Schema & Keys:** +- `test_stream_keys_returns_tag_and_packet_keys` — tuple of (tag_keys, packet_keys) +- `test_stream_output_schema_returns_two_schemas` — (tag_schema, packet_schema) +- `test_stream_schema_matches_actual_data` — output_schema() types match as_table() types +- `test_stream_keys_with_column_config` — ColumnConfig filtering works + +**Iteration:** +- `test_stream_iter_packets_yields_tag_packet_pairs` — each yield is (Tag, Packet) +- `test_stream_iter_packets_count_matches_rows` — number of yields = number of rows +- `test_stream_iter_packets_tag_keys_correct` — tag column names match +- `test_stream_iter_packets_packet_keys_correct` — packet column names match +- `test_stream_as_table_matches_iter_packets` — table materialization consistent with iteration + +**Immutability:** +- `test_stream_immutable` — no mutation methods available + +**Format Conversions:** +- `test_stream_as_polars_df` — converts to Polars DataFrame +- `test_stream_as_pandas_df` — converts to Pandas DataFrame +- `test_stream_as_lazy_frame` — converts to Polars LazyFrame + +### 6. `test_sources.py` — All Source Types + +**ArrowTableSource:** +- `test_arrow_source_from_valid_table` — normal construction succeeds +- `test_arrow_source_empty_table_raises` — ValueError("Table is empty") +- `test_arrow_source_missing_tag_column_raises` — ValueError if tag_columns not in table +- `test_arrow_source_adds_system_tag_column` — system tag column added automatically +- `test_arrow_source_adds_source_info_columns` — _source_ columns added +- `test_arrow_source_source_id_set` — source_id property populated +- `test_arrow_source_producer_is_none` — root sources have no producer +- `test_arrow_source_upstreams_empty` — root sources have no upstreams +- `test_arrow_source_resolve_field_by_record_id` — resolves field value +- `test_arrow_source_resolve_field_missing_raises` — FieldNotResolvableError +- `test_arrow_source_pipeline_identity_structure` — returns (tag_schema, packet_schema) +- `test_arrow_source_iter_packets_yields_correct_pairs` +- `test_arrow_source_as_table_has_all_columns` + +**DictSource:** +- `test_dict_source_from_dict_of_lists` — constructs correctly +- `test_dict_source_delegates_to_arrow_table_source` — same behavior as ArrowTableSource +- `test_dict_source_with_tag_columns` + +**ListSource:** +- `test_list_source_from_list_of_dicts` — constructs correctly +- `test_list_source_empty_list_raises` — ValueError + +**CSVSource:** +- `test_csv_source_from_file` — reads CSV correctly +- `test_csv_source_with_tag_columns` + +**DataFrameSource:** +- `test_dataframe_source_from_polars` — constructs from Polars DataFrame +- `test_dataframe_source_from_pandas` — constructs from Pandas DataFrame + +**DerivedSource:** +- `test_derived_source_before_run_raises` — ValueError before upstream has computed +- `test_derived_source_after_run_yields_records` — produces records from upstream node + +### 7. `test_source_registry.py` — SourceRegistry + +- `test_registry_register_and_get` — register then retrieve +- `test_registry_register_empty_id_raises` — ValueError +- `test_registry_register_none_source_raises` — ValueError +- `test_registry_register_same_object_idempotent` — re-register same object is no-op +- `test_registry_register_different_object_same_id_keeps_existing` — warns, keeps existing +- `test_registry_replace_overwrites` — replace() unconditionally overwrites +- `test_registry_replace_returns_old` — returns previous source +- `test_registry_unregister_removes` — removes and returns source +- `test_registry_unregister_missing_raises` — KeyError +- `test_registry_get_missing_raises` — KeyError +- `test_registry_get_optional_missing_returns_none` — returns None +- `test_registry_contains` — __contains__ works +- `test_registry_len` — __len__ works +- `test_registry_iter` — __iter__ yields IDs +- `test_registry_clear` — removes all entries +- `test_registry_list_ids` — returns list of registered IDs + +### 8. `test_packet_function.py` — PythonPacketFunction, CachedPacketFunction + +**PythonPacketFunction:** +- `test_pf_from_simple_function` — wraps a function with explicit output_keys +- `test_pf_infers_input_schema_from_signature` — type annotations → input_packet_schema +- `test_pf_infers_output_schema` — output type annotations or output_keys → output_packet_schema +- `test_pf_rejects_variadic_parameters` — *args, **kwargs raise ValueError +- `test_pf_call_transforms_packet` — call() applies function to packet data +- `test_pf_call_returns_none_if_function_returns_none` — None propagates +- `test_pf_direct_call_bypasses_executor` — direct_call() ignores executor +- `test_pf_call_routes_through_executor` — call() uses executor when set +- `test_pf_version_parsing` — "v1.2" → major_version=1, minor_version_string="2" +- `test_pf_canonical_function_name` — uses function.__name__ or explicit name +- `test_pf_content_hash_deterministic` — same function → same hash +- `test_pf_content_hash_changes_with_function` — different function → different hash +- `test_pf_pipeline_hash_ignores_data` — pipeline_hash based on schema only + +**CachedPacketFunction:** +- `test_cached_pf_cache_miss_computes_and_stores` — first call computes + records +- `test_cached_pf_cache_hit_returns_stored` — second call returns cached result +- `test_cached_pf_skip_cache_lookup_always_computes` — skip_cache_lookup=True forces compute +- `test_cached_pf_skip_cache_insert_doesnt_store` — skip_cache_insert=True skips recording +- `test_cached_pf_get_all_cached_outputs` — returns all stored records as table +- `test_cached_pf_record_path_based_on_function_hash` — record path includes function identity + +### 9. `test_function_pod.py` — FunctionPod, FunctionPodStream + +**FunctionPod:** +- `test_function_pod_process_returns_stream` — process() returns FunctionPodStream +- `test_function_pod_validate_inputs_single_stream` — accepts exactly one stream +- `test_function_pod_validate_inputs_multiple_raises` — rejects multiple streams +- `test_function_pod_output_schema_prediction` — output_schema() matches actual output +- `test_function_pod_callable_alias` — __call__ same as process() +- `test_function_pod_never_modifies_tags` — tags pass through unchanged +- `test_function_pod_transforms_packets` — packets are transformed by function + +**FunctionPodStream:** +- `test_fps_lazy_evaluation` — iter_packets() triggers computation +- `test_fps_producer_is_function_pod` — producer property returns the pod +- `test_fps_upstreams_contains_input_stream` +- `test_fps_keys_matches_pod_output_schema` — keys() consistent with pod.output_schema() +- `test_fps_as_table_materialization` — as_table() returns correct table +- `test_fps_clear_cache_forces_recompute` — clear_cache() resets cached state + +**Decorator:** +- `test_function_pod_decorator_creates_pod_attribute` — @function_pod adds .pod +- `test_function_pod_decorator_with_result_database` — wraps in CachedPacketFunction + +### 10. `test_operators.py` — All Operators + +**Join (N-ary, commutative):** +- `test_join_two_streams_on_common_tags` — inner join on shared tag columns +- `test_join_non_overlapping_packet_columns_required` — InputValidationError on collision +- `test_join_commutative` — join(A, B) == join(B, A) (same rows regardless of order) +- `test_join_three_or_more_streams` — N-ary join works +- `test_join_empty_result_when_no_matches` — disjoint tags → empty stream +- `test_join_system_tag_name_extending` — system tag columns get ::pipeline_hash:position suffix +- `test_join_system_tag_values_sorted_for_commutativity` — canonical ordering of tag values +- `test_join_output_schema_prediction` — output_schema() matches actual output + +**MergeJoin (binary):** +- `test_merge_join_colliding_columns_become_sorted_lists` — same-name packet cols → list[T] +- `test_merge_join_requires_identical_types` — different types raise error +- `test_merge_join_non_colliding_columns_pass_through` — unmatched columns kept as-is +- `test_merge_join_system_tag_name_extending` +- `test_merge_join_output_schema_prediction` — predicts list[T] types correctly + +**SemiJoin (binary, non-commutative):** +- `test_semijoin_filters_left_by_right_tags` — keeps left rows matching right tags +- `test_semijoin_non_commutative` — semijoin(A, B) != semijoin(B, A) in general +- `test_semijoin_preserves_left_packet_columns` — right packet columns dropped +- `test_semijoin_system_tag_name_extending` + +**Batch:** +- `test_batch_groups_rows` — groups rows by tag, aggregates packets +- `test_batch_types_become_lists` — packet column types become list[T] +- `test_batch_system_tag_type_evolving` — system tag type becomes list[str] +- `test_batch_with_batch_size` — batch_size limits group size +- `test_batch_drop_partial_batch` — drop_partial_batch=True drops incomplete groups +- `test_batch_output_schema_prediction` — predicts list[T] types + +**Column Selection (Select/Drop Tag/Packet):** +- `test_select_tag_columns` — keeps only specified tag columns +- `test_select_tag_columns_strict_missing_raises` — strict=True raises on missing column +- `test_select_packet_columns` — keeps only specified packet columns +- `test_drop_tag_columns` — removes specified tag columns +- `test_drop_packet_columns` — removes specified packet columns +- `test_column_selection_system_tag_name_preserving` — system tags unchanged + +**MapTags/MapPackets:** +- `test_map_tags_renames_tag_columns` — renames specified tag columns +- `test_map_tags_drop_unmapped` — drop_unmapped=True removes unrenamed columns +- `test_map_packets_renames_packet_columns` +- `test_map_preserves_system_tags` — system tag columns unchanged (name-preserving) + +**PolarsFilter:** +- `test_polars_filter_with_predicate` — filters rows matching predicate +- `test_polars_filter_with_constraints` — filters by column=value constraints +- `test_polars_filter_preserves_schema` — output schema same as input +- `test_polars_filter_system_tag_name_preserving` + +**Operator Base Classes:** +- `test_unary_operator_rejects_multiple_inputs` — validate_inputs raises for >1 stream +- `test_binary_operator_rejects_wrong_count` — validate_inputs raises for !=2 streams +- `test_nonzero_input_operator_rejects_zero` — validate_inputs raises for 0 streams + +### 11. `test_nodes.py` — FunctionNode, OperatorNode, Persistent variants + +**FunctionNode:** +- `test_function_node_iter_packets` — iterates and transforms all packets +- `test_function_node_process_packet` — transforms single (tag, packet) pair +- `test_function_node_producer_is_function_pod` +- `test_function_node_upstreams` +- `test_function_node_clear_cache` + +**PersistentFunctionNode:** +- `test_persistent_fn_two_phase_iteration` — Phase 1: cached records, Phase 2: compute missing +- `test_persistent_fn_pipeline_path_uses_pipeline_hash` — path includes pipeline_hash +- `test_persistent_fn_caches_computed_results` — computed results stored in DB +- `test_persistent_fn_skips_already_cached` — Phase 2 skips inputs with cached outputs +- `test_persistent_fn_run_eagerly_processes_all` — run() processes all packets +- `test_persistent_fn_as_source_returns_derived_source` — as_source() returns DerivedSource + +**OperatorNode:** +- `test_operator_node_delegates_to_operator` +- `test_operator_node_clear_cache` +- `test_operator_node_run` + +**PersistentOperatorNode:** +- `test_persistent_on_cache_mode_off` — always recomputes +- `test_persistent_on_cache_mode_log` — computes and stores +- `test_persistent_on_cache_mode_replay` — loads from DB, no recompute +- `test_persistent_on_as_source_returns_derived_source` + +### 12. `test_hashing.py` — SemanticHasher, TypeHandlerRegistry + +**BaseSemanticHasher:** +- `test_hasher_primitives` — int, str, float, bool, None hashed deterministically +- `test_hasher_structures` — list, dict, tuple, set expanded structurally +- `test_hasher_content_hash_terminal` — ContentHash inputs returned as-is +- `test_hasher_content_identifiable_uses_identity_structure` — resolves via identity_structure() +- `test_hasher_unknown_type_strict_raises` — TypeError in strict mode +- `test_hasher_deterministic` — same input → same hash always +- `test_hasher_different_inputs_different_hashes` — collision resistance +- `test_hasher_nested_structures` — deeply nested dicts/lists hashed correctly + +**TypeHandlerRegistry:** +- `test_registry_register_and_lookup` — register handler, get_handler returns it +- `test_registry_mro_aware_lookup` — subclass falls back to parent handler +- `test_registry_unregister` — remove handler +- `test_registry_has_handler` — boolean check +- `test_registry_registered_types` — list all registered types +- `test_registry_thread_safety` — concurrent register/lookup doesn't crash + +**Built-in Handlers:** +- `test_path_handler_hashes_file_content` — Path → file content hash +- `test_path_handler_missing_file_raises` — FileNotFoundError +- `test_uuid_handler` — UUID → canonical string +- `test_bytes_handler` — bytes → hex string +- `test_function_handler` — function → signature-based identity +- `test_type_object_handler` — type → "type:module.qualname" +- `test_arrow_table_handler` — pa.Table → content hash + +### 13. `test_databases.py` — InMemory, DeltaLake, NoOp + +**InMemoryArrowDatabase:** +- `test_inmemory_add_and_get_record` — add_record + get_record_by_id roundtrip +- `test_inmemory_add_records_batch` — add_records with multiple rows +- `test_inmemory_get_all_records` — returns all at path +- `test_inmemory_get_records_by_ids` — returns subset by IDs +- `test_inmemory_skip_duplicates` — skip_duplicates=True doesn't raise +- `test_inmemory_pending_batch_semantics` — records not visible before flush() +- `test_inmemory_flush_makes_visible` — flush() commits pending records +- `test_inmemory_invalid_path_raises` — ValueError for empty/invalid paths +- `test_inmemory_get_nonexistent_returns_none` — missing path → None + +**NoOpArrowDatabase:** +- `test_noop_all_writes_silently_discarded` — add_record/add_records don't error +- `test_noop_all_reads_return_none` — get_* always returns None +- `test_noop_flush_noop` — flush() doesn't error + +**DeltaTableDatabase (if available):** +- `test_delta_add_and_get_record` — persistence roundtrip +- `test_delta_flush_writes_to_disk` — data survives flush +- `test_delta_path_validation` — invalid paths rejected + +### 14. `test_schema_utils.py` — Schema Utilities + +- `test_extract_function_schemas_from_annotations` — infers schemas from type hints +- `test_extract_function_schemas_rejects_variadic` — ValueError for *args/**kwargs +- `test_verify_packet_schema_valid` — matching dict passes +- `test_verify_packet_schema_type_mismatch` — mismatched types fail +- `test_check_schema_compatibility` — compatible types pass +- `test_infer_schema_from_dict` — infers types from values +- `test_union_schemas_no_conflict` — merges cleanly +- `test_union_schemas_with_conflict_raises` — TypeError on conflicting types +- `test_intersection_schemas` — returns common fields +- `test_get_compatible_type_int_float` — numeric promotion +- `test_get_compatible_type_incompatible_raises` — TypeError + +### 15. `test_arrow_utils.py` — Arrow Utilities + +- `test_schema_select` — selects subset of arrow schema columns +- `test_schema_select_missing_raises` — KeyError for missing columns +- `test_schema_drop` — drops specified columns +- `test_normalize_to_large_types` — string → large_string, etc. +- `test_pylist_to_pydict` — row-oriented → column-oriented +- `test_pydict_to_pylist` — column-oriented → row-oriented +- `test_pydict_to_pylist_inconsistent_lengths_raises` — ValueError +- `test_hstack_tables` — horizontal concatenation +- `test_hstack_tables_different_row_counts_raises` — ValueError +- `test_hstack_tables_duplicate_columns_raises` — ValueError +- `test_check_arrow_schema_compatibility` — compatible schemas pass +- `test_split_by_column_groups` — splits table into multiple tables + +### 16. `test_arrow_data_utils.py` — System Tags & Source Info + +- `test_add_system_tag_columns` — adds _tag:: prefixed columns +- `test_add_system_tag_columns_empty_table_raises` — ValueError +- `test_add_system_tag_columns_length_mismatch_raises` — ValueError +- `test_append_to_system_tags` — extends existing system tag values +- `test_sort_system_tag_values` — canonical sorting for commutativity +- `test_add_source_info` — adds _source_ prefixed columns +- `test_drop_columns_with_prefix` — removes columns matching prefix +- `test_drop_system_columns` — removes __ and __ prefixed columns + +### 17. `test_semantic_types.py` — UniversalTypeConverter + +- `test_python_to_arrow_type_primitives` — int→int64, str→large_string, etc. +- `test_python_to_arrow_type_list` — list[int]→large_list(int64) +- `test_python_to_arrow_type_dict` — dict→struct +- `test_arrow_to_python_type_roundtrip` — python→arrow→python recovers original +- `test_python_dicts_to_arrow_table` — list of dicts → pa.Table +- `test_arrow_table_to_python_dicts` — pa.Table → list of dicts +- `test_schema_conversion_roundtrip` — Schema→pa.Schema→Schema preserves types + +### 18. `test_contexts.py` — DataContext + +- `test_resolve_context_none_returns_default` — None → default context +- `test_resolve_context_string_version` — "v0.1" → matching context +- `test_resolve_context_datacontext_passthrough` — DataContext returned as-is +- `test_resolve_context_invalid_raises` — ContextResolutionError +- `test_get_available_contexts` — returns sorted version list +- `test_default_context_has_all_components` — type_converter, arrow_hasher, semantic_hasher present + +### 19. `test_tracker.py` — BasicTrackerManager, GraphTracker + +- `test_tracker_manager_register_deregister` — add/remove trackers +- `test_tracker_manager_broadcasts_invocations` — records sent to all active trackers +- `test_tracker_manager_no_tracking_context` — no_tracking() suspends recording +- `test_graph_tracker_records_function_pod_invocation` — node added to graph +- `test_graph_tracker_records_operator_invocation` — node added to graph +- `test_graph_tracker_compile_builds_graph` — compile() produces nx.DiGraph +- `test_graph_tracker_reset_clears_state` + +### 20. `test_lazy_module.py` — LazyModule + +- `test_lazy_module_not_loaded_initially` — is_loaded is False +- `test_lazy_module_loads_on_attribute_access` — accessing attr triggers import +- `test_lazy_module_force_load` — force_load() triggers immediate import +- `test_lazy_module_invalid_module_raises` — ModuleNotFoundError + +--- + +## Integration Test Cases + +### `test_pipeline_flows.py` — End-to-End Pipeline Scenarios + +- `test_source_to_stream_to_single_operator` — Source → Filter → Stream +- `test_source_to_function_pod` — Source → FunctionPod → Stream with transformed packets +- `test_multi_source_join` — Two sources → Join → Stream with combined data +- `test_chained_operators` — Source → Filter → Select → MapTags → Stream +- `test_function_pod_then_operator` — Source → FunctionPod → Filter → Stream +- `test_join_then_batch` — Two sources → Join → Batch → Stream +- `test_semijoin_filters_correctly` — Source A semi-joined with Source B +- `test_merge_join_combines_columns` — Two sources with overlapping columns → MergeJoin +- `test_diamond_pipeline` — Source → [branch A, branch B] → Join → Stream +- `test_pipeline_with_multiple_function_pods` — Source → FunctionPod1 → FunctionPod2 + +### `test_caching_flows.py` — DB-Backed Caching Scenarios + +- `test_persistent_function_node_caches_and_replays` — first run computes, second replays +- `test_persistent_function_node_incremental_update` — new input rows only compute missing +- `test_persistent_operator_node_log_mode` — CacheMode.LOG stores results +- `test_persistent_operator_node_replay_mode` — CacheMode.REPLAY loads from DB +- `test_derived_source_reingestion` — PersistentFunctionNode → DerivedSource → further pipeline +- `test_cached_packet_function_with_inmemory_db` — end-to-end caching flow + +### `test_hash_invariants.py` — Hash Stability & Merkle Chain Properties + +- `test_content_hash_stability_same_data` — identical data → identical hash across runs +- `test_content_hash_changes_with_data` — different data → different hash +- `test_pipeline_hash_ignores_data_content` — same schema, different data → same pipeline_hash +- `test_pipeline_hash_changes_with_schema` — different schema → different pipeline_hash +- `test_pipeline_hash_merkle_chain` — downstream hash commits to upstream hashes +- `test_commutative_join_pipeline_hash_order_independent` — join(A,B) pipeline_hash == join(B,A) +- `test_non_commutative_semijoin_pipeline_hash_order_dependent` — semijoin(A,B) != semijoin(B,A) + +### `test_provenance.py` — System Tag Lineage Tracking + +- `test_source_creates_system_tag_column` — source adds _tag::source:hash column +- `test_unary_operator_preserves_system_tags` — filter/select/map: name+value unchanged +- `test_join_extends_system_tag_names` — multi-input: column names get ::hash:pos suffix +- `test_join_sorts_system_tag_values` — commutative ops sort tag values +- `test_batch_evolves_system_tag_type` — batch: str → list[str] +- `test_full_pipeline_provenance_chain` — source → join → filter → batch: all rules applied + +### `test_column_config_filtering.py` — ColumnConfig Across All Components + +- `test_datagram_column_config_meta` — meta=True includes __ columns +- `test_datagram_column_config_data_only` — all False = data columns only +- `test_tag_column_config_system_tags` — system_tags=True includes _tag:: columns +- `test_packet_column_config_source` — source=True includes _source_ columns +- `test_stream_column_config_all_info` — all_info=True on keys/output_schema/as_table +- `test_stream_column_config_consistency` — keys(), output_schema(), as_table() all respect same config + +--- + +## Property-Based & Advanced Testing (test-objective/property/) + +### `test_schema_properties.py` (using Hypothesis) +- `test_schema_merge_commutative` — merge(A,B) == merge(B,A) when compatible +- `test_schema_select_then_drop_complementary` — select(X) ∪ drop(X) == original +- `test_schema_is_compatible_reflexive` — A.is_compatible_with(A) always True +- `test_schema_optional_fields_subset_of_all_fields` + +### `test_hash_properties.py` (using Hypothesis) +- `test_hash_deterministic` — hash(X) == hash(X) for any X +- `test_hash_changes_with_any_field_mutation` — mutate one value → different hash +- `test_content_hash_string_roundtrip` — from_string(to_string(h)) == h for any h + +### `test_operator_algebra.py` +- `test_join_commutativity` — join(A,B) data == join(B,A) data +- `test_join_associativity` — join(join(A,B),C) data == join(A,join(B,C)) data +- `test_filter_idempotency` — filter(filter(S, P), P) == filter(S, P) +- `test_select_then_select_is_intersection` — select(select(S, X), Y) == select(S, X∩Y) +- `test_drop_then_drop_is_union` — drop(drop(S, X), Y) == drop(S, X∪Y) + +--- + +## Suggestions for More Objective Testing + +### Included in `test-objective/property/`: +1. **Property-based testing** (Hypothesis) — generate random schemas, data, operations and verify algebraic invariants hold +2. **Algebraic property testing** — verify mathematical properties (commutativity of join, idempotency of filter, etc.) + +### Recommended additions (not implemented in this PR, but suggested): +3. **Mutation testing** with `mutmut` — run `uv run mutmut run --paths-to-mutate=src/orcapod/ --tests-dir=test-objective/` to verify tests catch code mutations. A surviving mutant indicates a test gap +4. **Metamorphic testing** — "if I add a row to source A that matches source B's tags, the join output should have one more row" — tests relationships between inputs/outputs without knowing exact expected values +5. **Protocol conformance automation** — use `runtime_checkable` protocols and `isinstance` checks to verify every concrete class satisfies its protocol at import time +6. **Specification oracle** — for each documented behavior in `orcapod-design.md`, create a test that constructs the exact scenario described and verifies the documented outcome +7. **Fuzz testing** — feed malformed inputs (wrong types, extreme sizes, Unicode edge cases) to constructors and verify graceful error handling + +--- + +## Implementation Order + +1. **`conftest.py`** — shared fixtures (reusable sources, streams, packet functions, databases) +2. **`unit/test_types.py`** — foundational types (Schema, ContentHash, ColumnConfig) +3. **`unit/test_datagram.py`**, **`test_tag.py`**, **`test_packet.py`** — data containers +4. **`unit/test_stream.py`** — stream construction and iteration +5. **`unit/test_sources.py`** + **`test_source_registry.py`** — all source types +6. **`unit/test_hashing.py`** — semantic hasher and handlers +7. **`unit/test_schema_utils.py`** + **`test_arrow_utils.py`** + **`test_arrow_data_utils.py`** — utilities +8. **`unit/test_semantic_types.py`** + **`test_contexts.py`** — type conversion and contexts +9. **`unit/test_databases.py`** — database implementations +10. **`unit/test_packet_function.py`** — packet function behavior +11. **`unit/test_function_pod.py`** — function pod and streams +12. **`unit/test_operators.py`** — all operators +13. **`unit/test_nodes.py`** — function/operator nodes +14. **`unit/test_tracker.py`** + **`test_lazy_module.py`** — remaining units +15. **`integration/`** — all integration test files +16. **`property/`** — property-based tests + +## Dependencies + +- **hypothesis** — added as a test dependency for property-based testing in `test-objective/property/` +- **pytest** — test runner (already present) +- DeltaTableDatabase tests marked with `@pytest.mark.slow` (skip with `-m "not slow"`) + +## Verification + +Run the full test suite with: +```bash +uv run pytest test-objective/ -v +``` + +Run only unit tests: +```bash +uv run pytest test-objective/unit/ -v +``` + +Run only integration tests: +```bash +uv run pytest test-objective/integration/ -v +``` + +Run only property tests: +```bash +uv run pytest test-objective/property/ -v +``` + +## Key Files to Modify/Create + +- **New:** `TESTING_PLAN.md` (project root) — the test case catalog document (content mirrors this plan) +- **New:** `test-objective/` directory tree — all files listed in the structure above +- **No modifications** to any existing source code or tests diff --git a/demo_caching.py b/demo_caching.py new file mode 100644 index 00000000..91639a66 --- /dev/null +++ b/demo_caching.py @@ -0,0 +1,358 @@ +""" +End-to-end demo: all three pod caching strategies at work. + +Demonstrates: +1. PersistentSource — always-on cache scoped to content_hash() + - DeltaTableSource with canonical source_id (defaults to dir name) + - Named sources: same name + same schema = same identity (data-independent) + - Unnamed sources: identity determined by table hash (data-dependent) +2. FunctionNode — pipeline_hash()-scoped cache, cross-source sharing + - Two pipelines with different source identities but same schema share one cache table +3. OperatorNode — content_hash()-scoped with CacheMode (OFF/LOG/REPLAY) +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pyarrow as pa +from deltalake import DeltaTable, write_deltalake + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode, OperatorNode +from orcapod.core.operators import Join +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource, DeltaTableSource, PersistentSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import CacheMode + +# Shared databases +source_db = InMemoryArrowDatabase() +pipeline_db = InMemoryArrowDatabase() +result_db = InMemoryArrowDatabase() +operator_db = InMemoryArrowDatabase() + +# ============================================================ +# STEP 1: PersistentSource with DeltaTableSource +# ============================================================ +print("=" * 70) +print("STEP 1: PersistentSource (source pod caching)") +print("=" * 70) + +with tempfile.TemporaryDirectory() as tmpdir: + # --- Create Delta tables on disk --- + patients_path = Path(tmpdir) / "patients" + labs_path = Path(tmpdir) / "labs" + + patients_arrow = pa.table( + { + "patient_id": pa.array(["p1", "p2", "p3"], type=pa.large_string()), + "age": pa.array([30, 45, 60], type=pa.int64()), + } + ) + labs_arrow = pa.table( + { + "patient_id": pa.array(["p1", "p2", "p3"], type=pa.large_string()), + "cholesterol": pa.array([180, 220, 260], type=pa.int64()), + } + ) + write_deltalake(str(patients_path), patients_arrow) + write_deltalake(str(labs_path), labs_arrow) + + # --- DeltaTableSource: source_id defaults to directory name --- + patients_src = DeltaTableSource(patients_path, tag_columns=["patient_id"]) + labs_src = DeltaTableSource(labs_path, tag_columns=["patient_id"]) + + print(f"\n patients_src.source_id: {patients_src.source_id!r}") + print(f" labs_src.source_id: {labs_src.source_id!r}") + print(" (defaults to Delta table directory name)") + + patients = PersistentSource(patients_src, cache_database=source_db) + labs = PersistentSource(labs_src, cache_database=source_db) + + patients.run() + labs.run() + + print(f"\n Patients cache_path: {patients.cache_path}") + print(f" Labs cache_path: {labs.cache_path}") + print( + f" Different tables (different source_id): {patients.cache_path != labs.cache_path}" + ) + + patients_records = patients.get_all_records() + labs_records = labs.get_all_records() + print(f"\n Patients cached rows: {patients_records.num_rows}") + print(f" Labs cached rows: {labs_records.num_rows}") + + # --- Named source identity: same name + same schema = same identity --- + print("\n --- Named source identity ---") + # Rebuild from same Delta dir (same name, same schema) → same content_hash + patients_src_2 = DeltaTableSource(patients_path, tag_columns=["patient_id"]) + patients_2 = PersistentSource(patients_src_2, cache_database=source_db) + print( + f" Same dir, same name → same content_hash: " + f"{patients.content_hash() == patients_2.content_hash()}" + ) + + # Now update the Delta table with new data — same dir name → same identity + patients_arrow_v2 = pa.table( + { + "patient_id": pa.array(["p1", "p2", "p3", "p4"], type=pa.large_string()), + "age": pa.array([30, 45, 60, 25], type=pa.int64()), + } + ) + write_deltalake(str(patients_path), patients_arrow_v2, mode="overwrite") + patients_src_updated = DeltaTableSource(patients_path, tag_columns=["patient_id"]) + patients_updated = PersistentSource(patients_src_updated, cache_database=source_db) + + print( + f" Updated data, same dir name → same source_id: " + f"{patients_src.source_id == patients_src_updated.source_id}" + ) + print( + f" Updated data, same dir name → same content_hash: " + f"{patients.content_hash() == patients_updated.content_hash()}" + ) + print(" (Named sources: identity = name + schema, not data content)") + + # Cumulative caching: new rows accumulate in the same cache table + patients_updated.run() + updated_records = patients_updated.get_all_records() + print( + f" After update + re-run, cached rows: {updated_records.num_rows} " + f"(3 original + 1 new, deduped)" + ) + + # --- Unnamed source: identity determined by table hash --- + print("\n --- Unnamed source identity (no source_id) ---") + t1 = pa.table( + { + "k": pa.array(["a"], type=pa.large_string()), + "v": pa.array([1], type=pa.int64()), + } + ) + t2 = pa.table( + { + "k": pa.array(["b"], type=pa.large_string()), + "v": pa.array([2], type=pa.int64()), + } + ) + unnamed_1 = ArrowTableSource(t1, tag_columns=["k"]) + unnamed_2 = ArrowTableSource(t2, tag_columns=["k"]) + print(f" unnamed_1.source_id: {unnamed_1.source_id!r}") + print(f" unnamed_2.source_id: {unnamed_2.source_id!r}") + print( + f" Different data → different source_id (table hash): " + f"{unnamed_1.source_id != unnamed_2.source_id}" + ) + print( + f" Different data → different content_hash: " + f"{unnamed_1.content_hash() != unnamed_2.content_hash()}" + ) + + # ============================================================ + # STEP 2: FunctionNode — cross-source sharing + # ============================================================ + print("\n" + "=" * 70) + print("STEP 2: FunctionNode (function pod caching)") + print("=" * 70) + + def risk_score(age: int, cholesterol: int) -> float: + """Simple risk = age * 0.5 + cholesterol * 0.3""" + return age * 0.5 + cholesterol * 0.3 + + pf = PythonPacketFunction(risk_score, output_keys="risk") + pod = FunctionPod(packet_function=pf) + + # Pipeline 1: original patients + labs + joined_1 = Join()(patients, labs) + fn_node_1 = FunctionNode( + function_pod=pod, + input_stream=joined_1, + pipeline_database=pipeline_db, + result_database=result_db, + ) + fn_node_1.run() + + print( + f"\n Pipeline 1 source_ids: {patients_src.source_id!r}, {labs_src.source_id!r}" + ) + print(f" Pipeline 1 pipeline_path: {fn_node_1.pipeline_path}") + fn_records_1 = fn_node_1.get_all_records() + print(f" Pipeline 1 stored records: {fn_records_1.num_rows}") + + print(f"\n Pipeline 1 output:") + print(fn_node_1.as_table().to_pandas().to_string(index=False)) + + # Pipeline 2: DIFFERENT sources, SAME schema + # Create completely independent sources with different names + patients_path_b = Path(tmpdir) / "clinic_b_patients" + labs_path_b = Path(tmpdir) / "clinic_b_labs" + write_deltalake( + str(patients_path_b), + pa.table( + { + "patient_id": pa.array(["x1", "x2"], type=pa.large_string()), + "age": pa.array([28, 72], type=pa.int64()), + } + ), + ) + write_deltalake( + str(labs_path_b), + pa.table( + { + "patient_id": pa.array(["x1", "x2"], type=pa.large_string()), + "cholesterol": pa.array([160, 290], type=pa.int64()), + } + ), + ) + patients_b = PersistentSource( + DeltaTableSource(patients_path_b, tag_columns=["patient_id"]), + cache_database=source_db, + ) + labs_b = PersistentSource( + DeltaTableSource(labs_path_b, tag_columns=["patient_id"]), + cache_database=source_db, + ) + + joined_2 = Join()(patients_b, labs_b) + fn_node_2 = FunctionNode( + function_pod=pod, + input_stream=joined_2, + pipeline_database=pipeline_db, + result_database=result_db, + ) + + print(f"\n Pipeline 2 source_ids: {patients_b.source_id!r}, {labs_b.source_id!r}") + print(f" Pipeline 2 pipeline_path: {fn_node_2.pipeline_path}") + print( + f" Same pipeline_path (cross-source sharing): " + f"{fn_node_1.pipeline_path == fn_node_2.pipeline_path}" + ) + print(" (pipeline_hash ignores source identity — only schema + topology matter)") + + fn_node_2.run() + fn_records_2 = fn_node_2.get_all_records() + print(f"\n After pipeline 2 run, shared DB records: {fn_records_2.num_rows}") + print(f" (pipeline 1's 3 records + pipeline 2's 2 new records = 5 total)") + + print(f"\n Pipeline 2 output:") + print(fn_node_2.as_table().to_pandas().to_string(index=False)) + + # ============================================================ + # STEP 3: OperatorNode — CacheMode + # ============================================================ + print("\n" + "=" * 70) + print("STEP 3: OperatorNode (operator pod caching)") + print("=" * 70) + + join_op = Join() + + # --- CacheMode.OFF (default): compute, no DB writes --- + print("\n --- CacheMode.OFF ---") + op_node_off = OperatorNode( + operator=join_op, + input_streams=[patients, labs], + pipeline_database=operator_db, + cache_mode=CacheMode.OFF, + ) + op_node_off.run() + off_records = operator_db.get_all_records(op_node_off.pipeline_path) + print(f" Computed rows: {op_node_off.as_table().num_rows}") + print( + f" DB records after OFF: " + f"{off_records.num_rows if off_records is not None else 'None (no writes)'}" + ) + + # --- CacheMode.LOG: compute AND write to DB --- + print("\n --- CacheMode.LOG ---") + op_node_log = OperatorNode( + operator=join_op, + input_streams=[patients, labs], + pipeline_database=operator_db, + cache_mode=CacheMode.LOG, + ) + op_node_log.run() + log_records = operator_db.get_all_records(op_node_log.pipeline_path) + print(f" Computed rows: {op_node_log.as_table().num_rows}") + print( + f" DB records after LOG: " + f"{log_records.num_rows if log_records is not None else 'None'}" + ) + print(f" Pipeline path: {op_node_log.pipeline_path}") + print(" (scoped to content_hash — each source combination gets its own table)") + + # Show content_hash scoping: different sources → different paths + op_node_b = OperatorNode( + operator=join_op, + input_streams=[patients_b, labs_b], + pipeline_database=operator_db, + cache_mode=CacheMode.LOG, + ) + print(f"\n Operator v1 pipeline_path: {op_node_log.pipeline_path}") + print(f" Operator v2 pipeline_path: {op_node_b.pipeline_path}") + print( + f" Different paths (content_hash scoping): " + f"{op_node_log.pipeline_path != op_node_b.pipeline_path}" + ) + + # --- CacheMode.REPLAY: skip computation, load from DB --- + print("\n --- CacheMode.REPLAY ---") + op_node_replay = OperatorNode( + operator=join_op, + input_streams=[patients, labs], + pipeline_database=operator_db, + cache_mode=CacheMode.REPLAY, + ) + op_node_replay.run() + print( + f" Replayed rows (from cache, no computation): " + f"{op_node_replay.as_table().num_rows}" + ) + + # --- REPLAY with no prior cache → empty stream --- + print("\n --- CacheMode.REPLAY with no prior cache ---") + op_node_empty = OperatorNode( + operator=join_op, + input_streams=[patients, labs], + pipeline_database=InMemoryArrowDatabase(), + cache_mode=CacheMode.REPLAY, + ) + op_node_empty.run() + empty_table = op_node_empty.as_table() + print(f" Empty cache → empty stream: {empty_table.num_rows} rows") + print(f" Schema preserved: {empty_table.column_names}") + +# ============================================================ +# SUMMARY +# ============================================================ +print("\n" + "=" * 70) +print("SUMMARY") +print("=" * 70) +print(""" + Source identity: + - Named sources (DeltaTable, CSV): source_id = canonical name (dir/file path) + → identity = (class, schema, name) — data-independent + - Unnamed sources (ArrowTableSource): source_id = table data hash + → identity = (class, schema, hash) — data-dependent + + Source pod (PersistentSource): + - Always-on caching, scoped to content_hash() + - Named sources: same name + same schema → same cache table + (data updates accumulate cumulatively, deduped by row hash) + - Transparent StreamProtocol — downstream is unaware of caching + + Function pod (FunctionNode): + - Cache scoped to pipeline_hash() (schema + topology only) + - Cross-source sharing: different source identities, same schema + → same pipeline_path → same cache table + - Rows distinguished by system tags (source_id + record_id) + + Operator pod (OperatorNode): + - Cache scoped to content_hash() (includes source identity) + - Different source combinations → different cache tables + - CacheMode.OFF: compute only (default) + - CacheMode.LOG: compute + persist + - CacheMode.REPLAY: load from cache, skip computation +""") diff --git a/demo_pipeline.py b/demo_pipeline.py new file mode 100644 index 00000000..74c132d1 --- /dev/null +++ b/demo_pipeline.py @@ -0,0 +1,356 @@ +""" +Pipeline demo: automatic persistent wrapping of all pipeline nodes. + +Demonstrates: +1. Building a pipeline with sources, operators (via convenience methods), + and function pods +2. Auto-compile on context exit — all nodes become persistent +3. Running the pipeline — data cached in the database +4. Accessing results by label +5. Persistence with DeltaTableDatabase — data survives across runs +6. Re-running a pipeline — only new data is computed +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pyarrow as pa + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode, OperatorNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.databases import DeltaTableDatabase, InMemoryArrowDatabase +from orcapod.core.nodes import SourceNode +from orcapod.pipeline import Pipeline + + +# --------------------------------------------------------------------------- +# Helper functions used in the pipeline +# --------------------------------------------------------------------------- + + +def risk_score(age: int, cholesterol: int) -> float: + """Simple risk = age * 0.5 + cholesterol * 0.3.""" + return age * 0.5 + cholesterol * 0.3 + + +def categorize(risk: float) -> str: + if risk < 80: + return "low" + elif risk < 120: + return "medium" + else: + return "high" + + +# --------------------------------------------------------------------------- +# Sources +# --------------------------------------------------------------------------- + +patients_table = pa.table( + { + "patient_id": pa.array(["p1", "p2", "p3"], type=pa.large_string()), + "age": pa.array([30, 45, 60], type=pa.int64()), + } +) + +labs_table = pa.table( + { + "patient_id": pa.array(["p1", "p2", "p3"], type=pa.large_string()), + "cholesterol": pa.array([180, 220, 260], type=pa.int64()), + } +) + +patients = ArrowTableSource(patients_table, tag_columns=["patient_id"]) +labs = ArrowTableSource(labs_table, tag_columns=["patient_id"]) + + +# ============================================================ +# PART 1: Pipeline with InMemoryArrowDatabase +# ============================================================ +print("=" * 70) +print("PART 1: Pipeline with InMemoryArrowDatabase") +print("=" * 70) + +db = InMemoryArrowDatabase() + +risk_fn = PythonPacketFunction(risk_score, output_keys="risk") +risk_pod = FunctionPod(packet_function=risk_fn) + +cat_fn = PythonPacketFunction(categorize, output_keys="category") +cat_pod = FunctionPod(packet_function=cat_fn) + +# --- Build the pipeline using convenience methods --- +pipeline = Pipeline(name="risk_pipeline", pipeline_database=db) + +with pipeline: + # .join() is a convenience method on any stream/source + joined = patients.join(labs, label="join_data") + risk_stream = risk_pod(joined, label="compute_risk") + cat_pod(risk_stream, label="categorize") + +# --- Inspect compiled nodes --- +print("\n Compiled nodes:") +for name, node in pipeline.compiled_nodes.items(): + print(f" {name}: {type(node).__name__}") + +print("\n Source nodes (SourceNode):") +for n in pipeline._node_graph.nodes(): + if isinstance(n, SourceNode): + print(f" label: {n.label}, stream: {type(n.stream).__name__}") + +# --- Access nodes by label --- +print(f"\n pipeline.join_data -> {type(pipeline.join_data).__name__}") +print(f" pipeline.compute_risk -> {type(pipeline.compute_risk).__name__}") +print(f" pipeline.categorize -> {type(pipeline.categorize).__name__}") + +# --- Node types --- +assert isinstance(pipeline.join_data, OperatorNode) +assert isinstance(pipeline.compute_risk, FunctionNode) +assert isinstance(pipeline.categorize, FunctionNode) +print("\n All node types verified.") + +# --- Run the pipeline --- +print("\n Running pipeline...") +pipeline.run() +print(" Done.") + +# --- Inspect results --- +risk_table = pipeline.compute_risk.as_table() +print(f"\n Risk scores:") +print(f" {risk_table.to_pandas()[['patient_id', 'risk']].to_string(index=False)}") + +cat_table = pipeline.categorize.as_table() +print(f"\n Categories:") +print(f" {cat_table.to_pandas()[['patient_id', 'category']].to_string(index=False)}") + +# --- Show what's in the database --- +fn_records = pipeline.compute_risk.get_all_records() +print(f" Function records (compute_risk): {fn_records.num_rows} rows") + +cat_records = pipeline.categorize.get_all_records() +print(f" Function records (categorize): {cat_records.num_rows} rows") + + +# ============================================================ +# PART 2: Persistence with DeltaTableDatabase +# ============================================================ +print("\n" + "=" * 70) +print("PART 2: Pipeline with DeltaTableDatabase (persistent storage)") +print("=" * 70) + +with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "pipeline_db" + + # --- First run: compute everything --- + print("\n --- First run ---") + delta_db = DeltaTableDatabase(base_path=db_path) + + pipe1 = Pipeline(name="persistent_demo", pipeline_database=delta_db) + with pipe1: + joined = patients.join(labs, label="joiner") + risk_pod(joined, label="scorer") + + pipe1.run() + + result = pipe1.scorer.as_table() + print(f" Computed {result.num_rows} risk scores:") + print(f" {result.to_pandas()[['patient_id', 'risk']].to_string(index=False)}") + + # Show files on disk + delta_tables = list(db_path.rglob("*.parquet")) + print(f"\n Parquet files on disk: {len(delta_tables)}") + for f in sorted(delta_tables): + print(f" {f.relative_to(db_path)}") + + # --- Second run: data already cached --- + print("\n --- Second run (same data -> reads from cache) ---") + delta_db_2 = DeltaTableDatabase(base_path=db_path) + + pipe2 = Pipeline(name="persistent_demo", pipeline_database=delta_db_2) + with pipe2: + joined = patients.join(labs, label="joiner") + risk_pod(joined, label="scorer") + + pipe2.run() + + result2 = pipe2.scorer.as_table() + print(f" Retrieved {result2.num_rows} risk scores (from cache):") + print(f" {result2.to_pandas()[['patient_id', 'risk']].to_string(index=False)}") + + # --- Third run: add new data -> only new rows computed --- + print("\n --- Third run (new patient added -> incremental computation) ---") + patients_v2 = ArrowTableSource( + pa.table( + { + "patient_id": pa.array( + ["p1", "p2", "p3", "p4"], type=pa.large_string() + ), + "age": pa.array([30, 45, 60, 25], type=pa.int64()), + } + ), + tag_columns=["patient_id"], + ) + labs_v2 = ArrowTableSource( + pa.table( + { + "patient_id": pa.array( + ["p1", "p2", "p3", "p4"], type=pa.large_string() + ), + "cholesterol": pa.array([180, 220, 260, 150], type=pa.int64()), + } + ), + tag_columns=["patient_id"], + ) + + delta_db_3 = DeltaTableDatabase(base_path=db_path) + + pipe3 = Pipeline(name="persistent_demo", pipeline_database=delta_db_3) + with pipe3: + joined = patients_v2.join(labs_v2, label="joiner") + risk_pod(joined, label="scorer") + + pipe3.run() + + result3 = pipe3.scorer.as_table() + print(f" Total risk scores after incremental run: {result3.num_rows}") + print(f" {result3.to_pandas()[['patient_id', 'risk']].to_string(index=False)}") + print(" (p4 was computed fresh; p1-p3 were already in the cache)") + + +# ============================================================ +# PART 3: Convenience methods in pipelines +# ============================================================ +print("\n" + "=" * 70) +print("PART 3: Convenience methods (.join, .select_packet_columns, .map_packets)") +print("=" * 70) + +db3 = InMemoryArrowDatabase() + +pipe = Pipeline(name="convenience_demo", pipeline_database=db3) + +with pipe: + # .join() on a source + joined = patients.join(labs, label="join_data") + # .select_packet_columns() to keep only "age" + ages_only = joined.select_packet_columns(["age"], label="select_age") + # .map_packets() to rename a column + renamed = ages_only.map_packets({"age": "patient_age"}, label="rename_col") + # function pod on the renamed stream + # (categorize expects "risk" but let's just show the chain works) + +print("\n Compiled nodes from chained convenience methods:") +for name, node in pipe.compiled_nodes.items(): + print(f" {name}: {type(node).__name__}") + +pipe.run() + +renamed_table = pipe.rename_col.as_table() +print(f"\n After select + rename:") +print(f" columns: {renamed_table.column_names}") +print(f" {renamed_table.to_pandas().to_string(index=False)}") + + +# ============================================================ +# PART 4: Separate function database +# ============================================================ +print("\n" + "=" * 70) +print("PART 4: Separate function_database for result isolation") +print("=" * 70) + +pipeline_db = InMemoryArrowDatabase() +function_db = InMemoryArrowDatabase() + +pipe = Pipeline( + name="isolated", + pipeline_database=pipeline_db, + function_database=function_db, +) + +with pipe: + joined = patients.join(labs, label="joiner") + risk_pod(joined, label="scorer") + +pipe.run() + +# Function results are stored in function_db, not pipeline_db +fn_node = pipe.scorer +print(f"\n pipeline_database and function_database are separate objects:") +print( + f" function result DB is function_db: " + f"{fn_node._packet_function._result_database is function_db}" +) + +# Show the record_path prefix includes the pipeline name +record_path = fn_node._packet_function.record_path +print(f"\n Function result record_path: {record_path}") + +# When function_database is None, results go under pipeline_name/_results +pipe_shared = Pipeline( + name="shared", + pipeline_database=pipeline_db, + function_database=None, # explicit None +) + +with pipe_shared: + joined = patients.join(labs, label="joiner") + risk_pod(joined, label="scorer") + +shared_record_path = pipe_shared.scorer._packet_function.record_path +print(f" Shared DB record_path: {shared_record_path}") +print( + f" Starts with ('shared', '_results'): " + f"{shared_record_path[:2] == ('shared', '_results')}" +) + + +# ============================================================ +# SUMMARY +# ============================================================ +print("\n" + "=" * 70) +print("SUMMARY") +print("=" * 70) +print(""" + Pipeline wraps ALL nodes as persistent variants automatically: + - Leaf streams -> SourceNode (graph vertex wrapper, no caching) + - Operator calls -> OperatorNode (DB-backed cache) + - Function pod calls -> FunctionNode (DB-backed cache) + + Building a pipeline (using stream convenience methods): + pipeline = Pipeline(name="my_pipe", pipeline_database=db) + with pipeline: + joined = src_a.join(src_b, label="my_join") + selected = joined.select_packet_columns(["col_a"], label="select") + pod(selected, label="my_func") + pipeline.run() # executes in topological order + + Available convenience methods on any stream/source: + stream.join(other) # Join + stream.semi_join(other) # SemiJoin + stream.map_tags({"old": "new"}) # MapTags + stream.map_packets({"a": "b"}) # MapPackets + stream.select_tag_columns([..]) # SelectTagColumns + stream.select_packet_columns(..) # SelectPacketColumns + stream.drop_tag_columns([..]) # DropTagColumns + stream.drop_packet_columns([..]) # DropPacketColumns + stream.batch(batch_size=N) # Batch + stream.polars_filter(col="val") # PolarsFilter + + Accessing results: + pipeline.my_join # -> OperatorNode + pipeline.my_func # -> FunctionNode + pipeline.my_func.as_table() # -> PyArrow Table with results + + Persistence: + - InMemoryArrowDatabase: fast, data lost when process exits + - DeltaTableDatabase: data persists to disk as Delta Lake tables + - Re-running with DeltaTableDatabase reads from cache; + new rows are computed incrementally + + Function database: + - function_database=None -> results stored under pipeline_name/_results/ + - function_database=db -> results stored in separate database +""") diff --git a/design/async-execution-implementation-plan.md b/design/async-execution-implementation-plan.md new file mode 100644 index 00000000..fb1f656b --- /dev/null +++ b/design/async-execution-implementation-plan.md @@ -0,0 +1,331 @@ +# Async Execution System — Implementation Plan + +**Design doc:** `design/async-execution-system.md` + +--- + +## Phase 1: Foundation (Channels, Protocols, Config) + +No existing code is modified. All new files. + +### Step 1.1 — Channel primitives + +**New file:** `src/orcapod/core/execution/channels.py` + +- `Channel[T]` — bounded async queue with close/done signaling +- `ReadableChannel[T]` — consumer side: `receive()`, `__aiter__`, `collect()` +- `WritableChannel[T]` — producer side: `send()`, `close()` +- `BroadcastChannel[T]` — fan-out: one writer, multiple independent readers +- `ChannelClosed` exception +- `create_channel(buffer_size: int) -> Channel` + +**Tests:** `tests/test_core/test_execution/test_channels.py` +- Single producer / single consumer +- Backpressure (full buffer blocks send) +- Close semantics (receive after close drains then raises) +- Broadcast (multiple readers get all items) +- Cancellation safety + +### Step 1.2 — Async execution protocol + +**New file:** `src/orcapod/protocols/core_protocols/async_execution.py` + +- `AsyncExecutableProtocol` — single `async_execute(inputs, output)` method +- `NodeConfigProtocol` — `max_concurrency` property + +**Modify:** `src/orcapod/protocols/core_protocols/__init__.py` +- Export new protocol + +### Step 1.3 — Configuration types + +**New file:** `src/orcapod/core/execution/config.py` + +- `ExecutorType` enum: `SYNCHRONOUS`, `ASYNC_CHANNELS` +- `PipelineConfig` frozen dataclass: `executor`, `channel_buffer_size`, `default_max_concurrency` +- `NodeConfig` frozen dataclass: `max_concurrency` +- `resolve_concurrency(node_config, pipeline_config) -> int | None` + +**Tests:** `tests/test_core/test_execution/test_config.py` +- NodeConfig overrides PipelineConfig default +- None means unlimited + +### Step 1.4 — Execution module init + +**New file:** `src/orcapod/core/execution/__init__.py` + +- Re-export public API: `Channel`, `ReadableChannel`, `WritableChannel`, + `PipelineConfig`, `NodeConfig`, `ExecutorType` + +--- + +## Phase 2: Default `async_execute` on Base Classes + +Add default barrier-mode `async_execute` to every base class. No behavioral change to existing +sync execution — this just makes every node async-capable. + +### Step 2.1 — Helper: materialize rows to stream + +**New file:** `src/orcapod/core/execution/materialization.py` + +- `materialize_to_stream(rows: list[tuple[TagProtocol, PacketProtocol]]) -> ArrowTableStream` + — converts a list of (tag, packet) pairs back into an ArrowTableStream +- `stream_to_rows(stream: StreamProtocol) -> list[tuple[TagProtocol, PacketProtocol]]` + — the inverse (thin wrapper around `iter_packets`) + +**Tests:** `tests/test_core/test_execution/test_materialization.py` +- Round-trip: stream → rows → stream preserves schema and data +- Empty stream round-trip + +### Step 2.2 — Default `async_execute` on `StaticOutputPod` + +**Modify:** `src/orcapod/core/static_output_pod.py` + +- Add `async_execute(self, inputs, output)` method to `StaticOutputPod`: + - Collects all input channels + - Materializes to streams + - Calls `self.static_process(*streams)` + - Emits results to output channel + - Closes output + +This gives ALL operators (Unary, Binary, NonZeroInput) a working async_execute by default. + +**Tests:** `tests/test_core/test_execution/test_barrier_default.py` +- Run a unary operator (e.g., Batch) through async_execute, compare output to static_process +- Run a binary operator (e.g., MergeJoin) through async_execute +- Run a multi-input operator (e.g., Join) through async_execute +- All should produce identical results to sync mode + +### Step 2.3 — `async_execute` on `_FunctionPodBase` + +**Modify:** `src/orcapod/core/function_pod.py` + +- Add `async_execute` to `_FunctionPodBase` (barrier mode by default) +- Add `async_execute` to `FunctionPod` (streaming mode with semaphore) +- Add `async_execute` to `FunctionNode` (streaming with cache check + semaphore) +- Add `node_config` property (defaults to `NodeConfig()`) + +**Tests:** `tests/test_core/test_execution/test_function_pod_async.py` +- FunctionPod streaming produces same results as sync +- FunctionNode with cache hits emits without semaphore +- max_concurrency=1 preserves ordering +- max_concurrency=N allows N concurrent invocations + +### Step 2.4 — `async_execute` on source nodes + +**Modify:** `src/orcapod/core/tracker.py` (SourceNode) + +- Add `async_execute` to `SourceNode`: iterates `self.stream.iter_packets()`, sends to output +- No input channels consumed + +**Tests:** `tests/test_core/test_execution/test_source_async.py` +- Source pushes all rows to output channel +- Empty source closes immediately + +--- + +## Phase 3: Orchestrator + +### Step 3.1 — DAG compilation for async execution + +**New file:** `src/orcapod/core/execution/dag.py` + +- `CompiledDAG` — nodes, edges, topological order, terminal node +- `compile_for_async(tracker: GraphTracker) -> CompiledDAG` + — takes an existing compiled GraphTracker and produces the async DAG structure + — identifies fan-out points (node output feeds multiple downstreams) for broadcast channels + +**Tests:** `tests/test_core/test_execution/test_dag.py` +- Linear pipeline: Source → Op → FunctionPod +- Diamond: Source → [Op1, Op2] → Join +- Fan-out detection + +### Step 3.2 — Async pipeline orchestrator + +**New file:** `src/orcapod/core/execution/orchestrator.py` + +- `AsyncPipelineOrchestrator` + - `run(graph, config) -> StreamProtocol` — entry point, calls `asyncio.run` + - `_run_async(graph, config)` — creates channels, launches tasks, collects result + - Error propagation via TaskGroup + - Timeout support (optional) + +**Tests:** `tests/test_core/test_execution/test_orchestrator.py` +- End-to-end: Source → Filter → FunctionPod via async orchestrator +- End-to-end: two Sources → Join → Map via async orchestrator +- Compare results to synchronous execution +- Error in one node cancels all others +- Backpressure: slow consumer throttles producer + +--- + +## Phase 4: Streaming Overrides for Concrete Operators + +Each step is independent — can be done in any order or in parallel. + +### Step 4.1 — Streaming column selection operators + +**Modify:** `src/orcapod/core/operators/column_selection.py` + +- Override `async_execute` on `SelectTagColumns`, `SelectPacketColumns`, + `DropTagColumns`, `DropPacketColumns` +- Each: iterate input, project/drop columns per row, emit + +**Tests:** `tests/test_core/test_execution/test_streaming_operators.py` +- Compare streaming async output to sync output for each operator +- Verify row-by-row emission (no buffering) + +### Step 4.2 — Streaming mappers + +**Modify:** `src/orcapod/core/operators/mappers.py` + +- Override `async_execute` on `MapTags`, `MapPackets` +- Each: iterate input, rename columns per row, emit + +**Tests:** added to `test_streaming_operators.py` + +### Step 4.3 — Streaming filter + +**Modify:** `src/orcapod/core/operators/filters.py` + +- Override `async_execute` on `PolarsFilter` +- Evaluate predicate per row, emit if passes + +**Tests:** added to `test_streaming_operators.py` + +### Step 4.4 — Incremental Join + +**Modify:** `src/orcapod/core/operators/join.py` + +- Override `async_execute` with symmetric hash join +- Concurrent consumption of all inputs via TaskGroup +- Per-row index probing and immediate emission +- System tag extension logic (reuse existing `_extend_system_tag_columns` logic) + +**Tests:** `tests/test_core/test_execution/test_incremental_join.py` +- Same result set as sync join (order may differ, compare as sets) +- Interleaved arrival from multiple inputs +- Single-input join (degenerates to pass-through) + +### Step 4.5 — Incremental MergeJoin + +**Modify:** `src/orcapod/core/operators/merge_join.py` + +- Override `async_execute` with symmetric hash join + list merge for colliding columns + +**Tests:** `tests/test_core/test_execution/test_incremental_merge_join.py` + +### Step 4.6 — Incremental SemiJoin + +**Modify:** `src/orcapod/core/operators/semijoin.py` + +- Override `async_execute`: buffer right side fully, then stream left + +**Tests:** `tests/test_core/test_execution/test_incremental_semijoin.py` + +--- + +## Phase 5: Integration and Wiring + +### Step 5.1 — Pipeline-level API + +**Determine integration point:** How does a user trigger async execution? + +Option A — `GraphTracker` gains a `run(config)` method: +```python +with GraphTracker() as tracker: + result = source | filter_op | func_pod +tracker.run(PipelineConfig(executor=ExecutorType.ASYNC_CHANNELS)) +``` + +Option B — A top-level `run_pipeline` function: +```python +result = run_pipeline(terminal_stream, config=PipelineConfig(...)) +``` + +The exact API will be determined during implementation. The orchestrator internals are +independent of this choice. + +### Step 5.2 — NodeConfig attachment + +Allow `NodeConfig` to be attached to operators/function pods: + +```python +func_pod = FunctionPod(my_func, node_config=NodeConfig(max_concurrency=4)) +filter_op = PolarsFilter(predicate, node_config=NodeConfig(max_concurrency=None)) +``` + +**Modify:** `StaticOutputPod.__init__`, `_FunctionPodBase.__init__` +- Accept optional `node_config: NodeConfig` parameter +- Default: `NodeConfig()` (inherit pipeline default) + +### Step 5.3 — End-to-end integration tests + +**New file:** `tests/test_core/test_execution/test_integration.py` + +- Full pipeline: Source → Filter → FunctionPod → Join → Map + - Run sync, run async, compare results +- Pipeline with mixed strategies: streaming filter + barrier batch + streaming map +- Pipeline with database-backed FunctionNode +- Concurrency behavior: verify max_concurrency limits are respected + +--- + +## Implementation Order and Dependencies + +``` +Phase 1 (Foundation) + ├── 1.1 Channels ──────────────────┐ + ├── 1.2 Protocol ──────────────────┤ + ├── 1.3 Config ────────────────────┤ + └── 1.4 Module init ──────────────┘ + │ +Phase 2 (Defaults) ▼ + ├── 2.1 Materialization helpers ───┐ + ├── 2.2 StaticOutputPod default ───┤ (depends on 1.x + 2.1) + ├── 2.3 FunctionPod async ─────────┤ + └── 2.4 SourceNode async ─────────┘ + │ +Phase 3 (Orchestrator) ▼ + ├── 3.1 DAG compilation ───────────┐ (depends on 2.x) + └── 3.2 Orchestrator ─────────────┘ + │ +Phase 4 (Streaming Overrides) ▼ + ├── 4.1 Column selection ──────────┐ + ├── 4.2 Mappers ───────────────────┤ (all independent, depend on 3.x) + ├── 4.3 Filter ────────────────────┤ + ├── 4.4 Join ──────────────────────┤ + ├── 4.5 MergeJoin ─────────────────┤ + └── 4.6 SemiJoin ──────────────────┘ + │ +Phase 5 (Integration) ▼ + ├── 5.1 Pipeline API ─────────────┐ + ├── 5.2 NodeConfig attachment ─────┤ (depends on 4.x) + └── 5.3 Integration tests ────────┘ +``` + +Phases 1–3 must be sequential. Phase 4 steps are independent of each other. +Phase 5 depends on everything above. + +--- + +## Risk Assessment + +| Risk | Mitigation | +|---|---| +| Row ordering differs between sync/async | Document clearly; `sort_by_tags` provides determinism | +| Incremental Join correctness | Extensive property-based tests comparing to sync | +| Deadlocks from channel misuse | Strict rule: every node MUST close output channel | +| Per-row Datagram operations are slow | Benchmark; fall back to barrier if perf regresses | +| Breaking existing tests | async_execute is additive; sync path unchanged | +| Fan-out channel memory | Bounded buffers + backpressure limit memory | + +--- + +## What's NOT in Scope + +- Distributed execution (network channels, Ray integration) — future work +- Adaptive concurrency tuning — future work +- Checkpointing / fault recovery — future work +- Modifications to `PacketFunctionExecutorProtocol` — orthogonal concern, unchanged +- Changes to hashing / identity — unchanged +- Changes to `CacheMode` semantics — unchanged diff --git a/design/async-execution-system.md b/design/async-execution-system.md new file mode 100644 index 00000000..ec3bfb14 --- /dev/null +++ b/design/async-execution-system.md @@ -0,0 +1,529 @@ +# Unified Async Channel Execution System + +**Status:** Proposed +**Date:** 2026-03-04 + +--- + +## Motivation + +The current execution model is synchronous and pull-based: each node materializes its full +output before the downstream node begins. This means a pipeline like +`Source → Filter → FunctionPod → Join → Map` processes in discrete stages — Filter waits for +all source rows, FunctionPod waits for all filtered rows, etc. + +This design proposes a **push-based async channel execution model** where every pipeline node +is a coroutine that consumes from input channels and produces to output channels. Rows flow +through the pipeline as soon as they're available, enabling: + +- **Streaming**: row-by-row operators (Filter, Map, Select, FunctionPod) emit immediately + without buffering +- **Incremental computation**: multi-input operators (Join) can emit matches as rows arrive + from any input, using techniques like symmetric hash join +- **Controlled concurrency**: per-node `max_concurrency` limits enable rate-limiting for + external API calls or GPU inference while allowing trivial operators to run unbounded +- **Backpressure**: bounded channels naturally throttle fast producers when downstream + consumers are slow + +Critically, the design is **backwards-compatible** — every existing synchronous operator works +unchanged via a barrier wrapper, and the executor type is selected at the pipeline level. + +--- + +## Core Design: One Interface for All Nodes + +### The Async Execute Protocol + +Every pipeline node — source, operator, or function pod — implements a single method: + +```python +@runtime_checkable +class AsyncExecutableProtocol(Protocol): + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + ) -> None: + """ + Consume (tag, packet) pairs from input channels, produce to output channel. + MUST close output channel when done (signals completion to downstream). + """ + ... +``` + +The orchestrator sees a **homogeneous DAG** — it doesn't need to know whether a node is an +operator, function pod, or source. It just wires up channels and launches tasks. + +### Channel Abstraction + +Channels are bounded async queues with close semantics: + +```python +@dataclass +class Channel(Generic[T]): + """Bounded async channel with close/done signaling.""" + _queue: asyncio.Queue[T | _Sentinel] + _closed: asyncio.Event + + @property + def reader(self) -> ReadableChannel[T]: ... + + @property + def writer(self) -> WritableChannel[T]: ... + + +class ReadableChannel(Protocol[T]): + """Consumer side of a channel.""" + + async def receive(self) -> T: + """Receive next item. Raises ChannelClosed when done.""" + ... + + def __aiter__(self) -> AsyncIterator[T]: ... + async def __anext__(self) -> T: ... + + async def collect(self) -> list[T]: + """Drain all remaining items into a list.""" + ... + + +class WritableChannel(Protocol[T]): + """Producer side of a channel.""" + + async def send(self, item: T) -> None: + """Send an item. Blocks if channel buffer is full (backpressure).""" + ... + + async def close(self) -> None: + """Signal that no more items will be sent.""" + ... +``` + +Bounded channels provide natural backpressure: a fast producer blocks on `send()` when the +buffer is full, automatically throttling without explicit flow control. + +### Thread Safety + +Channels are backed by `asyncio.Queue`, which is **coroutine-safe but not thread-safe**. +This is sufficient because all channel operations happen on the event loop thread: + +- `async_execute` methods are coroutines running on the event loop +- Sync `PacketFunction`s run in thread pools via `loop.run_in_executor`, but the result + is awaited back on the event loop before `output.send()` is called — the channel is + never touched from a worker thread +- The `async def` signature on `send()`/`receive()` structurally prevents direct calls + from non-async (thread) contexts + +If a future executor needs to push results directly from worker threads (bypassing the event +loop), channels should be swapped to a dual sync/async queue (e.g., `janus`) or use +`loop.call_soon_threadsafe` to marshal back to the event loop. This is not needed for the +current design. + +--- + +## Three Execution Strategies + +All three strategies implement the same `async_execute` interface. The differences are purely +in **when** the node reads, **how much** it buffers, and **when** it emits. + +### 1. Streaming (Row-by-Row) + +**Applies to:** Filter, MapTags, MapPackets, Select/Drop columns, FunctionPod + +Zero buffering. Each input row is independently transformed and emitted immediately. + +```python +# Example: PolarsFilter +async def async_execute(self, inputs, output): + async for tag, packet in inputs[0]: + if self._evaluate_predicate(tag, packet): + await output.send((tag, packet)) + await output.close() + +# Example: FunctionPod with concurrency control +async def async_execute(self, inputs, output): + sem = asyncio.Semaphore(self.node_config.max_concurrency or _INF) + + async def process_one(tag, packet): + async with sem: + result = await self.packet_function.async_call(packet) + if result is not None: + await output.send((tag, result)) + + async with asyncio.TaskGroup() as tg: + async for tag, packet in inputs[0]: + tg.create_task(process_one(tag, packet)) + + await output.close() +``` + +### 2. Incremental (Stateful, Eager Emit) + +**Applies to:** Join, MergeJoin, SemiJoin + +Maintains internal state (hash indexes). Emits matches as soon as they're found. + +```python +# Example: Symmetric Hash Join +async def async_execute(self, inputs, output): + indexes: list[dict[JoinKey, list[Row]]] = [{} for _ in inputs] + + async def consume(i: int, channel): + async for tag, packet in channel: + key = self._extract_join_key(tag) + indexes[i].setdefault(key, []).append((tag, packet)) + + # Probe all OTHER indexes for matches + other_lists = [indexes[j].get(key, []) for j in range(len(inputs)) if j != i] + for combo in itertools.product(*other_lists): + joined = self._merge_rows((tag, packet), *combo) + await output.send(joined) + + async with asyncio.TaskGroup() as tg: + for i, ch in enumerate(inputs): + tg.create_task(consume(i, ch)) + + await output.close() +``` + +For SemiJoin (non-commutative), the right side is buffered first, then left rows are probed: + +```python +async def async_execute(self, inputs, output): + left, right = inputs + + # Phase 1: Build right-side index + right_keys = set() + async for tag, packet in right: + key = self._extract_join_key(tag) + right_keys.add(key) + + # Phase 2: Stream left, emit matches + async for tag, packet in left: + key = self._extract_join_key(tag) + if key in right_keys: + await output.send((tag, packet)) + + await output.close() +``` + +### 3. Barrier (Fully Synchronous, Wrapped) + +**Applies to:** Batch, or any operator that hasn't implemented `async_execute` + +Collects all input, runs existing `static_process`, emits results. This is the **default** +implementation on operator base classes — every existing operator works without modification. + +```python +async def async_execute(self, inputs, output): + # Phase 1: Collect all inputs (the barrier) + collected = [await ch.collect() for ch in inputs] + + # Phase 2: Materialize into streams, run sync logic + streams = [self._materialize(rows) for rows in collected] + result_stream = self.static_process(*streams) + + # Phase 3: Emit results asynchronously + for tag, packet in result_stream.iter_packets(): + await output.send((tag, packet)) + + await output.close() +``` + +The barrier is a **local** bottleneck — upstream streaming nodes still push rows into the +barrier's input channel as they're produced, and downstream nodes receive rows as soon as +the barrier emits them. + +--- + +## Default Implementations (Backwards Compatibility) + +Operator base classes provide a default `async_execute` that wraps `static_process` in the +barrier pattern. Existing operators work without any changes: + +```python +class UnaryOperator(StaticOutputPod): + """Default: barrier mode. Override async_execute for streaming.""" + + async def async_execute(self, inputs, output): + rows = await inputs[0].collect() + stream = self._materialize_to_stream(rows) + result = self.static_process(stream) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + await output.close() + + +class BinaryOperator(StaticOutputPod): + async def async_execute(self, inputs, output): + left_rows, right_rows = await asyncio.gather( + inputs[0].collect(), inputs[1].collect() + ) + left_stream = self._materialize_to_stream(left_rows) + right_stream = self._materialize_to_stream(right_rows) + result = self.static_process(left_stream, right_stream) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + await output.close() + + +class NonZeroInputOperator(StaticOutputPod): + async def async_execute(self, inputs, output): + all_rows = await asyncio.gather(*(ch.collect() for ch in inputs)) + streams = [self._materialize_to_stream(rows) for rows in all_rows] + result = self.static_process(*streams) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + await output.close() +``` + +Concrete operators **opt into** better strategies by overriding `async_execute`. + +--- + +## FunctionPod and FunctionNode + +FunctionPod fits the streaming pattern naturally — it processes packets independently: + +```python +class FunctionPod: + async def async_execute(self, inputs, output): + sem = asyncio.Semaphore(self.node_config.max_concurrency or _INF) + + async def process_one(tag, packet): + async with sem: + result_packet = await self.packet_function.async_call(packet) + if result_packet is not None: + await output.send((tag, result_packet)) + + async with asyncio.TaskGroup() as tg: + async for tag, packet in inputs[0]: + tg.create_task(process_one(tag, packet)) + + await output.close() +``` + +FunctionNode adds DB-backed caching — cache hits emit immediately, misses go through the +semaphore: + +```python +class FunctionNode: + async def async_execute(self, inputs, output): + sem = asyncio.Semaphore(self.node_config.max_concurrency or _INF) + + async def process_one(tag, packet): + cache_key = self._compute_cache_key(packet) + cached = await self._db_lookup(cache_key) + if cached is not None: + await output.send((tag, cached)) + return + + async with sem: + result = await self.packet_function.async_call(packet) + await self._db_store(cache_key, result) + if result is not None: + await output.send((tag, result)) + + async with asyncio.TaskGroup() as tg: + async for tag, packet in inputs[0]: + tg.create_task(process_one(tag, packet)) + + await output.close() +``` + +### Sync PacketFunctions + +Existing synchronous `PacketFunction`s are bridged via `run_in_executor`: + +```python +class PythonPacketFunction: + async def direct_async_call(self, packet): + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + self._thread_pool, + self._func, + packet, + ) +``` + +CPU-bound functions run in a thread pool. Async-native functions (API calls, I/O) can +override `direct_async_call` directly. + +--- + +## Configuration + +### Two-Level Config + +```python +class ExecutorType(Enum): + SYNCHRONOUS = "synchronous" # Current behavior: static_process chain + ASYNC_CHANNELS = "async_channels" # New: async_execute with channels + +@dataclass(frozen=True) +class PipelineConfig: + executor: ExecutorType = ExecutorType.SYNCHRONOUS + channel_buffer_size: int = 64 + default_max_concurrency: int | None = None # pipeline-wide default + +@dataclass(frozen=True) +class NodeConfig: + max_concurrency: int | None = None # overrides pipeline default + # None = inherit from pipeline default + # 1 = sequential (rate-limited APIs, ordered output) + # N = up to N packets in-flight concurrently +``` + +### Concurrency Resolution + +```python +def resolve_concurrency(node_config: NodeConfig, pipeline_config: PipelineConfig) -> int | None: + if node_config.max_concurrency is not None: + return node_config.max_concurrency + return pipeline_config.default_max_concurrency +``` + +Examples: +- `max_concurrency=1`: sequential processing (rate-limited API, preserves ordering) +- `max_concurrency=8`: bounded parallelism (GPU inference, external service) +- `max_concurrency=None` (unlimited): trivial ops (column select, rename) + +--- + +## Orchestrator + +The orchestrator builds the DAG, creates channels, and launches all nodes concurrently: + +```python +class AsyncPipelineOrchestrator: + + def run(self, graph: CompiledGraph, config: PipelineConfig) -> StreamProtocol: + """Entry point — runs async pipeline, returns materialized result.""" + return asyncio.run(self._run_async(graph, config)) + + async def _run_async(self, graph, config): + buf = config.channel_buffer_size + + # Create a channel for each edge in the DAG + channels: dict[EdgeId, Channel] = { + edge: Channel(buffer_size=buf) for edge in graph.edges + } + + # Launch every node as a concurrent task + async with asyncio.TaskGroup() as tg: + for node in graph.nodes: + input_chs = [channels[e].reader for e in node.input_edges] + output_ch = channels[node.output_edge].writer + tg.create_task(node.async_execute(input_chs, output_ch)) + + # Collect terminal output + terminal_rows = await channels[graph.terminal_edge].collect() + return self._materialize(terminal_rows) +``` + +### Source Nodes + +Sources have no input channels — they just push their data onto the output channel: + +```python +class SourceNode: + async def async_execute(self, inputs, output): + # inputs is empty for sources + for tag, packet in self.stream.iter_packets(): + await output.send((tag, packet)) + await output.close() +``` + +### Fan-Out (Multiple Consumers) + +When a node's output feeds multiple downstream nodes, the channel is **broadcast** — each +downstream gets its own reader over a shared sequence. This avoids duplicating computation +while allowing each consumer to read at its own pace. + +--- + +## Operator Classification + +| Operator | Default Strategy | Async Override? | +|---|---|---| +| PolarsFilter | Barrier (inherited) | **Streaming** — evaluate predicate per row | +| MapTags / MapPackets | Barrier (inherited) | **Streaming** — rename per row | +| SelectTagColumns / SelectPacketColumns | Barrier (inherited) | **Streaming** — project per row | +| DropTagColumns / DropPacketColumns | Barrier (inherited) | **Streaming** — project per row | +| FunctionPod | N/A (new) | **Streaming** — transform packet per row | +| FunctionNode | N/A (new) | **Streaming** — cache check + transform per row | +| Join | Barrier (inherited) | **Incremental** — symmetric hash join | +| MergeJoin | Barrier (inherited) | **Incremental** — symmetric hash join with merge | +| SemiJoin | Barrier (inherited) | **Incremental** — buffer right, stream left | +| Batch | Barrier (inherited) | Barrier (inherent) — needs all rows for grouping | + +All operators work in barrier mode by default. Streaming/incremental overrides are added +incrementally — the system is correct at every step. + +--- + +## Interaction with Existing Execution Models + +### Synchronous Mode (ExecutorType.SYNCHRONOUS) + +Unchanged. The existing `static_process` / `DynamicPodStream` / `iter_packets` chain continues +to work exactly as before. `async_execute` is never called. + +### Async Mode (ExecutorType.ASYNC_CHANNELS) + +The orchestrator calls `async_execute` on every node. The existing `static_process` is used +by the barrier wrapper as an implementation detail — it's not called directly by the +orchestrator. + +### PacketFunctionExecutorProtocol + +The existing executor protocol (`execute` / `async_execute` for individual packets) remains +unchanged. It controls **how a single packet function invocation runs** (local, Ray, etc.). +The new `async_execute` on nodes controls **how the node participates in the pipeline DAG**. +These are orthogonal concerns: + +- `PacketFunctionExecutorProtocol.async_execute(fn, packet)` → single invocation strategy +- `FunctionPod.async_execute(inputs, output)` → pipeline-level data flow + +### GraphTracker + +The `GraphTracker` continues to build the DAG via `record_*_invocation` calls. The compiled +graph it produces is what the `AsyncPipelineOrchestrator` consumes. The tracker doesn't need +to know about async execution — it only records topology. + +--- + +## Row Ordering Considerations + +Streaming and incremental strategies may change row ordering compared to synchronous mode: + +- **Streaming with concurrency**: `max_concurrency > 1` on FunctionPod means packets may + complete out of order. If ordering matters, set `max_concurrency=1`. +- **Incremental Join**: rows are emitted as matches are found, which depends on arrival order + from upstream. The result set is identical but row order may differ. +- **Barrier**: row order matches synchronous mode exactly. + +The `sort_by_tags` option in `ColumnConfig` provides deterministic ordering when needed, +independent of execution strategy. + +--- + +## Error Propagation + +When a node raises an exception inside `async_execute`: + +1. The `TaskGroup` propagates the exception, cancelling all other tasks +2. Channel close semantics ensure no deadlocks — cancelled producers don't block consumers +3. The orchestrator surfaces the original exception to the caller + +This is handled naturally by Python's `asyncio.TaskGroup` semantics. + +--- + +## Future Extensions + +- **Distributed execution**: Replace local channels with network channels (e.g., gRPC streams) + while keeping the same `async_execute` interface +- **Adaptive concurrency**: Auto-tune `max_concurrency` based on throughput/latency metrics +- **Checkpointing**: Persist channel state for fault recovery in long-running pipelines +- **Backpressure metrics**: Expose channel fill levels for monitoring and debugging diff --git a/docs/api/databases.md b/docs/api/databases.md new file mode 100644 index 00000000..728a14e7 --- /dev/null +++ b/docs/api/databases.md @@ -0,0 +1,9 @@ +# Databases + +Database backends for persistent storage of pipeline computation results. + +::: orcapod.databases.DeltaTableDatabase + +::: orcapod.databases.InMemoryArrowDatabase + +::: orcapod.databases.NoOpArrowDatabase diff --git a/docs/api/function-pods.md b/docs/api/function-pods.md new file mode 100644 index 00000000..fc4a27b6 --- /dev/null +++ b/docs/api/function-pods.md @@ -0,0 +1,13 @@ +# Function Pods + +Function pods wrap Python functions to transform individual packets in a stream. + +::: orcapod.core.function_pod.FunctionPod + +::: orcapod.core.function_pod.function_pod + +::: orcapod.core.packet_function.PythonPacketFunction + +::: orcapod.core.packet_function.CachedPacketFunction + +::: orcapod.core.packet_function.PacketFunctionBase diff --git a/docs/api/index.md b/docs/api/index.md new file mode 100644 index 00000000..3d054439 --- /dev/null +++ b/docs/api/index.md @@ -0,0 +1,26 @@ +# API Reference + +This section contains the full API reference for Orcapod, auto-generated from +source code docstrings. + +## Package structure + +The top-level `orcapod` namespace exposes the most commonly used entry points +directly: + +| Symbol | Description | +|--------|-------------| +| [`FunctionPod`](function-pods.md) | Wraps a Python function to transform packets in a stream | +| [`function_pod`](function-pods.md) | Decorator that attaches a `FunctionPod` to a callable | +| [`Pipeline`](pipeline.md) | Top-level orchestration for composing and executing pipelines | + +Everything else lives in subpackages: + +| Subpackage | Description | +|------------|-------------| +| [`orcapod.sources`](sources.md) | Source classes for ingesting external data into pipelines | +| [`orcapod.operators`](operators.md) | Structural stream transformations (join, filter, select, batch, etc.) | +| [`orcapod.databases`](databases.md) | Persistent storage backends for computation results | +| [`orcapod.nodes`](nodes.md) | DB-backed pipeline elements that persist their results | +| [`orcapod.streams`](streams.md) | Immutable (Tag, Packet) sequences backed by PyArrow tables | +| [`orcapod.types`](types.md) | Core type definitions: `Schema`, `ColumnConfig`, `ContentHash` | diff --git a/docs/api/nodes.md b/docs/api/nodes.md new file mode 100644 index 00000000..c9f36a95 --- /dev/null +++ b/docs/api/nodes.md @@ -0,0 +1,9 @@ +# Nodes + +Nodes are DB-backed pipeline elements that persist their computation results. + +::: orcapod.core.nodes.FunctionNode + +::: orcapod.core.nodes.OperatorNode + +::: orcapod.core.nodes.SourceNode diff --git a/docs/api/operators.md b/docs/api/operators.md new file mode 100644 index 00000000..89527a81 --- /dev/null +++ b/docs/api/operators.md @@ -0,0 +1,25 @@ +# Operators + +Operators perform structural transformations on streams without inspecting or synthesizing packet values. + +::: orcapod.core.operators.Join + +::: orcapod.core.operators.MergeJoin + +::: orcapod.core.operators.SemiJoin + +::: orcapod.core.operators.Batch + +::: orcapod.core.operators.SelectTagColumns + +::: orcapod.core.operators.SelectPacketColumns + +::: orcapod.core.operators.DropTagColumns + +::: orcapod.core.operators.DropPacketColumns + +::: orcapod.core.operators.MapTags + +::: orcapod.core.operators.MapPackets + +::: orcapod.core.operators.PolarsFilter diff --git a/docs/api/pipeline.md b/docs/api/pipeline.md new file mode 100644 index 00000000..ad1f2b37 --- /dev/null +++ b/docs/api/pipeline.md @@ -0,0 +1,5 @@ +# Pipeline + +Top-level orchestration for composing and executing Orcapod data pipelines. + +::: orcapod.pipeline.Pipeline diff --git a/docs/api/sources.md b/docs/api/sources.md new file mode 100644 index 00000000..7895999d --- /dev/null +++ b/docs/api/sources.md @@ -0,0 +1,27 @@ +# Sources + +Source classes provide the entry point for external data into Orcapod pipelines. +All sources convert their input to a PyArrow Table and use `SourceStreamBuilder` for +enrichment (provenance columns, system tags, hashing). + +::: orcapod.core.sources.ArrowTableSource + +::: orcapod.core.sources.DictSource + +::: orcapod.core.sources.ListSource + +::: orcapod.core.sources.DataFrameSource + +::: orcapod.core.sources.CSVSource + +::: orcapod.core.sources.DeltaTableSource + +::: orcapod.core.sources.DerivedSource + +::: orcapod.core.sources.CachedSource + +::: orcapod.core.sources.SourceProxy + +::: orcapod.core.sources.SourceRegistry + +::: orcapod.core.sources.RootSource diff --git a/docs/api/streams.md b/docs/api/streams.md new file mode 100644 index 00000000..aff3daac --- /dev/null +++ b/docs/api/streams.md @@ -0,0 +1,7 @@ +# Streams + +Streams are immutable sequences of (Tag, Packet) pairs backed by PyArrow tables. + +::: orcapod.core.streams.ArrowTableStream + +::: orcapod.core.streams.StreamBase diff --git a/docs/api/types.md b/docs/api/types.md new file mode 100644 index 00000000..4dcc76a3 --- /dev/null +++ b/docs/api/types.md @@ -0,0 +1,9 @@ +# Types + +Core type definitions used throughout Orcapod. + +::: orcapod.types.Schema + +::: orcapod.types.ColumnConfig + +::: orcapod.types.ContentHash diff --git a/docs/concepts/function-pods.md b/docs/concepts/function-pods.md new file mode 100644 index 00000000..247e8ab3 --- /dev/null +++ b/docs/concepts/function-pods.md @@ -0,0 +1,162 @@ +# Function Pods + +Function pods are packet-level transforms -- they take each packet in a +[stream](streams.md), apply a Python function to its values, and produce a new packet with the +function's outputs. Unlike [operators](operators.md), function pods never inspect or modify +tags. They are the primary mechanism for adding computation to an Orcapod pipeline: data +cleaning, feature extraction, model inference, or any transformation that produces new values +from existing ones. + +## The `function_pod` decorator + +The most common way to create a function pod is with the `@function_pod` decorator. It wraps +a plain Python function so it can be applied to streams: + +```python +from orcapod import function_pod + +@function_pod(output_keys="bmi") +def compute_bmi(weight: float, height: float) -> float: + return weight / (height ** 2) +``` + +Key points: + +- **`output_keys`** names the output packet column(s). A single string means one output + column; a list of strings means the function returns multiple values. +- **Function parameters** must match the input stream's packet column names. Orcapod uses + the function signature to determine which columns to read. +- **Type annotations** on parameters are used to validate schema compatibility. + +The decorated function still works as a normal Python function. The pod is accessible via the +`.pod` attribute: + +```python +# Call as a normal function +result = compute_bmi(weight=25.3, height=0.12) + +# Access the pod for pipeline use +pod = compute_bmi.pod +``` + +## `FunctionPod` -- lazy in-memory execution + +`FunctionPod` is the pod class created by the decorator. When you call `.process()` on it, it +returns a `FunctionPodStream` -- a lazy stream that applies the function to each packet on +demand: + +```python +from orcapod import function_pod +from orcapod.sources import DictSource + +@function_pod(output_keys="bmi") +def compute_bmi(weight: float, height: float) -> float: + return weight / (height ** 2) + +source = DictSource( + data=[ + {"subject_id": "mouse_01", "weight": 25.3, "height": 0.12}, + {"subject_id": "mouse_02", "weight": 22.1, "height": 0.10}, + ], + tag_columns=["subject_id"], +) + +# Apply the function pod to the source stream +result = compute_bmi.pod(source) # shorthand for compute_bmi.pod.process(source) + +# Inspect the output schema -- tags pass through, packets are replaced +tag_schema, packet_schema = result.output_schema() +print("Tag schema:", dict(tag_schema)) +# Tag schema: {'subject_id': } +print("Packet schema:", dict(packet_schema)) +# Packet schema: {'bmi': } + +# Iterate over results +for tag, packet in result.iter_packets(): + print(f" {tag.as_dict()} -> {packet.as_dict()}") +# {'subject_id': 'mouse_01'} -> {'bmi': 1756.9444444444446} +# {'subject_id': 'mouse_02'} -> {'bmi': 2209.9999999999995} +``` + +The function pod preserves tags and replaces packet columns with the function's output. If the +input stream has multiple packet columns but the function only needs some of them, Orcapod +extracts the matching columns by name. + +!!! tip + All standard pods support `__call__` as a shorthand for `.process()`, so + `compute_bmi.pod(source)` is equivalent to `compute_bmi.pod.process(source)`. + +## `FunctionNode` -- DB-backed cached execution + +For persistent, cached execution, use `FunctionNode`. It wraps a function pod and an input +stream, storing results in a database. This enables two-phase iteration: on subsequent runs, +cached results are returned immediately, and only new inputs are computed. + +```python +from orcapod import function_pod +from orcapod.sources import DictSource +from orcapod.nodes import FunctionNode +from orcapod.databases import InMemoryArrowDatabase + +@function_pod(output_keys="bmi") +def compute_bmi(weight: float, height: float) -> float: + return weight / (height ** 2) + +source = DictSource( + data=[ + {"subject_id": "mouse_01", "weight": 25.3, "height": 0.12}, + {"subject_id": "mouse_02", "weight": 22.1, "height": 0.10}, + ], + tag_columns=["subject_id"], +) + +db = InMemoryArrowDatabase() +node = FunctionNode( + function_pod=compute_bmi.pod, + input_stream=source, + pipeline_database=db, + result_database=db, +) + +# Run the node -- computes and stores all results +node.run() + +# Iterate over cached results +for tag, packet in node.iter_packets(): + print(f" {tag.as_dict()} -> {packet.as_dict()}") +``` + +`FunctionNode` also provides: + +- **`as_source()`** -- returns a `DerivedSource` backed by the node's stored results, which + can be used as input to downstream pipelines +- **`get_all_records()`** -- returns the stored PyArrow Table directly + +## Multiple input streams + +If you pass multiple streams to a function pod, they are automatically joined (using +[Join](operators.md)) before the function is applied: + +```python +result = compute_bmi.pod(weight_stream, height_stream) +``` + +The join happens on shared tag columns, and the merged packet columns are fed to the function. + +## PacketFunction internals + +Under the hood, the `function_pod` decorator creates a `PythonPacketFunction`, which wraps a +Python callable with input/output schema metadata. When a result database is provided, the +packet function is further wrapped in a `CachedPacketFunction` that checks the database before +calling the underlying function. + +These are implementation details -- the `@function_pod` decorator and `FunctionNode` are the +primary user-facing APIs. + +## How it connects to other concepts + +- Function pods consume and produce [Streams](streams.md) +- [Sources](sources.md) produce the input streams that function pods transform +- [Operators](operators.md) handle structural transforms (the complement of function pods) +- Function pods participate in the [identity chain](identity.md) -- each pod's hash includes + the function's identity and its upstream pipeline hashes diff --git a/docs/concepts/identity.md b/docs/concepts/identity.md new file mode 100644 index 00000000..aa4b6286 --- /dev/null +++ b/docs/concepts/identity.md @@ -0,0 +1,123 @@ +# Identity & Hashing + +Every pipeline element in Orcapod -- [sources](sources.md), [streams](streams.md), +[operators](operators.md), and [function pods](function-pods.md) -- carries two parallel +identity hashes. These hashes enable Orcapod to deduplicate computations, scope database +storage, and detect when a pipeline's structure or data has changed. + +## Two identity chains + +### `content_hash()` -- recursive source-inclusive identity + +The content hash captures schema, topology, and the **identity of the sources** feeding the +pipeline. It is computed recursively: each element's content hash depends on its own identity +plus the content hashes of all its upstream elements, all the way back to the sources. + +What "source identity" means depends on the source type. For in-memory sources like +`ArrowTableSource` or `DictSource`, the content hash includes the actual data values. For +storage-backed sources like `DeltaTableSource`, the content hash is derived from the source's +canonical identity (e.g., its path and metadata) rather than the raw data. The key point is +that the content hash changes whenever a different source is used, even if the schema is the +same. + +**Used for:** deduplication and memoization. When a `FunctionNode` processes a packet, it +checks the packet's content hash against its database. If the hash already exists, the cached +result is returned without recomputation. + +### `pipeline_hash()` -- schema and topology only + +The pipeline hash captures the pipeline's **structure** -- schemas, function identities, and +how elements are connected -- but deliberately ignores source identity. Two pipeline elements +with identical schemas and the same computational graph have the same pipeline hash, even if +they are fed by completely different sources. + +**Used for:** database scoping. Pipeline hash determines the database table path where results +are stored. This means that two `FunctionNode` instances with the same function and the same +input schema share the same database table -- regardless of which source produced the data. +This is a powerful feature: it means that running the same function on new data automatically +benefits from results already cached for previous data with the same schema. + +## How it works in practice + +Consider two sources with the same schema but different data: + +``` +source_a = DictSource(data=[{"x": 1, "y": 2}], tag_columns=["x"]) +source_b = DictSource(data=[{"x": 10, "y": 20}], tag_columns=["x"]) +``` + +- `source_a.content_hash() != source_b.content_hash()` -- different source identity +- `source_a.pipeline_hash() == source_b.pipeline_hash()` -- same schema and structure + +If both sources feed into the same function via `FunctionNode`, the nodes share a database +table (same pipeline hash), but each packet is stored and retrieved by its content hash. + +## The Merkle chain + +Pipeline hashes form a **Merkle chain** -- each element's pipeline hash commits to its own +identity plus the pipeline hashes of all its upstream elements. + +### Base case: sources + +A `RootSource`'s pipeline identity is simply its `(tag_schema, packet_schema)`. Sources with +the same column names and types have the same pipeline hash, regardless of their data. + +### Recursive case: downstream elements + +Each downstream element (operator, function pod, or node) computes its pipeline hash from: + +1. Its own identity (e.g., the function's name, version, and output schema for a function pod; + the operator class name for an operator) +2. The pipeline hashes of all its upstream streams + +This creates a chain: changing any element's structure (renaming a column, modifying a +function, adding an operator) changes the pipeline hash of that element and all downstream +elements, while leaving upstream hashes unchanged. + +## The resolver pattern + +Orcapod uses a resolver pattern to determine which hash method to call on different objects: + +- Objects implementing `PipelineElementProtocol` (sources, operators, function pods, streams) + route through `pipeline_hash()` +- Other `ContentIdentifiable` objects (like raw data values) route through `content_hash()` + +This distinction matters when computing pipeline identity structures: a function pod's +pipeline hash depends on its upstream stream's pipeline hash (structural), not its content +hash (data-inclusive). + +## Identity in practice + +### Memoization + +When a `FunctionNode` iterates over its input stream, it follows a two-phase process: + +1. **Phase 1 (cached):** Read all existing records from the shared pipeline database table + and yield them immediately. +2. **Phase 2 (compute):** For each input packet whose content hash is not already in the + database, run the function, store the result, and yield it. + +This means that adding new data to a source only triggers computation for the new rows -- +previously computed results are served from the cache. + +### DB path scoping + +The pipeline hash determines the database table path. Two `FunctionNode` instances that +apply the same function to sources with the same schema will write to and read from the same +table. This is intentional: it maximizes cache reuse across pipeline runs. + +### Change detection + +If you modify a function (change its code, rename its output, bump its version), its identity +changes, which changes the pipeline hash of its node and all downstream nodes. This +automatically creates new database table paths, preventing stale cached results from being +returned for the modified pipeline. + +## How it connects to other concepts + +- [Sources](sources.md) form the base case of the Merkle chain (schema-only for pipeline hash, + source identity for content hash) +- [Streams](streams.md) carry both hashes and propagate them through the pipeline +- [Operators](operators.md) contribute their own identity to the chain +- [Function Pods](function-pods.md) and `FunctionNode` use content hashes for memoization + and pipeline hashes for DB scoping diff --git a/docs/concepts/operators.md b/docs/concepts/operators.md new file mode 100644 index 00000000..5319e56d --- /dev/null +++ b/docs/concepts/operators.md @@ -0,0 +1,191 @@ +# Operators + +Operators are structural transforms that reshape [streams](streams.md) without inspecting or +synthesizing packet values. They join, filter, batch, rename, and select columns -- operations +that affect the *structure* of the data (which rows exist, which columns are present, how +columns are named) but never compute new values from packet content. This is the key +distinction from [function pods](function-pods.md), which do the opposite: they transform +packet values but never touch tags or stream structure. + +## The operator / function pod boundary + +This separation is a core Orcapod design principle: + +| | Operator | Function Pod | +|---|---|---| +| Inspects packet content | Never | Yes | +| Inspects / uses tags | Yes | No | +| Can rename columns | Yes | No | +| Synthesizes new values | No | Yes | +| Stream arity | Configurable (1, 2, or N inputs) | Single in, single out | + +This boundary ensures that structural operations (joins, filters) and value computations +(transformations, model inference) are cleanly separated, making pipelines easier to reason +about and optimize. + +## Operator categories + +### `UnaryOperator` -- single input + +Takes one stream, produces one stream. Used for filtering, column selection, renaming, and +batching. + +### `BinaryOperator` -- two inputs + +Takes exactly two streams. Used for `MergeJoin` and `SemiJoin`. + +### `NonZeroInputOperator` -- one or more inputs + +Takes one or more streams. Used for `Join`, which performs an N-ary inner join. + +## Available operators + +### Join + +N-ary inner join on shared tag columns. Requires that input streams have non-overlapping +packet columns (raises `InputValidationError` on collision). Join is **commutative** -- the +order of input streams does not affect the result. + +```python +from orcapod.sources import DictSource +from orcapod.operators import Join + +subjects = DictSource( + data=[ + {"subject_id": "mouse_01", "age": 12}, + {"subject_id": "mouse_02", "age": 8}, + ], + tag_columns=["subject_id"], +) + +measurements = DictSource( + data=[ + {"subject_id": "mouse_01", "weight": 25.3}, + {"subject_id": "mouse_02", "weight": 22.1}, + ], + tag_columns=["subject_id"], +) + +join = Join() +joined = join.process(subjects, measurements) +print(joined.as_table().to_pandas()) +# subject_id age weight +# 0 mouse_01 12 25.3 +# 1 mouse_02 8 22.1 +``` + +### MergeJoin + +Binary join that handles colliding packet columns by merging their values into sorted +`list[T]`. Both inputs must have the same type for any colliding packet columns. MergeJoin +is **commutative** -- the order of the two input streams does not affect the result. + +### SemiJoin + +Binary join that filters the left stream to only include rows whose tags match the right +stream. The right stream's packet columns are discarded. SemiJoin is **not commutative** -- +the order of inputs matters. The first stream is the one being filtered; the second stream +provides the set of matching tags. + +### Batch + +Groups all rows into a single row (or fixed-size batches), converting column types from `T` +to `list[T]`. Useful for aggregation-style processing. + +```python +from orcapod.sources import DictSource +from orcapod.operators import Batch + +source = DictSource( + data=[ + {"subject_id": "mouse_01", "age": 12}, + {"subject_id": "mouse_02", "age": 8}, + ], + tag_columns=["subject_id"], +) + +batch = Batch() +batched = batch.process(source) +for tag, packet in batched.iter_packets(): + print("Tags:", tag.as_dict()) + # Tags: {'subject_id': ['mouse_01', 'mouse_02']} + print("Packet:", packet.as_dict()) + # Packet: {'age': [12, 8]} +``` + +Pass `batch_size=N` to create fixed-size batches instead of grouping everything: + +```python +batch = Batch(batch_size=10, drop_partial_batch=False) +``` + +### Column selection + +Four operators for including or excluding columns: + +- **`SelectTagColumns(columns=["col1", "col2"])`** -- keep only the specified tag columns +- **`SelectPacketColumns(columns=["col1", "col2"])`** -- keep only the specified packet columns +- **`DropTagColumns(columns=["col1"])`** -- remove the specified tag columns +- **`DropPacketColumns(columns=["col1"])`** -- remove the specified packet columns + +```python +from orcapod.operators import SelectPacketColumns + +select = SelectPacketColumns(columns=["weight"]) +result = select.process(source) +print(result.keys()[1]) # ('weight',) +``` + +### Column renaming + +- **`MapTags(mapping={"old_name": "new_name"})`** -- rename tag columns +- **`MapPackets(mapping={"old_name": "new_name"})`** -- rename packet columns + +### PolarsFilter + +Filter rows using Polars expressions: + +```python +import polars as pl +from orcapod.operators import PolarsFilter + +filt = PolarsFilter(predicates=[pl.col("age") > 10]) +filtered = filt.process(source) +for tag, pkt in filtered.iter_packets(): + print(f"{tag.as_dict()} -> {pkt.as_dict()}") +# {'subject_id': 'mouse_01'} -> {'age': 12, 'weight': 25.3} +# {'subject_id': 'mouse_03'} -> {'age': 15, 'weight': 27.8} +``` + +You can also filter by exact values using `constraints`: + +```python +filt = PolarsFilter(constraints={"subject_id": "mouse_01"}) +``` + +## Using operators + +All operators follow the same interface. Call `.process()` with one or more input streams: + +```python +operator = Join() +result_stream = operator.process(stream_a, stream_b) +``` + +All standard operators also support `__call__` as a shorthand for `.process()`, so you can +write: + +```python +result_stream = Join()(stream_a, stream_b) +``` + +The result is a new [stream](streams.md) that you can inspect, iterate, or pass to further +operators or function pods. + +## How it connects to other concepts + +- Operators consume and produce [Streams](streams.md) +- [Sources](sources.md) produce the initial streams that operators transform +- [Function Pods](function-pods.md) handle value-level transforms (the complement of operators) +- Operators participate in the [identity chain](identity.md) -- each operator's hash includes + its own identity plus its upstream hashes diff --git a/docs/concepts/sources.md b/docs/concepts/sources.md new file mode 100644 index 00000000..ffb849a2 --- /dev/null +++ b/docs/concepts/sources.md @@ -0,0 +1,119 @@ +# Sources + +Sources are the entry points for external data into an Orcapod pipeline. Every pipeline begins +with one or more sources that load raw data -- from Python dicts, lists, CSV files, Delta Lake +tables, or Pandas DataFrames -- and present it as an immutable +[stream](streams.md) of (Tag, Packet) pairs. Sources also attach provenance metadata +(source-info columns and system tag columns) so that every downstream value can be traced back +to its origin. + +## Key classes + +### `RootSource` (abstract base) + +All sources inherit from `RootSource`. A root source is a pure stream with no upstream +dependencies -- it sits at the root of the computational graph. Key properties: + +- `source.producer` returns `None` (no upstream pod) +- `source.upstreams` is always an empty tuple +- `source.source_id` is a canonical name used for provenance tracking and the source registry + +### Concrete source types + +All sources follow the same pattern: convert input data to a PyArrow Table, then pass it +through `SourceStreamBuilder` which handles enrichment (provenance columns, system tags, +hashing) and produces the final immutable stream. + +| Source | Input type | Notes | +|---|---|---| +| `ArrowTableSource` | PyArrow `Table` | Accepts an Arrow table directly | +| `DictSource` | `list[dict]` | Each dict becomes one (Tag, Packet) pair | +| `ListSource` | `list[Any]` | Each element stored under a named packet column | +| `DataFrameSource` | Pandas `DataFrame` | Converts via Arrow | +| `CSVSource` | File path (string) | Reads CSV into Arrow | +| `DeltaTableSource` | File path (string) | Reads Delta Lake table | + +### `DerivedSource` + +A `DerivedSource` reads the computed results of a [FunctionNode or OperatorNode](function-pods.md) +from its database and presents them as a new source. This is useful for chaining pipeline +stages: run a node, then use its output as input to a new pipeline. + +## How sources add provenance columns + +When you create a source, Orcapod automatically adds two kinds of hidden columns to track +data lineage: + +**Source-info columns** (prefix `_source_`) store a provenance token for each packet column. +For example, a packet column `weight` gets a companion `_source_weight` column. These tokens +identify which source originally produced each value. + +**System tag columns** (prefix `_tag::`) track which source contributed each row. These +columns are used internally during [joins](operators.md) to maintain provenance through +multi-stream operations. + +These columns are hidden by default. You can reveal them using `ColumnConfig`: + +```python +# Show source-info columns +table = source.as_table(columns={"source": True}) + +# Show system tag columns +tag_schema, packet_schema = source.output_schema(columns={"system_tags": True}) + +# Show everything +table = source.as_table(all_info=True) +``` + +## Code example + +Create a `DictSource`, inspect its schema, and iterate over its stream: + +```python +from orcapod.sources import DictSource + +source = DictSource( + data=[ + {"subject_id": "mouse_01", "age": 12, "weight": 25.3}, + {"subject_id": "mouse_02", "age": 8, "weight": 22.1}, + {"subject_id": "mouse_03", "age": 15, "weight": 27.8}, + ], + tag_columns=["subject_id"], +) + +# Inspect the schema +tag_schema, packet_schema = source.output_schema() +print("Tag schema:", dict(tag_schema)) +# Tag schema: {'subject_id': } +print("Packet schema:", dict(packet_schema)) +# Packet schema: {'age': , 'weight': } + +# Get column names +tag_keys, packet_keys = source.keys() +print("Tag keys:", tag_keys) # ('subject_id',) +print("Packet keys:", packet_keys) # ('age', 'weight') + +# Iterate over (Tag, Packet) pairs +for tag, packet in source.iter_packets(): + print(f" Tag: {tag.as_dict()}, Packet: {packet.as_dict()}") +# Tag: {'subject_id': 'mouse_01'}, Packet: {'age': 12, 'weight': 25.3} +# Tag: {'subject_id': 'mouse_02'}, Packet: {'age': 8, 'weight': 22.1} +# Tag: {'subject_id': 'mouse_03'}, Packet: {'age': 15, 'weight': 27.8} + +# Convert to a PyArrow table +table = source.as_table() +print(table.to_pandas()) +# subject_id age weight +# 0 mouse_01 12 25.3 +# 1 mouse_02 8 22.1 +# 2 mouse_03 15 27.8 +``` + +## How it connects to other concepts + +- Sources produce [Streams](streams.md) -- immutable sequences of (Tag, Packet) pairs +- Streams flow into [Operators](operators.md) for structural transforms (joins, filters, + column selection) +- Streams flow into [Function Pods](function-pods.md) for value-level transforms +- Every source has a `content_hash()` and `pipeline_hash()` -- see + [Identity & Hashing](identity.md) for how these work diff --git a/docs/concepts/streams.md b/docs/concepts/streams.md new file mode 100644 index 00000000..e72dab79 --- /dev/null +++ b/docs/concepts/streams.md @@ -0,0 +1,172 @@ +# Streams + +A stream is an immutable sequence of (Tag, Packet) pairs backed by a PyArrow Table. Streams +are the universal data currency in Orcapod -- every [source](sources.md) produces a stream, +every [operator](operators.md) consumes and produces streams, and every +[function pod](function-pods.md) transforms packets within a stream. Immutability guarantees +that once a stream is created, its data cannot change, which is essential for reproducible +pipelines. + +## Tag columns vs Packet columns + +Every stream divides its columns into two groups: + +**Tag columns** are join keys and metadata. They identify *which* record you are looking at +(e.g., `subject_id`, `session_date`). Operators like [Join](operators.md) match rows across +streams using shared tag columns. + +**Packet columns** are the data payload. They hold the actual values being processed +(e.g., `age`, `weight`, `spike_count`). [Function pods](function-pods.md) read packet +columns as function inputs and write new packet columns as outputs. + +This separation is enforced throughout the framework: + +- Operators inspect and restructure tags but never look inside packets +- Function pods inspect and transform packets but never look at tags + +## Key classes + +### `ArrowTableStream` + +The primary stream implementation. Wraps a PyArrow Table with designated tag and packet +columns. Created internally by sources and operators -- you rarely construct one directly. + +### `StreamBase` + +Abstract base class providing the stream interface. Both `ArrowTableStream` and stream-like +objects (sources, function pod streams, nodes) inherit from it. + +## Core methods + +Every stream exposes four key methods: + +### `output_schema()` + +Returns the `(tag_schema, packet_schema)` tuple describing column names and their Python types: + +```python +tag_schema, packet_schema = stream.output_schema() +print(dict(tag_schema)) # {'subject_id': } +print(dict(packet_schema)) # {'age': , 'weight': } +``` + +### `keys()` + +Returns column names as `(tag_keys, packet_keys)`: + +```python +tag_keys, packet_keys = stream.keys() +# tag_keys = ('subject_id',) +# packet_keys = ('age', 'weight') +``` + +### `iter_packets()` + +Iterates over (Tag, Packet) pairs. Each Tag and Packet is an immutable datagram that you can +inspect with `.as_dict()`: + +```python +for tag, packet in stream.iter_packets(): + print(tag.as_dict()) # {'subject_id': 'mouse_01'} + print(packet.as_dict()) # {'age': 12, 'weight': 25.3} +``` + +### `as_table()` + +Returns the full stream as a PyArrow Table, which integrates with Pandas, Polars, and other +Arrow-compatible tools: + +```python +table = stream.as_table() +df = table.to_pandas() +``` + +## Controlling column visibility with `ColumnConfig` + +By default, streams only expose user-facing tag and packet columns. Orcapod also maintains +hidden columns for provenance tracking and metadata. Use `ColumnConfig` (or the `all_info` +shortcut) to control which column groups are included. + +| Config field | What it reveals | Column prefix | +|---|---|---| +| `system_tags` | System tag columns (provenance tracking) | `_tag::` | +| `source` | Source-info columns (per-packet provenance tokens) | `_source_` | +| `context` | Data context column | `_context_key` | +| `content_hash` | Content hash column | `_content_hash` | +| `sort_by_tags` | Sort rows by tag columns | (ordering only) | + +Pass config as a dict or a `ColumnConfig` object: + +```python +from orcapod.sources import DictSource + +source = DictSource( + data=[ + {"subject_id": "mouse_01", "age": 12, "weight": 25.3}, + {"subject_id": "mouse_02", "age": 8, "weight": 22.1}, + ], + tag_columns=["subject_id"], +) + +# Default: user-facing columns only +table = source.as_table() +print(table.column_names) +# ['subject_id', 'age', 'weight'] + +# Include source-info columns +table = source.as_table(columns={"source": True}) +print(table.column_names) +# ['subject_id', 'age', 'weight', '_source_age', '_source_weight'] + +# Include everything +table = source.as_table(all_info=True) +print(table.column_names) +# ['subject_id', 'age', 'weight', '_tag_source_id::...', '_tag_record_id::...', +# '_content_hash', '_context_key', '_source_age', '_source_weight'] +``` + +## Code example + +Inspect a stream produced by a source: + +```python +from orcapod.sources import DictSource + +source = DictSource( + data=[ + {"subject_id": "mouse_01", "age": 12, "weight": 25.3}, + {"subject_id": "mouse_02", "age": 8, "weight": 22.1}, + {"subject_id": "mouse_03", "age": 15, "weight": 27.8}, + ], + tag_columns=["subject_id"], +) + +# Schema inspection +tag_schema, packet_schema = source.output_schema() +print("Tag schema:", dict(tag_schema)) +# Tag schema: {'subject_id': } +print("Packet schema:", dict(packet_schema)) +# Packet schema: {'age': , 'weight': } + +# Iterate over (Tag, Packet) pairs +for tag, packet in source.iter_packets(): + print(f" {tag.as_dict()} -> {packet.as_dict()}") +# {'subject_id': 'mouse_01'} -> {'age': 12, 'weight': 25.3} +# {'subject_id': 'mouse_02'} -> {'age': 8, 'weight': 22.1} +# {'subject_id': 'mouse_03'} -> {'age': 15, 'weight': 27.8} + +# Convert to a PyArrow table (interops with Pandas) +table = source.as_table() +print(table.to_pandas()) +# subject_id age weight +# 0 mouse_01 12 25.3 +# 1 mouse_02 8 22.1 +# 2 mouse_03 15 27.8 +``` + +## How it connects to other concepts + +- [Sources](sources.md) produce streams from external data +- [Operators](operators.md) consume one or more streams and produce a new stream +- [Function Pods](function-pods.md) transform packet values within a stream +- Every stream carries [identity hashes](identity.md) for deduplication and DB-scoping diff --git a/docs/getting-started.md b/docs/getting-started.md new file mode 100644 index 00000000..6e8f8e68 --- /dev/null +++ b/docs/getting-started.md @@ -0,0 +1,216 @@ +# Getting Started + +This guide walks you through the core workflow in Orcapod: creating a data source, +inspecting its stream, applying a transformation with a function pod, and examining +the results. + +## Creating a source + +A **source** is the entry point for data in an Orcapod pipeline. The simplest way to +get started is with `DictSource`, which accepts a list of dictionaries: + +```python +from orcapod.sources import DictSource + +source = DictSource( + data=[ + {"experiment": "exp_001", "temperature": 20.5, "pressure": 1.01}, + {"experiment": "exp_002", "temperature": 22.3, "pressure": 0.98}, + {"experiment": "exp_003", "temperature": 19.8, "pressure": 1.05}, + ], + tag_columns=["experiment"], + source_id="lab_results", +) +``` + +There are two important concepts here: + +- **Tag columns** (`tag_columns`) are the keys that identify each row -- like primary keys + in a database or independent variables in an experiment. Here, `experiment` uniquely + identifies each measurement. +- **Packet columns** are everything else -- the actual data payload. In this example, + `temperature` and `pressure` are the packet columns. + +The `source_id` is a human-readable name used for provenance tracking. + +!!! note + Orcapod provides several source types beyond `DictSource`: `ListSource`, `CSVSource`, + `ArrowTableSource`, `DataFrameSource`, and `DeltaTableSource`. They all produce the + same kind of immutable stream. See [Sources](concepts/sources.md) for details. + +## Inspecting the stream + +In Orcapod, a source *is* a stream. You can inspect it immediately without any extra +conversion step. + +### Schema + +Use `output_schema()` to see the tag and packet column types: + +```python +tag_schema, packet_schema = source.output_schema() +print(tag_schema) +# Schema({'experiment': }) +print(packet_schema) +# Schema({'temperature': , 'pressure': }) +``` + +### Column names + +Use `keys()` to get just the column names: + +```python +tag_keys, packet_keys = source.keys() +print(tag_keys) +# ('experiment',) +print(packet_keys) +# ('temperature', 'pressure') +``` + +### Iterating over rows + +Use `iter_packets()` to walk through each (Tag, Packet) pair: + +```python +for tag, packet in source.iter_packets(): + print(f"Tag: {tag.as_dict()}, Packet: {packet.as_dict()}") +# Tag: {'experiment': 'exp_001'}, Packet: {'temperature': 20.5, 'pressure': 1.01} +# Tag: {'experiment': 'exp_002'}, Packet: {'temperature': 22.3, 'pressure': 0.98} +# Tag: {'experiment': 'exp_003'}, Packet: {'temperature': 19.8, 'pressure': 1.05} +``` + +### Getting the full table + +Use `as_table()` to get a PyArrow Table, which you can convert to pandas: + +```python +table = source.as_table() +print(table.to_pandas()) +# experiment temperature pressure +# 0 exp_001 20.5 1.01 +# 1 exp_002 22.3 0.98 +# 2 exp_003 19.8 1.05 +``` + +## Applying a function pod + +A **function pod** transforms packet data row by row. Use the `@function_pod` decorator +to turn a plain Python function into a reusable, trackable transformation: + +```python +from orcapod import function_pod + +@function_pod(output_keys=["temp_fahrenheit", "is_high_pressure"]) +def analyze_conditions(temperature: float, pressure: float) -> tuple[float, bool]: + temp_f = temperature * 9.0 / 5.0 + 32.0 + is_high = pressure > 1.0 + return temp_f, is_high +``` + +A few things to note about function pods: + +- **Parameter names match packet column names.** The function's parameter names + (`temperature`, `pressure`) must match the packet column names from the input stream. + Orcapod uses type annotations to validate compatibility at process time. +- **`output_keys` names the output columns.** Since the function returns a tuple of two + values, `output_keys` must be a list of two names. For a function returning a single + value, pass a single string (e.g., `output_keys="result"`). +- **The decorator creates a `.pod` attribute.** You call `.pod.process()` to apply + the function to a stream. + +Now apply it: + +```python +result = analyze_conditions.pod(source) +``` + +!!! tip + All standard pods support `__call__` as a shorthand for `.process()`, so + `pod(stream)` is equivalent to `pod.process(stream)`. + +The `result` is a new stream. Tags are preserved from the input; the packet columns +are replaced with the function's outputs. + +## Inspecting the result + +The result stream supports the same inspection methods as the source: + +```python +tag_schema, packet_schema = result.output_schema() +print(tag_schema) +# Schema({'experiment': }) +print(packet_schema) +# Schema({'temp_fahrenheit': , 'is_high_pressure': }) +``` + +The tag schema is unchanged -- function pods never modify tags. The packet schema +now reflects the function's output types. + +Iterate over the results: + +```python +for tag, packet in result.iter_packets(): + print(f"Tag: {tag.as_dict()}, Packet: {packet.as_dict()}") +# Tag: {'experiment': 'exp_001'}, Packet: {'temp_fahrenheit': 68.9, 'is_high_pressure': True} +# Tag: {'experiment': 'exp_002'}, Packet: {'temp_fahrenheit': 72.14, 'is_high_pressure': False} +# Tag: {'experiment': 'exp_003'}, Packet: {'temp_fahrenheit': 67.64, 'is_high_pressure': True} +``` + +Or view it as a table: + +```python +print(result.as_table().to_pandas()) +# experiment temp_fahrenheit is_high_pressure +# 0 exp_001 68.90 True +# 1 exp_002 72.14 False +# 2 exp_003 67.64 True +``` + +## Putting it all together + +Here is the complete example in one block: + +```python +from orcapod import function_pod +from orcapod.sources import DictSource + +# Create a source from raw data +source = DictSource( + data=[ + {"experiment": "exp_001", "temperature": 20.5, "pressure": 1.01}, + {"experiment": "exp_002", "temperature": 22.3, "pressure": 0.98}, + {"experiment": "exp_003", "temperature": 19.8, "pressure": 1.05}, + ], + tag_columns=["experiment"], + source_id="lab_results", +) + +# Define a transformation +@function_pod(output_keys=["temp_fahrenheit", "is_high_pressure"]) +def analyze_conditions(temperature: float, pressure: float) -> tuple[float, bool]: + temp_f = temperature * 9.0 / 5.0 + 32.0 + is_high = pressure > 1.0 + return temp_f, is_high + +# Apply and view results +result = analyze_conditions.pod(source) +print(result.as_table().to_pandas()) +# experiment temp_fahrenheit is_high_pressure +# 0 exp_001 68.90 True +# 1 exp_002 72.14 False +# 2 exp_003 67.64 True +``` + +## Next steps + +Now that you have the basics, explore these topics: + +- [Sources](concepts/sources.md) -- learn about the different source types and how + provenance tracking works. +- [Streams](concepts/streams.md) -- understand the immutable (Tag, Packet) stream model. +- [Function Pods](concepts/function-pods.md) -- advanced function pod usage, including + caching with databases. +- [Operators](concepts/operators.md) -- structural transforms like Join, Batch, and Filter + that work on tags and stream structure without inspecting packet content. +- [Identity & Hashing](concepts/identity.md) -- how Orcapod tracks content identity and + pipeline structure for reproducibility. diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 00000000..b08847a8 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,60 @@ +# Orcapod + +Orcapod is an intuitive and powerful Python library for building highly reproducible +scientific data pipelines. It provides a structured way to load, transform, and track data +through typed, immutable streams -- ensuring that every step of your analysis is traceable, +content-addressable, and reproducible by design. + +## Installation + +Install from [PyPI](https://pypi.org/project/orcapod/) using [uv](https://docs.astral.sh/uv/): + +```bash +uv add orcapod +``` + +For the latest development version, install directly from GitHub: + +```bash +uv add git+https://github.com/walkerlab/orcapod-python.git +``` + +## Quick example + +Create a data source, apply a transformation, and inspect the results -- all in a few lines: + +```python +from orcapod import function_pod +from orcapod.sources import DictSource + +# 1. Load data into a source +source = DictSource( + data=[ + {"name": "Alice", "age": 30}, + {"name": "Bob", "age": 25}, + {"name": "Charlie", "age": 35}, + ], + tag_columns=["name"], + source_id="people", +) + +# 2. Define a transform with the function_pod decorator +@function_pod(output_keys="birth_year") +def compute_birth_year(age: int) -> int: + return 2026 - age + +# 3. Apply the function pod and inspect the output +result = compute_birth_year.pod(source) +for tag, packet in result.iter_packets(): + print(f"{tag.as_dict()} -> {packet.as_dict()}") +# {'name': 'Alice'} -> {'birth_year': 1996} +# {'name': 'Bob'} -> {'birth_year': 2001} +# {'name': 'Charlie'} -> {'birth_year': 1991} +``` + +## Next steps + +- [Getting Started](getting-started.md) -- a hands-on walkthrough of sources, streams, + function pods, and operators. +- [Concepts](concepts/sources.md) -- deeper explanations of Orcapod's core abstractions. +- [API Reference](api/index.md) -- complete reference for all public classes and functions. diff --git a/examples/async_vs_sync_pipeline.py b/examples/async_vs_sync_pipeline.py new file mode 100644 index 00000000..73968000 --- /dev/null +++ b/examples/async_vs_sync_pipeline.py @@ -0,0 +1,186 @@ +"""Async vs sync pipeline execution — 2x2 comparison matrix. + +Demonstrates the interplay of two independent concurrency axes: + +1. **Pipeline executor** — sync (sequential node execution) vs async + (concurrent node execution via channels). + +2. **Packet function** — sync (GIL-holding busy-wait) vs async + (non-blocking ``asyncio.sleep``). + +The sync function uses a pure-Python busy-wait loop that holds the GIL, +preventing thread-pool concurrency. This makes the difference between +async+sync and async+async clearly visible: + ++---------------------+----------------------------+----------------------------+ +| | sync function (holds GIL) | async function | ++---------------------+----------------------------+----------------------------+ +| sync executor | fully sequential | sequential (async fn | +| | | called via sync fallback) | ++---------------------+----------------------------+----------------------------+ +| async executor | branches overlap, but | branches overlap AND | +| | packets serialize on GIL | packets run concurrently | +| | (thread pool can't help) | (native coroutines) | ++---------------------+----------------------------+----------------------------+ + +Usage: + uv run python examples/async_vs_sync_pipeline.py +""" + +from __future__ import annotations + +import asyncio +import time + +import pyarrow as pa + +from orcapod.sources import ArrowTableSource +from orcapod.core.function_pod import FunctionPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import Pipeline +from orcapod.types import ExecutorType, NodeConfig, PipelineConfig + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +SLEEP_SECONDS = 0.1 # per-packet delay (kept short since busy-wait burns CPU) +NUM_PACKETS = 6 + +# --------------------------------------------------------------------------- +# Source data +# --------------------------------------------------------------------------- + +SOURCE_TABLE = pa.table( + { + "id": pa.array(list(range(NUM_PACKETS)), type=pa.int64()), + "x": pa.array(list(range(NUM_PACKETS)), type=pa.int64()), + } +) + +# --------------------------------------------------------------------------- +# Domain functions — async and sync variants +# --------------------------------------------------------------------------- + + +def _busy_wait(seconds: float) -> None: + """Burn CPU in a pure-Python loop that never releases the GIL.""" + end = time.perf_counter() + seconds + while time.perf_counter() < end: + pass + + +async def async_slow_double(x: int) -> int: + """Simulate an async I/O-bound operation (e.g. API call).""" + await asyncio.sleep(SLEEP_SECONDS) + return x * 2 + + +def sync_slow_double(x: int) -> int: + """Simulate a GIL-holding blocking operation (e.g. CPU-bound work).""" + _busy_wait(SLEEP_SECONDS) + return x * 2 + + +# --------------------------------------------------------------------------- +# Pipeline builder +# --------------------------------------------------------------------------- + + +def build_pipeline(use_async_fn: bool) -> Pipeline: + """Build a pipeline with two independent branches from the same source. + + Pipeline:: + + source ──┬── slow_double (branch_a) + └── slow_double (branch_b) + """ + db = InMemoryArrowDatabase() + pipeline = Pipeline(name="demo", pipeline_database=db) + + fn = async_slow_double if use_async_fn else sync_slow_double + with pipeline: + source = ArrowTableSource(SOURCE_TABLE, tag_columns=["id"]) + pf_a = PythonPacketFunction(fn, output_keys="result", function_name="branch_a") + pf_b = PythonPacketFunction(fn, output_keys="result", function_name="branch_b") + FunctionPod( + packet_function=pf_a, + node_config=NodeConfig(max_concurrency=NUM_PACKETS), + )(source, label="branch_a") + FunctionPod( + packet_function=pf_b, + node_config=NodeConfig(max_concurrency=NUM_PACKETS), + )(source, label="branch_b") + + return pipeline + + +def run_case(label: str, use_async_fn: bool, use_async_executor: bool) -> float: + """Run a single combination and return elapsed time.""" + pipeline = build_pipeline(use_async_fn=use_async_fn) + t0 = time.perf_counter() + if use_async_executor: + pipeline.run(config=PipelineConfig(executor=ExecutorType.ASYNC_CHANNELS)) + else: + pipeline.run() + elapsed = time.perf_counter() - t0 + + a = pipeline.branch_a.get_all_records() + b = pipeline.branch_b.get_all_records() + assert a is not None and b is not None + total_rows = a.num_rows + b.num_rows + print(f" {label:44s} {elapsed:5.2f}s ({total_rows} rows)") + return elapsed + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + total = NUM_PACKETS * 2 # packets per branch * 2 branches + seq_time = SLEEP_SECONDS * total + + print("=" * 64) + print("Pipeline execution: 2x2 comparison matrix") + print("=" * 64) + print(f" {NUM_PACKETS} packets x 2 branches, {SLEEP_SECONDS}s sleep each") + print(f" Sequential baseline: {seq_time:.1f}s\n") + + print(" Pipeline topology:") + print(" source ──┬── slow_double (branch_a)") + print(" └── slow_double (branch_b)\n") + + t1 = run_case("sync executor + sync function :", False, False) + t2 = run_case("sync executor + async function:", True, False) + t3 = run_case("async executor + sync function :", False, True) + t4 = run_case("async executor + async function:", True, True) + + print() + print(" Analysis:") + print( + f" sync+sync {t1:.2f}s — fully sequential ({NUM_PACKETS}x2 x {SLEEP_SECONDS}s)" + ) + print( + f" sync+async {t2:.2f}s — still sequential (sync executor runs nodes one by one)" + ) + print(f" async+sync {t3:.2f}s — branches overlap, but GIL-holding busy-wait") + print(f" serializes packets even across threads") + print(f" async+async {t4:.2f}s — branches overlap AND packets overlap") + print( + f" (native coroutines yield at await points, enabling I/O overlap)" + ) + print() + print(f" Key insight: async+sync is much slower than async+async because") + print(f" the sync function holds the GIL, so run_in_executor threads") + print(f" cannot actually run in parallel. Native async coroutines yield") + print(f" control at each 'await', enabling cooperative I/O overlap.") + print() + print(f" Speedup (sync+sync vs async+async): {t1 / t4:.1f}x") + print(f" Speedup (async+sync vs async+async): {t3 / t4:.1f}x") + + +if __name__ == "__main__": + main() diff --git a/examples/save_and_load_pipelines.py b/examples/save_and_load_pipelines.py new file mode 100644 index 00000000..66051d27 --- /dev/null +++ b/examples/save_and_load_pipelines.py @@ -0,0 +1,34 @@ +from orcapod import Pipeline, function_pod +from orcapod import sources, databases + +database = databases.DeltaTableDatabase("./local_database") +source1 = sources.DictSource( + [{"id": 0, "x": 5}, {"id": 1, "x": 10}, {"id": 2, "x": 15}], + tag_columns=["id"], + label="source1", +) +source1 = source1.cached(database) +source2 = sources.DictSource( + [{"id": 0, "y": 3}, {"id": 2, "y": 6}, {"id": 4, "y": 9}], + tag_columns=["id"], + label="source2", +) +source2 = source2.cached(database) + + +pipeline = Pipeline("my_pipeline", database) + + +@function_pod("sum") +def take_sum(x: int, y: int) -> int: + return x + y + + +with pipeline: + result = take_sum.pod(source1.join(source2)) + +path = "./pipeline.json" +pipeline.save(path) + + +pipeline2 = Pipeline.load(path) diff --git a/function-execution-improvements-plan.md b/function-execution-improvements-plan.md new file mode 100644 index 00000000..7f39fd2b --- /dev/null +++ b/function-execution-improvements-plan.md @@ -0,0 +1,300 @@ +# Function Execution Chain Improvements — Design & Implementation Plan + +## Overview + +This document captures design decisions for improving the executor integration in the +packet function / function pod / function node execution chain. The changes address four +areas: + +1. **`with_options` semantics** — executors become immutable; `with_options()` always returns + a new instance. +2. **`execution_engine_opts` ownership** — removed from FunctionNode; owned exclusively by + the pipeline's executor-assignment logic. +3. **`CachedFunctionPod`** — a new pod-level caching wrapper complementing the existing + `CachedPacketFunction` (packet-level caching). +4. **Type-safe executor dispatch via `Generic[E]` + `__init_subclass__`** — eliminates + redundant `isinstance` checks in the hot path by resolving the executor protocol once at + class definition time. + +--- + +## 1. `with_options` Always Returns a New Instance + +### Current state + +`PacketFunctionExecutorBase.with_options()` returns `self` by default. `RayExecutor` +overrides it to return a new instance. This is inconsistent — callers cannot rely on +`with_options()` being side-effect-free without checking the concrete type. + +### Design decision + +`with_options()` **must always return a new executor instance**, even when no options change. +This makes executors effectively immutable value objects after construction — the same +executor can be safely shared across nodes, and `with_options()` produces a node-specific +variant without mutating the original. + +### Changes + +- **`PacketFunctionExecutorBase.with_options()`**: Default implementation returns + `copy.copy(self)` (shallow clone) instead of `self`. Subclasses that carry mutable state + (e.g. Ray handles) override to produce a properly configured new instance. +- **`PacketFunctionExecutorProtocol.with_options()`**: Docstring updated to specify "returns + a **new** executor instance". +- **`LocalExecutor.with_options()`**: Returns a new `LocalExecutor()`. Trivial since it + carries no state. + +--- + +## 2. Remove `execution_engine_opts` from FunctionNode + +### Current state + +`FunctionNode.__init__` stores `self.execution_engine_opts: dict[str, Any] | None = None`. +This field is set externally (by the pipeline) and later read back during executor +assignment. The node becomes an awkward intermediary — it holds configuration that logically +belongs to the pipeline's executor-assignment step. + +### Design decision + +`execution_engine_opts` is **removed from FunctionNode entirely**. The pipeline's +`apply_executor` (or equivalent) logic is the sole owner: it reads per-node options from +the pipeline config, calls `executor.with_options(**merged_opts)`, and sets the resulting +executor directly on the packet function. The node never sees raw option dicts. + +### Changes + +- **`FunctionNode.__init__`**: Remove `self.execution_engine_opts` attribute. +- **`FunctionNode.from_descriptor`**: Stop reading/writing `execution_engine_opts` from + descriptors. (Backward-compatible break is acceptable per project policy — pre-v0.1.0.) +- **Pipeline executor assignment** (in `pipeline/` module): Merge pipeline-level and + per-node options, call `executor.with_options(**merged)`, then assign the resulting + executor to `node.executor = configured_executor`. + +--- + +## 3. `CachedFunctionPod` — Pod-Level Caching + +### Current state + +Caching exists only at the packet-function level (`CachedPacketFunction`), which wraps +`call()` / `async_call()` with DB lookup/insert. This works but cannot leverage tag +information (which is invisible to packet functions). + +### Design decision + +Add a **`CachedFunctionPod`** that wraps a `FunctionPod` and intercepts at the +`process_packet(tag, packet)` level. This complements `CachedPacketFunction`: + +| Layer | `CachedPacketFunction` | `CachedFunctionPod` | +|-------|------------------------|---------------------| +| Intercepts at | `call(packet)` | `process_packet(tag, packet)` | +| Has tag access | No | Yes | +| Cache key includes | Packet content hash | Tag + packet content hash | +| Delegates to | Wrapped `PacketFunction.call()` | Inner `FunctionPod.process_packet()` | + +Both are useful: `CachedPacketFunction` deduplicates purely on data content; +`CachedFunctionPod` can incorporate tag metadata into cache decisions. + +### Implementation sketch + +```python +class CachedFunctionPod(WrappedFunctionPod): + """Pod-level caching wrapper that intercepts process_packet().""" + + def __init__( + self, + function_pod: FunctionPodProtocol, + result_database: ArrowDatabaseProtocol, + record_path_prefix: tuple[str, ...] = (), + **kwargs, + ) -> None: + super().__init__(function_pod, **kwargs) + self._result_database = result_database + self._record_path_prefix = record_path_prefix + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + # Cache key incorporates both tag and packet content + cache_key = self._compute_cache_key(tag, packet) + cached = self._lookup(cache_key) + if cached is not None: + return tag, cached + tag, output = self._function_pod.process_packet(tag, packet) + if output is not None: + self._store(cache_key, tag, output) + return tag, output +``` + +### Changes + +- **New file**: `src/orcapod/core/cached_function_pod.py` containing `CachedFunctionPod`. +- **`function_pod` decorator**: Add `pod_cache_database` parameter that wraps the pod in + `CachedFunctionPod` when provided (distinct from `result_database` which wraps the packet + function in `CachedPacketFunction`). + +--- + +## 4. Type-Safe Executor Dispatch via `Generic[E]` + `__init_subclass__` + +### Problem + +Currently, each `PacketFunctionBase` subclass that cares about executor-specific capabilities +must do `isinstance` checks in the hot path (`call()` / `direct_call()`). This is both +verbose and error-prone — forgetting to check means silent misuse. + +In Rust, an `enum Executor { Python(PythonExecutor), Container(ContainerExecutor) }` with +`match` would give exhaustive, zero-cost dispatch. Python has no sum types with exhaustiveness +checking, but we can get close. + +### Design decision + +Use `Generic[E]` on `PacketFunctionBase` combined with `__init_subclass__` to resolve the +concrete executor protocol **once at class definition time**. The single `isinstance` check +moves to `set_executor()` (assignment boundary), and the hot path (`call()`) is clean. + +### Mechanism + +```python +from typing import Generic, TypeVar + +E = TypeVar("E", bound=PacketFunctionExecutorProtocol) + +class PacketFunctionBase(TraceableBase, Generic[E]): + _resolved_executor_protocol: ClassVar[type] # auto-set by __init_subclass__ + _executor: E | None = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + for base in cls.__orig_bases__: + origin = typing.get_origin(base) + if origin is PacketFunctionBase: + args = typing.get_args(base) + if args and not isinstance(args[0], TypeVar): + cls._resolved_executor_protocol = args[0] + return + + def set_executor(self, executor: PacketFunctionExecutorProtocol) -> None: + """Single isinstance check at assignment boundary.""" + proto = getattr(type(self), '_resolved_executor_protocol', None) + if proto is not None and not isinstance(executor, proto): + raise TypeError( + f"{type(self).__name__} requires {proto.__name__}, " + f"got {type(executor).__name__}" + ) + self._executor = executor # type: ignore[assignment] +``` + +Subclasses declare the executor type **once** via the generic parameter: + +```python +class PythonPacketFunction(PacketFunctionBase[PythonExecutorProtocol]): + # No _executor_protocol ClassVar needed — __init_subclass__ extracts it + # from Generic[PythonExecutorProtocol] automatically. + ... +``` + +### Why this works + +- **`__orig_bases__`** is set by Python's type machinery on every class that inherits from + a `Generic`. It contains the parameterized base (e.g. + `PacketFunctionBase[PythonExecutorProtocol]`). +- **`typing.get_args()`** extracts the type parameters. +- **`__init_subclass__`** runs at class definition time (import), not at instance creation. + Zero per-instance overhead. +- The `isinstance(args[0], TypeVar)` guard skips intermediate abstract subclasses that + haven't bound `E` yet (e.g. `PacketFunctionWrapper(PacketFunctionBase[E])`). + +### Requirements + +- Executor protocols used as type parameters **must** be decorated with + `@runtime_checkable` (already the case for `PacketFunctionExecutorProtocol`). +- Any new executor protocol (e.g. `PythonExecutorProtocol`) needs `@runtime_checkable` too. + +### Hot path after this change + +```python +def call(self, packet: PacketProtocol) -> PacketProtocol | None: + if self._executor is not None: + # self._executor is statically typed as E (e.g. PythonExecutorProtocol). + # No isinstance check needed — validated at set_executor() time. + return self._executor.execute(self, packet) + return self.direct_call(packet) +``` + +### Changes + +- **`PacketFunctionBase`**: Add `Generic[E]`, `__init_subclass__` resolver, + `set_executor()` method. Existing `executor` property setter delegates to `set_executor()`. +- **`PythonPacketFunction`**: Change to `PacketFunctionBase[PythonExecutorProtocol]` + (or a more specific `PythonExecutorProtocol` if we introduce one). +- **`PacketFunctionWrapper`**: Change to `PacketFunctionBase[E]` (remains generic, passes + through). +- **`CachedPacketFunction`**: Inherits from `PacketFunctionWrapper` — no changes needed + since executor delegation already targets the wrapped leaf function. +- **Executor protocols**: Ensure `@runtime_checkable` on all protocols that will be used as + generic parameters. + +--- + +## 5. `FunctionPod.packet_function` Remains a Read-Only Property + +### Decision + +`FunctionPod.packet_function` stays as a property on the protocol (as currently implemented). +It is understood to be **read-only** — callers should not replace the packet function after +pod construction. The property exists for introspection and executor wiring, not mutation. + +No code changes needed — this is a documentation/convention clarification. + +--- + +## Implementation Plan + +### Phase 1: Executor immutability + remove `execution_engine_opts` + +1. Update `PacketFunctionExecutorBase.with_options()` default to return a shallow copy. +2. Update `LocalExecutor.with_options()` to return `LocalExecutor()`. +3. Verify `RayExecutor.with_options()` already returns a new instance (it does). +4. Remove `self.execution_engine_opts` from `FunctionNode.__init__`. +5. Remove `execution_engine_opts` from `FunctionNode.from_descriptor` read-only state. +6. Update pipeline executor-assignment logic to merge options externally and pass the + configured executor in. +7. Update affected tests. + +### Phase 2: Type-safe executor dispatch + +1. Add `Generic[E]` and `__init_subclass__` to `PacketFunctionBase`. +2. Update `executor` setter to delegate to `set_executor()` with `__init_subclass__`-resolved + protocol check. +3. Parameterize `PythonPacketFunction` as `PacketFunctionBase[PacketFunctionExecutorProtocol]` + (or a narrower protocol if we introduce executor-type-specific protocols later). +4. Parameterize `PacketFunctionWrapper` as `PacketFunctionBase[E]`. +5. Ensure all executor protocols are `@runtime_checkable`. +6. Update tests to verify type checking at assignment time. + +### Phase 3: `CachedFunctionPod` + +1. Create `src/orcapod/core/cached_function_pod.py`. +2. Implement `CachedFunctionPod(WrappedFunctionPod)` with tag-aware cache key computation. +3. Add `pod_cache_database` parameter to `function_pod` decorator. +4. Add tests for pod-level vs packet-level caching interaction. + +### Phase 4: Documentation and cleanup + +1. Update `orcapod-design.md` with the new execution chain design. +2. Update `CLAUDE.md` architecture section if needed. +3. Check `DESIGN_ISSUES.md` for any resolved issues. + +--- + +## Open Questions + +- **Executor-type-specific protocols**: Currently all packet functions accept the base + `PacketFunctionExecutorProtocol`. If we later want `PythonPacketFunction` to only accept + executors with a `PythonExecutorProtocol` (which might require `execute(func, args)` + rather than `execute(pf, packet)`), the `Generic[E]` mechanism already supports this — + just parameterize with the narrower protocol. +- **`CachedFunctionPod` cache key design**: The exact composition of the cache key (which + tag columns to include, whether to include system tags) needs detailed design during + implementation. A reasonable default is tag content hash + packet content hash. diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..44420314 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,52 @@ +site_name: Orcapod +site_description: Intuitive and powerful library for highly reproducible scientific data pipelines +repo_url: https://github.com/walkerlab/orcapod-python + +theme: + name: material + features: + - navigation.tabs + - navigation.sections + - navigation.expand + - search.suggest + - content.code.copy + +plugins: + - search + - mkdocstrings: + handlers: + python: + paths: [src] + options: + docstring_style: google + show_source: true + show_bases: true + members_order: source + +markdown_extensions: + - admonition + - pymdownx.highlight + - pymdownx.superfences + - pymdownx.details + - toc: + permalink: true + +nav: + - Home: index.md + - Getting Started: getting-started.md + - Concepts: + - Sources: concepts/sources.md + - Streams: concepts/streams.md + - Operators: concepts/operators.md + - Function Pods: concepts/function-pods.md + - Identity & Hashing: concepts/identity.md + - API Reference: + - Overview: api/index.md + - Sources: api/sources.md + - Streams: api/streams.md + - Operators: api/operators.md + - Function Pods: api/function-pods.md + - Databases: api/databases.md + - Nodes: api/nodes.md + - Pipeline: api/pipeline.md + - Types: api/types.md diff --git a/orcapod-design.md b/orcapod-design.md new file mode 100644 index 00000000..7684a2cc --- /dev/null +++ b/orcapod-design.md @@ -0,0 +1,730 @@ +# OrcaPod — Design Specification + +--- + +## Core Abstractions + +### Datagram + +The **datagram** is the universal immutable data container in OrcaPod. A datagram holds named columns with explicit type information and supports lazy conversion between Python dict and Apache Arrow representations. Datagrams come in two specialized forms: + +- **Tag** — metadata columns attached to a packet for routing, filtering, and annotation. Tags carry additional **system tags** — framework-managed hidden provenance columns that are excluded from content identity by default. + +- **Packet** — data columns carrying the computational payload. Packets carry additional **source info** — per-column provenance tokens tracing each value back to its originating source and record. + +Datagrams are always constructed from either a Python dict or an Arrow table/record batch. The alternative representation is computed lazily and cached. Content hashing always uses the Arrow representation; value access always uses the Python dict. + +### Stream + +A **stream** is a sequence of (Tag, Packet) pairs over a shared schema. Streams define two column groups — tag columns and packet columns — and provide lazy iteration, table materialization, and schema introspection. Streams are the fundamental data-flow abstraction: every source emits one, every operator consumes and produces them, and every function pod iterates over them. + +The concrete implementation is `ArrowTableStream`, backed by an immutable PyArrow Table with explicit tag/packet column assignment. + +### Source + +A **source** acts as a stream from external data with no upstream dependencies, forming the base case of the pipeline graph. Sources establish provenance: each row gets a source-info token and system tag columns encoding the source's identity. + +- **Root source** — loads data from the external world (file, database, in-memory table). All root sources delegate to `ArrowTableSource`, which wraps the data in an `ArrowTableStream` with provenance annotations. Concrete subclasses include `CSVSource`, `DeltaTableSource`, `DataFrameSource`, `DictSource`, and `ListSource`. + +- **Derived source** — wraps the computed output of a `FunctionNode` or `OperatorNode`, reading from their pipeline database. Represents an explicit materialization declaration — an intermediate result given durable identity in the pipeline database, detached from the upstream topology that produced it. + +Every source has a `source_id` — a canonical registry name used to register the source in a `SourceRegistry` so that provenance tokens in downstream data can be resolved back to the originating source. If not explicitly provided, `source_id` defaults to a truncated content hash. + +### Function Pod + +A **function pod** wraps a **packet function** — a stateless computation that consumes a single packet and produces an output packet. Function pods never inspect tags or stream structure; they operate purely on packet content. When given multiple input streams, a function pod joins them via a configurable multi-stream handler (defaulting to `Join`) before iterating. + +Packet functions support pluggable executors (see **Packet Function Executor System**). When an executor is set, `call()` routes through `executor.execute()` and `async_call()` routes through `executor.async_execute()`. When no executor is set, the function's native `direct_call()` / `direct_async_call()` is invoked directly. For `PythonPacketFunction`, `direct_async_call` runs the synchronous function in a thread pool via `asyncio.run_in_executor`. + +Two execution models exist: + +- **FunctionPod + FunctionPodStream** — lazy, in-memory evaluation. The function pod processes each (tag, packet) pair from the input stream on demand, caching results by index. When the attached executor declares `supports_concurrent_execution = True`, `iter_packets()` materializes all remaining inputs and dispatches them concurrently via `asyncio.gather` over `async_call`, yielding results in order. + +- **FunctionNode** — database-backed evaluation with incremental computation. Execution proceeds in two phases: + 1. **Phase 1**: yield cached results from the pipeline database for inputs whose hashes are already stored. + 2. **Phase 2**: compute results for any remaining input packets, store them in the database, and yield. + + Pipeline database scoping uses `pipeline_hash()` (schema+topology only), so FunctionNodes with identical functions and schema-compatible sources share the same database table. + +### Operator + +An **operator** is a structural pod that transforms streams without synthesizing new packet values. Every packet value in an operator's output must be traceable to a concrete value already present in the input packets — operators perform joins, merges, splits, selections, column renames, batching, and tag operations within this constraint. + +Operators are subclasses of `StaticOutputPod` organized by input arity: + +| Base Class | Arity | Examples | +|---|---|---| +| `UnaryOperator` | Exactly 1 input | Batch, SelectTagColumns, DropPacketColumns, MapTags, MapPackets, PolarsFilter | +| `BinaryOperator` | Exactly 2 inputs | MergeJoin, SemiJoin | +| `NonZeroInputOperator` | 1 or more inputs | Join | + +Each operator declares its **argument symmetry** — whether inputs commute (`frozenset`, order-invariant) or have fixed positions (`tuple`, order-dependent). This determines how upstream hashes are combined for pipeline identity. + +The `OperatorNode` is the database-backed counterpart, analogous to `FunctionNode` for function pods. It applies the operator, materializes the output with per-row record hashes, and stores the result in the pipeline database. + +Every operator inherits a default barrier-mode `async_execute` from its base class (collect all inputs, run `static_process`, emit results). Subclasses can override for streaming or incremental strategies (see **Execution Models**). + +--- + +## Operator Catalog + +### Join +Variable-arity inner join on shared tag columns. Non-overlapping packet columns are required — colliding packet columns raise `InputValidationError`. Tag schema is the union of all input tag schemas; packet schema is the union. Inputs are canonically ordered by `pipeline_hash` for deterministic system tag column naming. Commutative (declared via `frozenset` argument symmetry). + +### MergeJoin +Binary inner join that handles colliding packet columns by merging their values into sorted `list[T]`. Colliding columns must have identical types. Non-colliding columns are kept as scalars. Corresponding source-info columns are reordered to match the sort order of their packet column. Commutative — commutativity comes from sorting merged values, not from ordering input streams. + +### SemiJoin +Binary semi-join: returns entries from the left stream that match on overlapping columns in the right stream. Output schema matches the left stream exactly. Non-commutative. + +### Batch +Groups rows into batches of a configurable size. All column types become `list[T]`. Optionally drops incomplete final batches. + +### SelectTagColumns / SelectPacketColumns +Keep only specified tag or packet columns. Optional `strict` mode raises on missing columns. + +### DropTagColumns / DropPacketColumns +Remove specified tag or packet columns. `DropPacketColumns` also removes associated source-info columns. + +### MapTags / MapPackets +Rename tag or packet columns via a name mapping. `MapPackets` automatically renames associated source-info columns. Optional `drop_unmapped` mode removes columns not in the mapping. + +### PolarsFilter +Applies Polars filtering predicates to rows. Output schema is unchanged from input. + +--- + +## Schema as a First-Class Citizen + +Every stream exposes `output_schema()` returning `(tag_schema, packet_schema)` as `Schema` objects — immutable mappings from field names to Python types with support for optional fields. Schema is embedded explicitly at every level rather than resolved against a central registry, making streams fully self-describing. + +The `ColumnConfig` dataclass controls what metadata columns are included in schema and data output: + +| Field | Controls | +|---|---| +| `meta` | System metadata columns (`__` prefix) | +| `context` | Data context column | +| `source` | Source-info provenance columns (`_source_` prefix) | +| `system_tags` | System tag columns (`_tag_` prefix) | +| `content_hash` | Per-row content hash column | +| `sort_by_tags` | Whether to sort output by tag columns | + +Operators predict their output schema — including system tag column names — without performing the actual computation. + +--- + +## Tags + +Tags are key-value pairs attached to every packet providing human-friendly metadata for navigation, filtering, and annotation. They are: + +- **Non-authoritative** — never used for cache lookup or pod identity computation +- **Auto-propagated** — tags flow forward through the pipeline automatically +- **The basis for joins** — operator pods join streams by matching tag keys, never by inspecting packet content + +**Tag merging in joins:** +- **Shared tag keys** — act as the join predicate; values must match for packets to be joined +- **Non-shared tag keys** — propagate freely into the joined output's tags + +--- + +## Operator / Function Pod Boundary + +This is a strict and critical separation: + +| | Operator | Function Pod | +|---|---|---| +| Inspects packet content | Never | Yes | +| Inspects / uses tags | Yes | No | +| Can rename columns | Yes | No | +| Stream arity | Configurable (unary/binary/N-ary) | Single stream in, single stream out | +| Cached by content hash | No | Yes | +| Synthesizes new values | No | Yes | + +Column renaming by operators allows join conflicts to be avoided without contaminating source info — the column name changes but the source info pointer remains intact, always traceable to the original producing pod. + +--- + +## Identity and Hashing + +OrcaPod maintains two parallel identity chains implemented as recursive Merkle-like hash trees: + +### Content Hash (`content_hash()`) + +Data-inclusive identity capturing the precise semantic content of an object: + +| Component | What Gets Hashed | +|---|---| +| RootSource | Class name + tag columns + table content hash | +| PacketFunction | URI (canonical name + output schema hash + version + type ID) | +| FunctionPodStream | Function pod + argument symmetry of inputs | +| Operator | Operator class + identity structure | +| ArrowTableStream | Producer + upstreams (or table content if no producer) | +| Datagram | Arrow table content | +| DerivedSource | Origin node's content hash | + +Content hashes use a `BaseSemanticHasher` that recursively expands structures, dispatches to type-specific handlers, and terminates at `ContentHash` leaves (preventing hash-of-hash inflation). + +### Pipeline Hash (`pipeline_hash()`) + +Schema-and-topology-only identity used for database path scoping. Excludes data content so that different sources with identical schemas share database tables: + +| Component | What Gets Hashed | +|---|---| +| RootSource | `(tag_schema, packet_schema)` — base case | +| PacketFunction | Raw packet function object (via content hash) | +| FunctionPodStream | Function pod + input stream pipeline hashes | +| Operator | Operator class + argument symmetry (pipeline hashes of inputs) | +| ArrowTableStream | Producer + upstreams pipeline hashes (or schema if no producer) | +| DerivedSource | Inherited from RootSource: `(tag_schema, packet_schema)` | + +Pipeline hash uses a **resolver pattern** — a callback that routes `PipelineElementProtocol` objects through `pipeline_hash()` and other `ContentIdentifiable` objects through `content_hash()` — ensuring the correct identity chain is used for nested objects within a single hash computation. + +### ContentHash Type + +All hashes are represented as `ContentHash` — a frozen dataclass pairing a method identifier (e.g., `"object_v0.1"`, `"arrow_v2.1"`) with raw digest bytes. The method name enables detecting version mismatches across hash configurations. Conversions: `.to_hex()`, `.to_int()`, `.to_uuid()`, `.to_base64()`, `.to_string()`. + +### Argument Symmetry and Upstream Commutativity + +Each pod declares how upstream hashes are combined: + +- **Commutative** (`frozenset`) — upstream hashes sorted before combining. Used when input order is semantically irrelevant (Join, MergeJoin). +- **Non-commutative** (`tuple`) — upstream hashes combined in declared order. Used when input position is significant (SemiJoin). +- **Partial symmetry** — nesting expresses mixed constraints, e.g. `(frozenset([a, b]), c)`. + +--- + +## Packet Function Signatures + +Every packet function has a unique signature reflecting its input/output schemas and implementation. The function's URI encodes: + +``` +(canonical_function_name, output_schema_hash, major_version, packet_function_type_id) +``` + +For Python functions specifically, the identity structure includes the function's bytecode hash, input parameters signature, and Git version information. + +--- + +## Source Info + +Every packet column carries a **source info** string — a provenance pointer to the source and record that produced the value: + +``` +{source_id}::{record_id}::{column_name} +``` + +Where: +- `source_id` — canonical identifier of the originating source (defaults to content hash) +- `record_id` — row identifier, either positional (`row_0`) or column-based (`user_id=abc123`) +- `column_name` — the original column name + +Source info columns are stored with a `_source_` prefix and are excluded from content hashing and standard output by default. They are included when `ColumnConfig(source=True)` is set. + +Source info is **immutable through the pipeline** — set once when a source creates the data and preserved through all downstream operator transformations including column renames. + +--- + +## System Tags + +System tags are **framework-managed, hidden provenance columns** automatically attached to every packet. Unlike user tags, they are authoritative and guaranteed to maintain perfect traceability from any result row back to its original source rows. + +### Flat Column Design + +System tags store `source_id` and `record_id` as **separate flat columns** rather than a combined string value. This is a deliberate design choice driven by the caching strategy (see **Caching Strategy** section below). + +In function pod cache tables, which are scoped to a structural pipeline hash and thus shared across different source combinations, filtering by source identity is a first-class operation. Storing `source_id` and `record_id` as separate columns makes this a straightforward equality predicate (`WHERE _tag_source_id::schema1 = 'X'`) with clean standard indexing, rather than a prefix match or string parse against a combined value. + +This is safe because within any given cache table, the system tag schema is fixed — every row has the same set of system tag fields, determined by the pipeline structure. The column count grows with pipeline depth (more join stages produce more system tag column pairs), but this growth is per-table-schema, not within a table. Different pipeline structures produce different tables with different column layouts, which is the expected and correct behavior. + +### Source System Tags + +Each source automatically adds a pair of system tag columns using the `_tag_` prefix convention: + +``` +_tag_source_id::{schema_hash} — the source's canonical source_id +_tag_record_id::{schema_hash} — the row identifier within that source +``` + +Where `schema_hash` is derived from the source's `(tag_schema, packet_schema)`. The `::` delimiter separates segments of the system tag column name, maintaining consistency with the extension pattern used downstream. + +Example at the root level: + +``` +_tag_source_id::schema1 (e.g., value: "customers_2024") +_tag_record_id::schema1 (e.g., value: "row_42" or "user_id=abc123") +``` + +### Three Evolution Rules + +**1. Name-Preserving (~90% of operations)** +Single-stream operations (filter, select, rename, batch, map). System tag column names and values pass through unchanged. + +**2. Name-Extending (multi-input operations)** +Joins and merges. Each incoming system tag column name is extended by appending `::node_pipeline_hash:canonical_position`. The `::` delimiter separates each extension segment, and `:` separates the pipeline hash from the canonical position within a segment. Canonical position assignment respects commutativity — for commutative operations, inputs are sorted by `pipeline_hash` to ensure identical column names regardless of wiring order. + +For example, joining two streams that each carry `_tag_source_id::schema1` / `_tag_record_id::schema1`, through a join with pipeline hash `abc123`: + +``` +_tag_source_id::schema1::abc123:0 _tag_record_id::schema1::abc123:0 (first stream by canonical position) +_tag_source_id::schema1::abc123:1 _tag_record_id::schema1::abc123:1 (second stream by canonical position) +``` + +A subsequent join (pipeline hash `def456`) over those results would further extend: + +``` +_tag_source_id::schema1::abc123:0::def456:0 +_tag_record_id::schema1::abc123:0::def456:0 +``` + +The full column name is a chain of `::` delimited segments tracing the provenance path: `_tag_{field}::{source_schema_hash}::{join1_hash}:{position}::{join2_hash}:{position}::...` + +**3. Type-Evolving (aggregation operations)** +Batch and similar grouping operations. Column names are unchanged but types evolve: `str → list[str]` as values collect all contributing source row IDs. Both `source_id` and `record_id` columns evolve independently. + +### System Tag Value Sorting + +For commutative operators (Join, MergeJoin), system tag values from same-`pipeline_hash` streams are sorted per row after the join. This ensures `Op(A, B)` and `Op(B, A)` produce identical system tag columns and values. + +### Schema Prediction + +Operators predict output system tag column names at schema time — without performing the actual computation — by computing `pipeline_hash` values and canonical positions. This is exposed via `output_schema(columns={"system_tags": True})`. + +--- + +## Caching Strategy + +OrcaPod uses a differentiated caching strategy across its three pod types — source, function, and operator — reflecting the distinct computational semantics of each. The guiding principle is that caching behavior should follow naturally from whether the computation is **cumulative**, **independent**, or **holistic**. + +### Source Pod Caching + +**Cache table identity:** Canonical source identity (content hash). + +Each source gets its own dedicated cache table. Sources are provenance roots — there is no upstream system tag mechanism to disambiguate rows from different sources within a shared table. A cached source table represents a cumulative record of all packets ever observed from that specific source. + +**Behavior:** +- Cache is **always on** by default. +- Each packet yielded by the source is stored in the cache table keyed by its content-addressable hash. +- On access, the source pod yields the **merged content of the cache and any new packets** from the live source. +- **Deduplication is performed at the source pod level** during merge, using content-addressable packet hashes. This ensures the yielded stream represents the complete known universe from the source with no redundancy. + +**Semantic guarantee:** The cache is a **correct cumulative record**. The union of cache + live packets is the full set of data ever available from that source. + +### Function Pod Caching + +Function pod caching is split into two tiers: + +1. **Packet-level cache (global):** Maps input packet hash → output packet. Shared globally across all pipelines, enabling identical function calls to reuse results regardless of context. +2. **Tag-level cache (per structural pipeline):** Maps tag → input packet hash. Scoped to the structural pipeline hash. + +**Tag-level cache table identity:** Structural pipeline hash (`pipeline_hash()`). + +A single cache table is used for all runs of structurally identical pipelines (same tag and packet schemas at source, followed by the same sequence of operator and function pods), regardless of which specific source combinations were involved. This is safe because function pods operate on individual packets independently — each cached mapping is self-contained and valid regardless of what other rows exist in the table. + +**Why structural hash, not content hash:** +- System tags already carry full provenance, including source identity as separate queryable columns. Rows from different source combinations are distinguishable within a shared table via equality predicates on `source_id` columns (e.g., `WHERE _tag_source_id::schema1 = 'X'`). +- A shared table provides a natural **cross-source view** — comparing how the same analytical pipeline behaves across different source populations without needing cross-table joins. +- Content-hash scoping would duplicate disambiguation that system tags already provide, violating the principle against redundant mechanisms. + +**Behavior:** +- Cache is **always on** by default. +- On a pipeline run, incoming packets are scoped to the current source combination (determined by upstream source pods). +- The function pod checks the tag-level cache for existing mappings among the incoming tag-packets. +- **Cache hits** (from this or any prior run over the same structural pipeline) are yielded directly. Cross-source sharing falls out naturally because packet-level computation is source-independent. +- **Cache misses** trigger computation; results are stored in both the packet-level and tag-level caches. + +**Semantic guarantee:** The cache is a **correct reusable lookup**. Every entry is independently valid. The table as a whole is a historical record of all computations processed through this function within this structural pipeline context. + +**User guidance:** If a user finds the mixture of results from different source combinations within one table to be unpredictable or undesirable, they should separate pipeline identity explicitly (e.g., by parameterizing the pipeline to produce distinct structural hashes). + +### Operator Pod Caching + +**Cache table identity:** Content hash (structural pipeline hash + identity hashes of all upstream sources). + +Each unique combination of pipeline structure and source identities gets its own cache table. This reflects the fact that operator results are holistic — they depend on the entire input stream, not individual packets. + +**Why content hash, not structural hash:** +Operators compute over the stream (joins, aggregations, window functions). Their outputs are meaningful only as a complete set given a specific input. Unlike function pods, operator results cannot be safely mixed across source combinations within a shared table because the distributive property does not hold for most operators. For example, with a join: `(X ⋈ Y) ∪ (X' ⋈ Y') ≠ (X ∪ X') ⋈ (Y ∪ Y')`. The shared table would miss cross-terms `X ⋈ Y'` and `X' ⋈ Y`. Cache invalidation is also cleaner per-table (drop/mark stale) rather than selectively purging rows by system tag. + +**Critical correctness caveat:** +Even scoped to content hash, operator caches are **not guaranteed to be complete** with respect to the full picture of all packets ever yielded by the sources. Because sources may use canonical identity for their content hash, the same source identity may yield different packet sets over time. The cache accumulates result rows across runs: + +- Run 1: `X ⋈ Y` is cached. +- Run 2: Sources yield `X'` and `Y'`. The operator computes `X' ⋈ Y'` and appends new rows to cache. +- The cache now contains `(X ⋈ Y) ∪ (X' ⋈ Y')`, which is **not** equivalent to `(X ∪ X') ⋈ (Y ∪ Y')`. + +The operator cache is strictly an **append-only historical record**, not a cumulative materialization. Identical output rows across runs naturally deduplicate (keyed by `hash(tag + packet + system_tag)`). Run-level grouping and tracking is managed separately outside the cache mechanism. + +**Behavior:** +- Cache is **off by default**. Operator computation is always triggered fresh in a typical run. +- Cache can be **explicitly opted into** for historical logging purposes. Even when enabled, the operator still recomputes — the cache serves as a record, not a substitute. +- A separate, explicit configuration is required to **skip computation and flow the historical cache** to the rest of the pipeline. This is only appropriate when the user intentionally wants to use the historical record (e.g., for auditing or comparing run-over-run results), not as a performance optimization. + +**Three-tier opt-in model:** + +| Mode | Cache writes | Computation | Use case | +|------|-------------|-------------|----------| +| Default (off) | No | Always | Normal pipeline execution | +| Logging | Yes | Always | Audit trail, run-over-run comparison | +| Historical replay | Yes (prior) | Skipped | Explicitly flowing prior results downstream | + +**Semantic guarantee:** The cache is a **historical record**. It records what was produced, not what would be produced now. Identical output rows across runs are deduplicated. It must never be silently substituted for fresh computation. + +### Caching Summary + +| Property | Source Pod | Function Pod | Operator Pod | +|----------|-----------|--------------|--------------| +| Cache table scope | Canonical source identity | Structural pipeline hash | Content hash (structure + sources) | +| Default state | Always on | Always on | Off | +| Semantic role | Cumulative record | Reusable lookup | Historical record | +| Correctness | Always correct | Always correct | Per-run snapshots only | +| Cross-source sharing | N/A (one source per table) | Yes, via system tag columns | No (separate tables) | +| Computation on cache hit | Dedup and merge | Skip (use cached result) | Recompute by default | + +The overall gradient: sources are always cached and always correct, function pods are always cached and always reusable, operators are optionally logged and never silently substituted. Each level directly follows from whether the computation is cumulative, independent, or holistic. + +--- + +## Pipeline Database Scoping + +Function pods and operators use `pipeline_hash()` to scope their database tables: + +### FunctionNode Pipeline Path + +``` +{pipeline_path_prefix} / {function_name} / {output_schema_hash} / v{major_version} / {function_type_id} / node:{pipeline_hash} +``` + +### OperatorNode Pipeline Path + +``` +{pipeline_path_prefix} / {operator_class} / {operator_content_hash} / node:{pipeline_hash} +``` + +### Multi-Source Table Sharing + +Sources with identical schemas produce identical `pipeline_hash` values. When processed through the same pipeline structure, they share database tables automatically. Different source instances (e.g., `customers_2023`, `customers_2024`) coexist in the same table, differentiated by system tag values and record hashes. This enables natural cross-source analytics without separate table management. + +--- + +## Derived Sources and Pipeline Composition + +Derived sources bridge pipeline stages by materializing intermediate results: + +- **Construction**: `function_node.as_source()` or `operator_node.as_source()` returns a `DerivedSource` that reads from the node's pipeline database. +- **Identity**: Content hash ties to the origin node's content hash; pipeline hash is schema-only (inherited from `RootSource`). +- **Use case**: Downstream pipelines reference the derived source directly, independent of the upstream topology that produced it. + +Derived sources serve two purposes: +1. **Semantic materialization** — domain-meaningful intermediate constructs (e.g., a daily top-3 selection, a trial, a session) are given durable identity in the pipeline database. +2. **Pipeline decoupling** — once materialized, downstream pipelines can evolve independently of upstream topology. + +--- + +## Provenance Graph + +Data provenance focuses on **data-generating entities only** — sources and function pods. Since operators never synthesize new packet values, they leave no computational footprint on the data itself. + +The provenance graph is a **bipartite graph of sources and function pods**, with edges encoded as source info pointers per output field. Operator pod topology is captured implicitly in system tag column names and the pipeline Merkle chain but operators do not appear as nodes in the provenance graph. + +This means: +- **Operators can be refactored** without invalidating data provenance +- **Provenance queries are simpler** — tracing a result requires only following source info pointers between function pod table entries +- **Provenance is robust** — lineage is told by what generated and transformed the data, not by how it was routed + +--- + +## Execution Models + +OrcaPod supports two complementary execution strategies — **synchronous pull-based** and **asynchronous push-based** — that produce semantically identical results. The choice of strategy is an execution concern, not a data-identity concern: neither content hashes nor pipeline hashes depend on how the pipeline was executed. + +### Synchronous Execution (Pull-Based) + +The default model. Callers invoke `process()` on a pod, which returns a stream. Iteration over the stream triggers computation lazily. + +Three variants exist within the synchronous model: + +**1. Lazy In-Memory (FunctionPod → FunctionPodStream)** +The function pod processes each packet on demand via `iter_packets()`. Results are cached by index in memory. No database persistence. Suitable for exploration and one-off computations. + +**2. Static with Recomputation (StaticOutputPod → DynamicPodStream)** +The operator's `static_process` produces a complete output stream. `DynamicPodStream` wraps it with timestamp-based staleness detection and automatic recomputation when upstreams change. + +**3. Database-Backed Incremental (FunctionNode / OperatorNode → PersistentFunctionNode / PersistentOperatorNode)** +Results are persisted in a pipeline database. Incremental computation: only process inputs whose hashes are not already in the database. Per-row record hashes enable deduplication. Suitable for production pipelines with expensive computations. `PersistentFunctionNode` extends `FunctionNode` with result caching via `CachedPacketFunction` and two-phase iteration (Phase 1: yield cached results, Phase 2: compute missing). `PersistentOperatorNode` extends `OperatorNode` with three-tier caching (off / log / replay). + +**Concurrent execution within sync mode:** +When a `PacketFunctionExecutor` with `supports_concurrent_execution = True` is attached (e.g. `RayExecutor`), `FunctionPodStream.iter_packets()` materializes all remaining input packets and dispatches them concurrently via the executor's `async_execute`, collecting results in order. This provides data-parallel speedup without leaving the synchronous call model. + +### Asynchronous Execution (Push-Based Channels) + +Every pipeline node — source, operator, or function pod — implements the `AsyncExecutableProtocol`: + +```python +async def async_execute( + inputs: Sequence[ReadableChannel[tuple[Tag, Packet]]], + output: WritableChannel[tuple[Tag, Packet]], +) -> None +``` + +Nodes consume `(Tag, Packet)` pairs from input channels and produce them to an output channel. This enables push-based, streaming execution where data flows through the pipeline as soon as it's available, with backpressure propagated via bounded channel buffers. + +**FunctionPod async strategy:** Streaming mode — each input `(tag, packet)` is processed independently with semaphore-controlled concurrency. Uses `asyncio.TaskGroup` for structured concurrency. + +#### Operator Async Strategies + +Each operator overrides `async_execute` with the most efficient streaming pattern its semantics permit. The default fallback (inherited from `StaticOutputPod`) is barrier mode: collect all inputs via `asyncio.gather`, materialize to `ArrowTableStream`, call `static_process`, and emit results. Operators override this default when a more incremental strategy is possible. + +| Strategy | Description | Operators | +|---|---|---| +| **Per-row streaming** | Transform each `(Tag, Packet)` independently as it arrives; zero buffering beyond the current row | SelectTagColumns, SelectPacketColumns, DropTagColumns, DropPacketColumns, MapTags, MapPackets | +| **Accumulate-and-emit** | Buffer rows up to `batch_size`, emit full batches immediately, flush partial at end | Batch (`batch_size > 0`) | +| **Build-probe** | Collect one side fully (build), then stream the other through a hash lookup (probe) | SemiJoin | +| **Symmetric hash join** | Read both sides concurrently, buffer + index both, emit matches as they're found | Join (2 inputs) | +| **Barrier mode** | Collect all inputs, run `static_process`, emit results | PolarsFilter, MergeJoin, Batch (`batch_size = 0`), Join (N > 2 inputs) | + +#### Per-Row Streaming (Unary Column/Map Operators) + +For operators that transform each row independently (column selection, column dropping, column renaming), the async path iterates `async for tag, packet in inputs[0]` and applies the transformation per row. Column metadata (which columns to drop, the rename map, etc.) is computed lazily on the first row and cached for subsequent rows. This avoids materializing the entire input into an Arrow table, enabling true pipeline-level streaming where upstream producers and downstream consumers run concurrently. + +#### Accumulate-and-Emit (Batch) + +When `batch_size > 0`, Batch accumulates rows into a buffer and emits a batched result stream each time the buffer reaches `batch_size`. Any partial batch at the end is emitted unless `drop_partial_batch` is set. When `batch_size = 0` (meaning "batch everything into one group"), the operator must see all input before producing output, so it falls back to barrier mode. + +#### Build-Probe (SemiJoin) + +SemiJoin is non-commutative: the left side is filtered by the right side. The async implementation collects the right (build) side fully, constructs a hash set of its key tuples, then streams the left (probe) side through the lookup — emitting each left row whose keys appear in the right set. This is the same pattern as Kafka's KStream-KTable join: the table side is materialized, the stream side drives output. + +#### Symmetric Hash Join + +The 2-input Join uses a symmetric hash join — the same algorithm used by Apache Kafka for KStream-KStream joins and by Apache Flink for regular streaming joins. Both input channels are drained concurrently into a shared `asyncio.Queue`. For each arriving row: + +1. Buffer the row on its side and index it by the shared key columns. +2. Probe the opposite side's index for matching keys. +3. Emit all matches immediately. + +When the first rows from both sides have arrived, the shared key columns are determined (intersection of tag column names). Any rows that arrived before shared keys were known are re-indexed and cross-matched in a one-time reconciliation step. + +**Comparison with industry stream processors:** + +| Aspect | Kafka Streams (KStream-KStream) | Apache Flink (Regular Join) | OrcaPod | +|---|---|---|---| +| Algorithm | Symmetric windowed hash join | Symmetric hash join with state TTL | Symmetric hash join | +| Windowing | Required (sliding window bounds state) | Optional (TTL evicts old state) | Not needed (finite streams) | +| State backend | RocksDB state stores for fault tolerance | RocksDB / heap state with checkpointing | In-memory buffers | +| State cleanup | Window expiry evicts old records | TTL or watermark eviction | Natural termination — inputs are finite | +| N-way joins | Chained pairwise joins | Chained pairwise joins | 2-way: symmetric hash; N > 2: barrier + Arrow join | + +The symmetric hash join is optimal for our use case: it emits results with minimum latency (as soon as a match exists on both sides) and requires no windowing complexity since OrcaPod streams are finite. For N > 2 inputs, the operator falls back to barrier mode with Arrow-level join execution, which is efficient for bounded data and avoids the complexity of chaining pairwise streaming joins. + +**Why not build-probe for Join?** Since Join is commutative and input sizes are unknown upfront, there is no principled way to choose which side to build vs. probe. Symmetric hash join avoids this asymmetry. SemiJoin, being non-commutative, has a natural build (right) and probe (left) side. + +**Why barrier for PolarsFilter and MergeJoin?** PolarsFilter requires a Polars DataFrame context for predicate evaluation, which needs full materialization. MergeJoin's column-merging semantics (colliding columns become sorted `list[T]`) require seeing all rows to produce correctly typed output columns. + +### Sync / Async Equivalence + +Both execution paths produce identical output given identical inputs. The sync path is simpler to debug and compose; the async path enables pipeline-level parallelism and streaming. The `PipelineConfig.executor` field selects between them: + +| `ExecutorType` | Mechanism | Use case | +|---|---|---| +| `SYNCHRONOUS` | `process()` chain with pull-based materialization | Interactive exploration, debugging | +| `ASYNC_CHANNELS` | `async_execute()` with push-based channels | Production pipelines, concurrent I/O | + +--- + +## Channel System + +Channels are the communication primitive for push-based async execution. They are bounded async queues with explicit close/done signaling and backpressure. + +### Channel + +A `Channel[T]` is a bounded async buffer (default capacity 64) with separate reader and writer views: + +- **`WritableChannel`** — `send(item)` blocks when the buffer is full (backpressure). `close()` signals that no more items will be sent. +- **`ReadableChannel`** — `receive()` blocks until an item is available. Raises `ChannelClosed` when the channel is closed and drained. Supports `async for` iteration and `collect()` to drain into a list. + +### BroadcastChannel + +A `BroadcastChannel[T]` fans out items from a single writer to multiple independent readers. Each `add_reader()` creates a reader with its own queue, so downstream consumers read at their own pace without interfering. + +### Backpressure + +Backpressure propagates naturally: when a downstream reader is slow, the writer blocks on `send()` once the buffer fills. This prevents unbounded memory growth and creates natural flow control through the pipeline graph. + +--- + +## Packet Function Executor System + +Executors decouple **what** a packet function computes from **where** and **how** it runs. Every `PacketFunctionBase` has an optional `executor` slot. When set, `call()` and `async_call()` route through the executor instead of calling the function directly. + +### Routing + +``` +packet_function.call(packet) + ├── executor is set → executor.execute(packet_function, packet) + └── executor is None → packet_function.direct_call(packet) + +packet_function.async_call(packet) + ├── executor is set → executor.async_execute(packet_function, packet) + └── executor is None → packet_function.direct_async_call(packet) +``` + +Executors call `direct_call()` / `direct_async_call()` internally, which are the native computation methods that subclasses implement. This two-level routing ensures executors can wrap the computation without infinite recursion. + +### Executor Types + +| Executor | `executor_type_id` | Supported Types | Concurrent | Description | +|---|---|---|---|---| +| `LocalExecutor` | `"local"` | All | No | Runs in-process. Default. | +| `RayExecutor` | `"ray.v0"` | `"python.function.v0"` | Yes | Dispatches to a Ray cluster. Configurable CPUs/GPUs/resources. | + +### Type Safety + +Each executor declares `supported_function_type_ids()`. Setting an incompatible executor raises `ValueError` at assignment time, not at execution time. An empty set means "supports all types" (used by `LocalExecutor`). + +### Identity Separation + +Executors are **not** part of content or pipeline identity. The same function produces the same hash regardless of whether it runs locally or on Ray. Executor metadata is captured separately via `get_execution_data()` for observability but does not affect hashing or caching. + +### Concurrency Configuration + +Two-level configuration controls per-node concurrency in async mode: + +- **`PipelineConfig`** — pipeline-level defaults: `executor` type, `channel_buffer_size`, `default_max_concurrency`. +- **`NodeConfig`** — per-node override: `max_concurrency`. `None` inherits from pipeline config. `1` forces sequential execution (useful for rate-limited APIs or order-preserving operations). + +`resolve_concurrency(node_config, pipeline_config)` returns the effective limit. In `FunctionPod.async_execute`, this limit governs an `asyncio.Semaphore` controlling how many packets are in-flight concurrently. + +--- + +## Pipeline Compilation and Orchestration + +### Graph Tracking + +All pod invocations are automatically recorded by a global `BasicTrackerManager`. When a `StaticOutputPod.process()` or `FunctionPod.process()` is called, the tracker manager broadcasts the invocation to all registered trackers. This enables transparent DAG construction — the user writes normal imperative code, and the computation graph is captured behind the scenes. + +`GraphTracker` is the base tracker implementation. It maintains: +- A **node lookup table** (`_node_lut`) mapping content hashes to `FunctionNode`, `OperatorNode`, or `SourceNode` objects. +- An **upstream map** (`_upstreams`) mapping stream content hashes to stream objects. +- A directed **edge list** (`_graph_edges`) recording (upstream_hash → downstream_hash) relationships. + +`GraphTracker.compile()` builds a `networkx.DiGraph`, topologically sorts it, and wraps unregistered leaf hashes in `SourceNode` objects, producing a complete typed DAG. + +### Pipeline + +`Pipeline` extends `GraphTracker` with persistence. Its lifecycle has three phases: + +**1. Recording phase (context manager).** Within a `with pipeline:` block, the pipeline registers itself as an active tracker. All pod invocations are captured as non-persistent nodes. + +**2. Compilation phase (`compile()`).** On context exit (if `auto_compile=True`), `compile()` walks the graph in topological order and replaces every node with its persistent variant: + +| Non-persistent | Persistent | Scoped by | +|---|---|---| +| Leaf stream | `PersistentSourceNode` | Stream content hash | +| `FunctionNode` | `PersistentFunctionNode` | Pipeline hash (schema+topology) | +| `OperatorNode` | `PersistentOperatorNode` | Content hash (structure+sources) | + +All persistent nodes share the same `pipeline_database` with the pipeline's name as path prefix. An optional separate `function_database` can be provided for function pod result caches. + +Compilation is **incremental**: re-entering the context, adding more operations, and compiling again preserves existing persistent nodes. Labels are disambiguated by content hash on collision. + +**3. Execution phase (`run()`).** Executes all compiled nodes in topological order by calling `node.run()` on each, then flushes all databases. Compiled nodes are accessible by label as attributes on the pipeline instance (e.g., `pipeline.compute_grades`). + +### Persistent Nodes + +| Node type | Behavior | +|---|---| +| `PersistentSourceNode` | Materializes the wrapped stream into a cache DB with per-row deduplication via content hash. On subsequent access, returns the union of cached + live data. | +| `PersistentFunctionNode` | DB-backed two-phase iteration: Phase 1 yields cached results from the pipeline database, Phase 2 computes only missing inputs. Uses `CachedPacketFunction` for packet-level result caching. | +| `PersistentOperatorNode` | DB-backed with three-tier cache mode: OFF (default, always recompute), LOG (compute and write to DB), REPLAY (skip computation, load from DB). | + +### Pipeline Composition + +Pipelines can be composed across boundaries: +- **Cross-pipeline references** — Pipeline B can use Pipeline A's compiled nodes as input streams. +- **Chain detachment** via `.as_source()` — `PersistentFunctionNode.as_source()` and `PersistentOperatorNode.as_source()` return a `DerivedSource` that reads from the pipeline database, breaking the upstream Merkle chain. Downstream pipelines reference the derived source directly, independent of the upstream topology that produced it. + +--- + +## Fused Pod Pattern + +### Motivation + +The strict operator / function pod boundary is central to OrcaPod's provenance guarantees: operators never synthesize values (provenance transparent), function pods always synthesize values (provenance tracked). This two-category model keeps provenance tracking simple and robust. + +However, certain common patterns require combining both behaviors in a single logical operation. The most common is **enrichment** — running a function on a packet and appending the computed columns to the original packet rather than replacing it. The naïve decomposition into `FunctionPod + Join` works but incurs unnecessary overhead: an intermediate stream is materialized only to be immediately joined back, and the join must re-match tags that trivially correspond because they came from the same input row. + +### Fused Pods as Optimization, Not Extension + +A **fused pod** is an implementation-level pod type that combines the behaviors of multiple existing pod types into a single pass, without introducing a new provenance category. Its correctness is verified by checking equivalence with its decomposition. + +The key invariant: **every column in a fused pod's output maps to exactly one existing provenance category.** + +- **Preserved columns** (from upstream) — provenance transparent, source-info passes through unchanged. This is the operator-like component. +- **Computed columns** (from the wrapped PacketFunction) — provenance tracked, source-info references the PacketFunction. This is the function-pod-like component. + +There is no third kind of output column. The theoretical provenance model stays clean (Source, Operator, FunctionPod), and fused pods are justified as performance/ergonomic optimizations whose provenance semantics are *derived from* the existing model rather than extending it. + +This is analogous to how a database query optimizer fuses filter+project into a single scan without changing the relational algebra semantics. + +### AddResult + +The first planned fused pod. Wraps a `PacketFunction` and merges the function output back into the original packet: + +```python +grade_pf = PythonPacketFunction(compute_letter_grade, output_keys="letter_grade") +enriched = AddResult(grade_pf).process(stream) +# enriched has all original columns + "letter_grade" +``` + +Equivalent decomposition: `FunctionPod(pf).process(stream)` → `Join()(stream, computed)`. + +Efficiency gains: no intermediate stream materialization, no redundant tag matching, no broadcast/rejoin wiring. The async path streams row-by-row like FunctionPod. + +Implementation constraints: +- `output_schema()` returns `(input_tag_schema, input_packet_schema | function_output_schema)`. +- Raises `InputValidationError` if function output keys collide with existing packet column names. +- `pipeline_hash` commits to the wrapped PacketFunction's identity plus the upstream's pipeline hash (as if the decomposition were performed). +- Source-info on computed columns references the PacketFunction. Source-info on preserved columns passes through unchanged. + +--- + +## Data Context + +Every object is associated with a `DataContext` providing: + +| Component | Purpose | +|---|---| +| `semantic_hasher` | Recursive, type-aware object hashing for content/pipeline identity | +| `arrow_hasher` | Arrow table/record batch hashing | +| `type_converter` | Python ↔ Arrow type conversion | +| `context_key` | Identifier for this context configuration | + +The data context ensures consistent hashing and type conversion across the pipeline. It is propagated through construction and accessible via the `DataContextMixin`. + +--- + +## Verification + +The ability to rerun and verify the exact chain of computation is a core feature. A pipeline run in verify mode recomputes every step and checks output hashes against stored results, producing a reproducibility certificate. + +Function pods carry a determinism declaration: +- **Deterministic pods** — verified by exact hash equality +- **Non-deterministic pods** — verified by an associated equivalence measure + +Equivalence measures are externally associated with function pods — not with schemas — because the same data type can require different notions of closeness in different computational contexts. + +--- + +## Separation of Concerns + +A consistent architectural principle: **computational identity is separated from computational semantics**. + +The content-addressed computation layer handles identity — pure, self-contained, uncontaminated by higher-level concerns. External associations carry richer semantic context: + +| Association | Informs | +|---|---| +| Schema linkage | Pipeline assembler / wiring validation | +| Equivalence measures | Verifier | +| Confidence levels | Registry / ecosystem tooling | + +None of these influence actual pod execution. diff --git a/plan.md b/plan.md new file mode 100644 index 00000000..5ac73a63 --- /dev/null +++ b/plan.md @@ -0,0 +1,830 @@ +# Plan: Unified `process_packet` / `async_process_packet` + Node `async_execute` + +## Goal + +Establish `process_packet` and `async_process_packet` as **the** universal per-packet +interface across FunctionPod, FunctionPodStream, FunctionNode, and PersistentFunctionNode. +All iteration paths — sequential, concurrent, and async — route through these methods. +Add `async_execute` to all four Node classes. Add cache-aware `async_call` to +`CachedPacketFunction`. Remove `_execute_concurrent` module-level helper. + +--- + +## What exists today + +### Class hierarchy + +``` +_FunctionPodBase (TraceableBase) + ├── process_packet(tag, packet) → calls packet_function.call(packet) + ├── FunctionPod + │ ├── process() → FunctionPodStream + │ └── async_execute() → calls packet_function.async_call(packet) DIRECTLY + │ + FunctionPodStream (StreamBase) + │ ├── _iter_packets_sequential() → calls _function_pod.process_packet(tag, packet) ✓ + │ └── _iter_packets_concurrent() → calls _execute_concurrent(packet_function, ...) DIRECTLY + │ + FunctionNode (StreamBase) + │ ├── _iter_packets_sequential() → calls _packet_function.call(packet) DIRECTLY + │ ├── _iter_packets_concurrent() → calls _execute_concurrent(_packet_function, ...) DIRECTLY + │ └── (no async_execute) + │ + PersistentFunctionNode (FunctionNode) + ├── process_packet(tag, packet) → calls _packet_function.call(packet, skip_cache_*=...) + │ then add_pipeline_record(...) + ├── iter_packets() → Phase 1: replay from DB + │ Phase 2: calls self.process_packet(tag, packet) ✓ + └── (no async_execute) + +OperatorNode (StreamBase) + ├── run() → calls _operator.process(*streams) + └── (no async_execute) + +PersistentOperatorNode (OperatorNode) + ├── _compute_and_store() → calls _operator.process() + bulk DB write + ├── _replay_from_cache() → loads from DB + └── (no async_execute) +``` + +### Module-level helpers + +```python +def _executor_supports_concurrent(packet_function) -> bool: + """True if the pf's executor supports concurrent execution.""" + +def _execute_concurrent(packet_function, packets) -> list[PacketProtocol | None]: + """Submit all packets concurrently via asyncio.gather(pf.async_call(...)). + Falls back to sequential pf.call() if already inside a running event loop.""" +``` + +### Problems + +1. **FunctionPod.async_execute** bypasses `process_packet` — calls `packet_function.async_call` + directly (line 317). +2. **FunctionPodStream._iter_packets_concurrent** bypasses `process_packet` — calls + `_execute_concurrent(packet_function, ...)` directly (line 472). +3. **FunctionNode._iter_packets_sequential** bypasses any process_packet — calls + `_packet_function.call(packet)` directly (line 831). +4. **FunctionNode._iter_packets_concurrent** same — calls `_execute_concurrent` directly + (line 852). +5. **CachedPacketFunction.async_call** inherits from `PacketFunctionWrapper` — completely + **bypasses the cache** (no lookup, no recording). +6. **No `async_process_packet`** exists anywhere. +7. **No `async_execute`** on any Node class. +8. **`_execute_concurrent`** is a module-level function that takes a raw `packet_function` + and list of bare `packets` — no way to route through `process_packet`. + +--- + +## Design principles + +### A. `process_packet` / `async_process_packet` is the single per-packet entry point + +Every class in the function pod hierarchy defines these two methods. **All** iteration and +execution paths go through them — sequential, concurrent, and async. No direct +`packet_function.call()` or `packet_function.async_call()` calls outside of these methods. + +``` +_FunctionPodBase.process_packet(tag, pkt) → packet_function.call(pkt) +_FunctionPodBase.async_process_packet(tag, pkt) → await packet_function.async_call(pkt) + +FunctionNode.process_packet(tag, pkt) → self._function_pod.process_packet(tag, pkt) +FunctionNode.async_process_packet(tag, pkt) → await self._function_pod.async_process_packet(tag, pkt) + +PersistentFunctionNode.process_packet(tag, pkt) → cache check → self._function_pod.process_packet → pipeline record +PersistentFunctionNode.async_process_packet(tag, pkt) → cache check → await self._function_pod.async_process_packet → pipeline record +``` + +Wait — there's a subtlety with PersistentFunctionNode. Today its `process_packet` calls +`self._packet_function.call(packet, skip_cache_lookup=..., skip_cache_insert=...)` directly, +where `self._packet_function` is a `CachedPacketFunction` (which wraps the original pf). +It does NOT delegate to the pod's `process_packet`. That's because PersistentFunctionNode +needs to pass `skip_cache_*` kwargs that the base `process_packet` doesn't accept. + +The cleanest structure: + +``` +PersistentFunctionNode.process_packet(tag, pkt) + → self._packet_function.call(pkt, skip_cache_*=...) # CachedPacketFunction (sync) + → self.add_pipeline_record(...) # pipeline DB (sync) + +PersistentFunctionNode.async_process_packet(tag, pkt) + → await self._packet_function.async_call(pkt, skip_cache_*=...) # CachedPacketFunction (async) + → self.add_pipeline_record(...) # pipeline DB (sync) +``` + +This is the same as today for the sync path. The `CachedPacketFunction` handles the result +cache internally. The `PersistentFunctionNode` handles pipeline records. Neither delegates +to the pod's `process_packet` — the pod is bypassed because the `CachedPacketFunction` +replaced the raw packet function in `__init__`. + +### B. Concurrent iteration routes through `async_process_packet` + +The concurrent path is inherently async — it uses `asyncio.gather`. So it naturally routes +through `async_process_packet`. The fallback path (when already inside an event loop) routes +through `process_packet` (sync). + +For **FunctionPodStream**, the target is the pod: +```python +# concurrent +await self._function_pod.async_process_packet(tag, pkt) +# fallback +self._function_pod.process_packet(tag, pkt) +``` + +For **FunctionNode**, the target is `self` — so overrides (PersistentFunctionNode) kick in: +```python +# concurrent +await self.async_process_packet(tag, pkt) +# fallback +self.process_packet(tag, pkt) +``` + +This means PersistentFunctionNode's concurrent path **automatically** gets cache checks + +pipeline records via polymorphism. No special handling needed. + +### C. `_execute_concurrent` is removed + +The module-level `_execute_concurrent(packet_function, packets)` helper is removed. Its +logic (asyncio.gather with event-loop fallback) is inlined into `_iter_packets_concurrent` +methods, but now routes through `process_packet` / `async_process_packet` instead of raw +`packet_function.call` / `packet_function.async_call`. + +The `_executor_supports_concurrent` helper stays — it's just a predicate check. + +### D. Sync and async are cleanly separated execution modes + +- Sync: `iter_packets()` / `as_table()` / `run()` +- Async: `async_execute(inputs, output)` + +They don't populate each other's caches. DB persistence (for Persistent variants) provides +durability that works across both modes. + +### E. OperatorNode delegates to operator, PersistentOperatorNode intercepts for storage + +Operators are opaque stream transformers — no per-packet hook. `OperatorNode` passes through +directly. `PersistentOperatorNode` uses an intermediate channel + `TaskGroup` to forward +results downstream immediately while collecting them for post-hoc DB storage. + +### F. DB operations stay synchronous + +The `ArrowDatabaseProtocol` is sync. All DB reads/writes within async methods are sync calls. +Acceptable because DB is typically in-process and fast. Async DB protocol is deferred. + +--- + +## Implementation steps + +### Step 1: Add `async_process_packet` to `_FunctionPodBase` + +**File:** `src/orcapod/core/function_pod.py` + +Add alongside existing `process_packet` (after line 180): + +```python +async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + return tag, await self.packet_function.async_call(packet) +``` + +### Step 2: Fix `FunctionPod.async_execute` to use `async_process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Change the `process_one` inner function (lines 315-322): + +```python +async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: + try: + tag, result_packet = await self.async_process_packet(tag, packet) + if result_packet is not None: + await output.send((tag, result_packet)) + finally: + if sem is not None: + sem.release() +``` + +### Step 3: Fix `FunctionPodStream._iter_packets_concurrent` to use `async_process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Replace the `_execute_concurrent` call (lines 454-482) with direct `async_process_packet` +routing: + +```python +def _iter_packets_concurrent( + self, +) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Collect remaining inputs, execute concurrently, and yield results in order.""" + input_iter = self._cached_input_iterator + + all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] + to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] + for i, (tag, packet) in enumerate(input_iter): + all_inputs.append((i, tag, packet)) + if i not in self._cached_output_packets: + to_compute.append((i, tag, packet)) + self._cached_input_iterator = None + + if to_compute: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self._function_pod.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + # No event loop — run concurrently via asyncio.run + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self._function_pod.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): + self._cached_output_packets[i] = (tag, output_packet) + + for i, *_ in all_inputs: + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet +``` + +**Note:** The method signature drops the `packet_function` parameter — it no longer needs +it since it routes through `self._function_pod`. + +The `iter_packets` method that calls this also needs updating — remove the `pf` argument: + +```python +def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + if self.is_stale: + self.clear_cache() + if self._cached_input_iterator is not None: + if _executor_supports_concurrent(self._function_pod.packet_function): + yield from self._iter_packets_concurrent() + else: + yield from self._iter_packets_sequential() + else: + for i in range(len(self._cached_output_packets)): + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet +``` + +### Step 4: Fix `FunctionNode._iter_packets_sequential` to use `process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Change line 831 from: +```python +output_packet = self._packet_function.call(packet) +self._cached_output_packets[i] = (tag, output_packet) +``` +to: +```python +tag, output_packet = self.process_packet(tag, packet) +self._cached_output_packets[i] = (tag, output_packet) +``` + +### Step 5: Fix `FunctionNode._iter_packets_concurrent` to use `async_process_packet` + +**File:** `src/orcapod/core/function_pod.py` + +Same transformation as Step 3, but routing through `self` instead of `self._function_pod`: + +```python +def _iter_packets_concurrent( + self, +) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Collect remaining inputs, execute concurrently, and yield results in order.""" + input_iter = self._cached_input_iterator + + all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] + to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] + for i, (tag, packet) in enumerate(input_iter): + all_inputs.append((i, tag, packet)) + if i not in self._cached_output_packets: + to_compute.append((i, tag, packet)) + self._cached_input_iterator = None + + if to_compute: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + # No event loop — run concurrently via asyncio.run + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): + self._cached_output_packets[i] = (tag, output_packet) + + for i, *_ in all_inputs: + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet +``` + +**Critical difference from Step 3:** Uses `self.process_packet` / `self.async_process_packet` +instead of `self._function_pod.*`. This means when `PersistentFunctionNode` inherits this +method, it automatically routes through its overridden `process_packet` / +`async_process_packet` which include cache checks + pipeline record storage. + +### Step 6: Remove `_execute_concurrent` + +**File:** `src/orcapod/core/function_pod.py` + +Delete the `_execute_concurrent` function (lines 52-82). Its logic is now inlined into the +`_iter_packets_concurrent` methods. + +### Step 7: Add `process_packet` and `async_process_packet` to `FunctionNode` + +**File:** `src/orcapod/core/function_pod.py` + +FunctionNode currently has no `process_packet`. Add delegation to the function pod: + +```python +def process_packet( + self, tag: TagProtocol, packet: PacketProtocol +) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a single packet by delegating to the function pod.""" + return self._function_pod.process_packet(tag, packet) + +async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + return await self._function_pod.async_process_packet(tag, packet) +``` + +### Step 8: Add `FunctionNode.async_execute` + +**File:** `src/orcapod/core/function_pod.py` + +Sequential streaming through `async_process_packet`: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Streaming async execution — process each packet via async_process_packet.""" + try: + async for tag, packet in inputs[0]: + tag, result_packet = await self.async_process_packet(tag, packet) + if result_packet is not None: + await output.send((tag, result_packet)) + finally: + await output.close() +``` + +### Step 9: Add async cache-aware `async_call` to `CachedPacketFunction` + +**File:** `src/orcapod/core/packet_function.py` + +Override `async_call` to mirror the sync `call()` logic (lines 508-533): + +```python +async def async_call( + self, + packet: PacketProtocol, + *, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, +) -> PacketProtocol | None: + """Async counterpart of ``call`` with cache check and recording.""" + output_packet = None + if not skip_cache_lookup: + logger.info("Checking for cache...") + output_packet = self.get_cached_output_for_packet(packet) + if output_packet is not None: + logger.info(f"Cache hit for {packet}!") + if output_packet is None: + output_packet = await self._packet_function.async_call(packet) + if output_packet is not None: + if not skip_cache_insert: + self.record_packet(packet, output_packet) + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet +``` + +### Step 10: Add `async_process_packet` to `PersistentFunctionNode` + +**File:** `src/orcapod/core/function_pod.py` + +PersistentFunctionNode already has `process_packet` (line 1027-1066) which calls +`self._packet_function.call(packet, skip_cache_*=...)` (where `_packet_function` is a +`CachedPacketFunction`) then `self.add_pipeline_record(...)`. Add the async counterpart: + +```python +async def async_process_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``. + + Uses the CachedPacketFunction's async_call for computation + result caching. + Pipeline record storage is synchronous (DB protocol is sync). + """ + output_packet = await self._packet_function.async_call( + packet, + skip_cache_lookup=skip_cache_lookup, + skip_cache_insert=skip_cache_insert, + ) + + if output_packet is not None: + result_computed = bool( + output_packet.get_meta_value( + self._packet_function.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + + return tag, output_packet +``` + +### Step 11: Add `PersistentFunctionNode.async_execute` (two-phase) + +**File:** `src/orcapod/core/function_pod.py` + +Overrides `FunctionNode.async_execute`: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Two-phase async execution: replay cached, then compute missing.""" + try: + # Phase 1: emit existing results from DB + existing = self.get_all_records(columns={"meta": True}) + computed_hashes: set[str] = set() + if existing is not None and existing.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + hash_col = constants.INPUT_PACKET_HASH_COL + computed_hashes = set( + cast(list[str], existing.column(hash_col).to_pylist()) + ) + data_table = existing.drop([hash_col]) + existing_stream = ArrowTableStream(data_table, tag_columns=tag_keys) + for tag, packet in existing_stream.iter_packets(): + await output.send((tag, packet)) + + # Phase 2: process packets not already in the DB + async for tag, packet in inputs[0]: + input_hash = packet.content_hash().to_string() + if input_hash in computed_hashes: + continue + tag, output_packet = await self.async_process_packet(tag, packet) + if output_packet is not None: + await output.send((tag, output_packet)) + finally: + await output.close() +``` + +### Step 12: Add `OperatorNode.async_execute` + +**File:** `src/orcapod/core/operator_node.py` + +Direct pass-through: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Delegate to operator's async_execute.""" + await self._operator.async_execute(inputs, output) +``` + +### Step 13: Extract `_store_output_stream` from `PersistentOperatorNode._compute_and_store` + +**File:** `src/orcapod/core/operator_node.py` + +```python +def _store_output_stream(self, stream: StreamProtocol) -> None: + """Materialize stream and store in the pipeline database with per-row dedup.""" + output_table = stream.as_table( + columns={"source": True, "system_tags": True}, + ) + + arrow_hasher = self.data_context.arrow_hasher + record_hashes = [] + for batch in output_table.to_batches(): + for i in range(len(batch)): + record_hashes.append( + arrow_hasher.hash_table(batch.slice(i, 1)).to_hex() + ) + + output_table = output_table.add_column( + 0, + self.HASH_COLUMN_NAME, + pa.array(record_hashes, type=pa.large_string()), + ) + + self._pipeline_database.add_records( + self.pipeline_path, + output_table, + record_id_column=self.HASH_COLUMN_NAME, + skip_duplicates=True, + ) + + self._cached_output_table = output_table.drop(self.HASH_COLUMN_NAME) +``` + +Refactor `_compute_and_store`: + +```python +def _compute_and_store(self) -> None: + self._cached_output_stream = self._operator.process(*self._input_streams) + if self._cache_mode == CacheMode.OFF: + self._update_modified_time() + return + self._store_output_stream(self._cached_output_stream) + self._update_modified_time() +``` + +### Step 14: Add `PersistentOperatorNode.async_execute` + +**File:** `src/orcapod/core/operator_node.py` + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], +) -> None: + """Async execution with cache mode handling. + + REPLAY: emit from DB, close output. + OFF: delegate to operator, forward results. + LOG: delegate to operator, forward + collect results, then store in DB. + """ + try: + if self._cache_mode == CacheMode.REPLAY: + self._replay_from_cache() + assert self._cached_output_stream is not None + for tag, packet in self._cached_output_stream.iter_packets(): + await output.send((tag, packet)) + return # finally block closes output + + # OFF or LOG: delegate to operator, forward results downstream + intermediate = Channel[tuple[TagProtocol, PacketProtocol]]() + collected: list[tuple[TagProtocol, PacketProtocol]] = [] + + async def forward() -> None: + async for item in intermediate.reader: + collected.append(item) + await output.send(item) + + async with asyncio.TaskGroup() as tg: + tg.create_task( + self._operator.async_execute(inputs, intermediate.writer) + ) + tg.create_task(forward()) + + # TaskGroup has completed — all results are in `collected` + # Store if LOG mode (sync DB write, post-hoc) + if self._cache_mode == CacheMode.LOG and collected: + stream = StaticOutputPod._materialize_to_stream(collected) + self._cached_output_stream = stream + self._store_output_stream(stream) + + self._update_modified_time() + finally: + await output.close() +``` + +### Step 15: Add imports + +**`src/orcapod/core/operator_node.py`** — add: +```python +import asyncio +from collections.abc import Sequence + +from orcapod.channels import Channel, ReadableChannel, WritableChannel +from orcapod.core.static_output_pod import StaticOutputPod +``` + +**`src/orcapod/core/function_pod.py`** — already has all needed imports. + +### Step 16: Update regression test for `_execute_concurrent` removal + +**File:** `tests/test_core/test_regression_fixes.py` + +`TestExecuteConcurrentInRunningLoop` imports and tests `_execute_concurrent` directly. +Since we're removing that function, this test class needs to be rewritten to test the +behavior through the actual classes: + +- Test that `FunctionPodStream._iter_packets_concurrent` falls back to sequential + `process_packet` when called inside a running event loop. +- Test that `FunctionNode._iter_packets_concurrent` does the same. + +The tested behavior (event-loop fallback) is preserved — it's just now method-internal +rather than in a standalone helper. + +### Step 17: Tests for new functionality + +**File:** `tests/test_channels/test_node_async_execute.py` (new) + +``` +TestProtocolConformance + - test_function_node_satisfies_async_executable_protocol + - test_persistent_function_node_satisfies_async_executable_protocol + - test_operator_node_satisfies_async_executable_protocol + - test_persistent_operator_node_satisfies_async_executable_protocol + +TestCachedPacketFunctionAsync + - test_async_call_cache_miss_computes_and_records + - test_async_call_cache_hit_returns_cached + - test_async_call_skip_cache_lookup + - test_async_call_skip_cache_insert + +TestProcessPacketRouting + - test_function_pod_stream_sequential_uses_process_packet + - test_function_pod_stream_concurrent_uses_async_process_packet + - test_function_node_sequential_uses_process_packet + - test_function_node_concurrent_uses_async_process_packet + - test_persistent_function_node_concurrent_uses_overridden_async_process_packet + - test_concurrent_fallback_in_event_loop_uses_sync_process_packet + +TestFunctionNodeAsyncExecute + - test_basic_streaming_matches_sync + - test_empty_input_closes_cleanly + - test_none_packets_filtered_out + +TestPersistentFunctionNodeAsyncExecute + - test_no_cache_processes_all_inputs + - test_phase1_emits_cached_results + - test_phase2_skips_cached_computes_new + - test_pipeline_records_created_for_new_packets + - test_result_cache_populated_for_new_packets + +TestOperatorNodeAsyncExecute + - test_unary_op_delegation (SelectPacketColumns) + - test_binary_op_delegation (SemiJoin) + - test_nary_op_delegation (Join) + - test_results_match_sync_run + +TestPersistentOperatorNodeAsyncExecute + - test_off_mode_computes_no_db_write + - test_log_mode_computes_and_stores + - test_log_mode_results_match_sync + - test_replay_mode_emits_from_db + - test_replay_empty_db_returns_empty + +TestEndToEnd + - test_source_to_persistent_function_node_pipeline + - test_source_to_persistent_operator_node_pipeline +``` + +### Step 18: Run full test suite + +```bash +uv run pytest tests/ -x +``` + +--- + +## Summary of all changes + +### Call chains after changes + +**Sync sequential path:** +``` +FunctionPodStream._iter_packets_sequential + → self._function_pod.process_packet(tag, pkt) # already correct + → packet_function.call(pkt) + +FunctionNode._iter_packets_sequential + → self.process_packet(tag, pkt) # CHANGED: was _packet_function.call(pkt) + → self._function_pod.process_packet(tag, pkt) + → packet_function.call(pkt) + +PersistentFunctionNode._iter_packets_sequential (inherited from FunctionNode) + → self.process_packet(tag, pkt) # polymorphism kicks in + → CachedPacketFunction.call(pkt, skip_cache_*=...) # cache check + compute + record + → self.add_pipeline_record(...) # pipeline DB +``` + +**Sync concurrent path:** +``` +FunctionPodStream._iter_packets_concurrent + → asyncio.run(gather( + self._function_pod.async_process_packet(tag, pkt) ... # CHANGED: was _execute_concurrent + )) + OR (if event loop running): + self._function_pod.process_packet(tag, pkt) ... # fallback + +FunctionNode._iter_packets_concurrent + → asyncio.run(gather( + self.async_process_packet(tag, pkt) ... # CHANGED: was _execute_concurrent + )) + OR (if event loop running): + self.process_packet(tag, pkt) ... # fallback + +PersistentFunctionNode._iter_packets_concurrent (inherited from FunctionNode) + → asyncio.run(gather( + self.async_process_packet(tag, pkt) ... # polymorphism kicks in + → await CachedPacketFunction.async_call(pkt) # cache + compute + → self.add_pipeline_record(...) # pipeline DB + )) +``` + +**Async execution path:** +``` +FunctionPod.async_execute + → await self.async_process_packet(tag, pkt) # CHANGED: was packet_function.async_call + → await packet_function.async_call(pkt) + +FunctionNode.async_execute # NEW + → await self.async_process_packet(tag, pkt) + → await self._function_pod.async_process_packet(tag, pkt) + → await packet_function.async_call(pkt) + +PersistentFunctionNode.async_execute # NEW (two-phase) + Phase 1: emit from DB + Phase 2: + → await self.async_process_packet(tag, pkt) # polymorphic override + → await CachedPacketFunction.async_call(pkt) # cache + compute + → self.add_pipeline_record(...) # pipeline DB (sync) + +OperatorNode.async_execute # NEW + → await operator.async_execute(inputs, output) + +PersistentOperatorNode.async_execute # NEW + REPLAY: emit from DB + OFF/LOG: + TaskGroup: + operator.async_execute(inputs, intermediate.writer) + forward(intermediate.reader → output + collect) + if LOG: _store_output_stream(materialize(collected)) # sync DB write +``` + +### Files modified + +| File | Changes | +|------|---------| +| `src/orcapod/core/packet_function.py` | Add `CachedPacketFunction.async_call` override with cache logic | +| `src/orcapod/core/function_pod.py` | (1) Add `_FunctionPodBase.async_process_packet` | +| | (2) Fix `FunctionPod.async_execute` to use `async_process_packet` | +| | (3) Rewrite `FunctionPodStream._iter_packets_concurrent` — route through `_function_pod.async_process_packet` / `process_packet`, drop `packet_function` param | +| | (4) Update `FunctionPodStream.iter_packets` — remove `pf` arg to `_iter_packets_concurrent` | +| | (5) Fix `FunctionNode._iter_packets_sequential` to use `self.process_packet` | +| | (6) Rewrite `FunctionNode._iter_packets_concurrent` — route through `self.async_process_packet` / `self.process_packet` | +| | (7) Add `FunctionNode.process_packet` + `async_process_packet` (delegate to pod) | +| | (8) Add `FunctionNode.async_execute` | +| | (9) Add `PersistentFunctionNode.async_process_packet` (cache + pipeline records) | +| | (10) Add `PersistentFunctionNode.async_execute` (two-phase) | +| | (11) Remove `_execute_concurrent` module-level helper | +| `src/orcapod/core/operator_node.py` | (1) Add imports | +| | (2) Add `OperatorNode.async_execute` (pass-through) | +| | (3) Extract `PersistentOperatorNode._store_output_stream` | +| | (4) Refactor `PersistentOperatorNode._compute_and_store` | +| | (5) Add `PersistentOperatorNode.async_execute` (TaskGroup + post-hoc storage) | +| `tests/test_core/test_regression_fixes.py` | Rewrite `TestExecuteConcurrentInRunningLoop` — test through classes instead of removed helper | +| `tests/test_channels/test_node_async_execute.py` | New test file | diff --git a/pyproject.toml b/pyproject.toml index eb38abae..ad9c04ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,13 @@ dependencies = [ "deltalake>=1.0.2", "graphviz>=0.21", "gitpython>=3.1.45", + "starfix>=0.1.3", + "pygraphviz>=1.14", + "tzdata>=2024.1", + "uuid-utils>=0.11.1", + "s3fs>=2025.12.0", + "pymongo>=4.15.5", + "basedpyright>=1.38.1", ] readme = "README.md" requires-python = ">=3.11.0" @@ -48,21 +55,25 @@ version_file = "src/orcapod/_version.py" [dependency-groups] dev = [ "httpie>=3.2.4", + "hypothesis>=6.0", "hydra-core>=1.3.2", "imageio>=2.37.0", "ipykernel>=6.29.5", "ipywidgets>=8.1.7", "jsonschema>=4.25.0", "minio>=7.2.16", + "pre-commit>=4.4.0", + "pre-commit-hooks>=6.0.0", "pyarrow-stubs>=20.0.0.20250716", - "pygraphviz>=1.14", "pyiceberg>=0.9.1", "pyright>=1.1.404", "pytest>=8.3.5", + "pytest-asyncio>=1.3.0", "pytest-cov>=6.1.1", "ray[default]==2.48.0", "redis>=6.2.0", - "ruff>=0.11.11", - "sphinx>=8.2.3", + "ruff>=0.14.4", "tqdm>=4.67.1", + "mkdocs-material>=9.7.5", + "mkdocstrings[python]>=1.0.3", ] diff --git a/sample.py b/sample.py new file mode 100644 index 00000000..33ced88e --- /dev/null +++ b/sample.py @@ -0,0 +1,544 @@ +""" +Arrow Schema BSON Serialization + +Implements the Arrow Schema Canonical Serialization Specification v2.0.0 +for deterministic, cross-language schema hashing. + +Requirements: + pip install pyarrow pymongo + +Usage: + import pyarrow as pa + from arrow_schema_bson import serialize_schema, deserialize_schema + + schema = pa.schema([ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.utf8(), nullable=True), + ]) + + bson_bytes = serialize_schema(schema) + reconstructed = deserialize_schema(bson_bytes) +""" + +from collections import OrderedDict +from typing import Any + +import bson +import pyarrow as pa + + +def sort_keys_recursive(obj: Any) -> Any: + """Recursively sort all dictionary keys alphabetically.""" + if isinstance(obj, dict): + return OrderedDict((k, sort_keys_recursive(v)) for k, v in sorted(obj.items())) + elif isinstance(obj, list): + return [sort_keys_recursive(x) for x in obj] + return obj + + +def serialize_type(arrow_type: pa.DataType) -> dict[str, Any]: + """Convert an Arrow DataType to a canonical type descriptor.""" + + # Null + if pa.types.is_null(arrow_type): + return {"name": "null"} + + # Boolean + if pa.types.is_boolean(arrow_type): + return {"name": "bool"} + + # Integers + if pa.types.is_integer(arrow_type): + return { + "bitWidth": arrow_type.bit_width, + "isSigned": pa.types.is_signed_integer(arrow_type), + "name": "int", + } + + # Floating point + if pa.types.is_floating(arrow_type): + precision_map = {16: "HALF", 32: "SINGLE", 64: "DOUBLE"} + return { + "name": "floatingpoint", + "precision": precision_map[arrow_type.bit_width], + } + + # Decimal + if pa.types.is_decimal(arrow_type): + return { + "bitWidth": arrow_type.bit_width, + "name": "decimal", + "precision": arrow_type.precision, + "scale": arrow_type.scale, + } + + # Date + if pa.types.is_date(arrow_type): + if pa.types.is_date32(arrow_type): + return {"name": "date", "unit": "DAY"} + else: # date64 + return {"name": "date", "unit": "MILLISECOND"} + + # Time + if pa.types.is_time(arrow_type): + unit_map = { + "s": "SECOND", + "ms": "MILLISECOND", + "us": "MICROSECOND", + "ns": "NANOSECOND", + } + return { + "bitWidth": arrow_type.bit_width, + "name": "time", + "unit": unit_map[str(arrow_type.unit)], + } + + # Timestamp + if pa.types.is_timestamp(arrow_type): + unit_map = { + "s": "SECOND", + "ms": "MILLISECOND", + "us": "MICROSECOND", + "ns": "NANOSECOND", + } + return { + "name": "timestamp", + "timezone": arrow_type.tz, # None if no timezone + "unit": unit_map[str(arrow_type.unit)], + } + + # Duration + if pa.types.is_duration(arrow_type): + unit_map = { + "s": "SECOND", + "ms": "MILLISECOND", + "us": "MICROSECOND", + "ns": "NANOSECOND", + } + return { + "name": "duration", + "unit": unit_map[str(arrow_type.unit)], + } + + # Interval + if pa.types.is_interval(arrow_type): + if arrow_type == pa.month_day_nano_interval(): + unit = "MONTH_DAY_NANO" + elif arrow_type == pa.day_time_interval(): + unit = "DAY_TIME" + else: + unit = "YEAR_MONTH" + return {"name": "interval", "unit": unit} + + # Binary types + if pa.types.is_fixed_size_binary(arrow_type): + return { + "byteWidth": arrow_type.byte_width, + "name": "fixedsizebinary", + } + + if pa.types.is_large_binary(arrow_type): + return {"name": "largebinary"} + + if pa.types.is_binary(arrow_type): + return {"name": "binary"} + + # String types - check by comparing to type instances + if arrow_type == pa.utf8() or arrow_type == pa.string(): + return {"name": "utf8"} + + if arrow_type == pa.large_utf8() or arrow_type == pa.large_string(): + return {"name": "largeutf8"} + + # List types + if pa.types.is_list(arrow_type): + return { + "children": [serialize_field(arrow_type.value_field)], + "name": "list", + } + + if pa.types.is_large_list(arrow_type): + return { + "children": [serialize_field(arrow_type.value_field)], + "name": "largelist", + } + + if pa.types.is_fixed_size_list(arrow_type): + return { + "children": [serialize_field(arrow_type.value_field)], + "listSize": arrow_type.list_size, + "name": "fixedsizelist", + } + + # Struct + if pa.types.is_struct(arrow_type): + children = {} + for i in range(arrow_type.num_fields): + field = arrow_type.field(i) + children[field.name] = serialize_field(field) + return { + "children": children, + "name": "struct", + } + + # Map + if pa.types.is_map(arrow_type): + return { + "children": [ + serialize_field(arrow_type.key_field), + serialize_field(arrow_type.item_field), + ], + "keysSorted": arrow_type.keys_sorted, + "name": "map", + } + + # Union + if pa.types.is_union(arrow_type): + mode = "SPARSE" if arrow_type.mode == "sparse" else "DENSE" + children = [] + for i in range(arrow_type.num_fields): + children.append(serialize_field(arrow_type.field(i))) + return { + "children": children, + "mode": mode, + "name": "union", + "typeIds": list(arrow_type.type_codes), + } + + # Dictionary + if pa.types.is_dictionary(arrow_type): + return { + "indexType": serialize_type(arrow_type.index_type), + "name": "dictionary", + "valueType": serialize_type(arrow_type.value_type), + } + + raise ValueError(f"Unsupported Arrow type: {arrow_type}") + + +def serialize_field(field: pa.Field) -> dict: + """Convert an Arrow Field to a canonical field descriptor.""" + return { + "nullable": field.nullable, + "type": serialize_type(field.type), + } + + +def serialize_schema(schema: pa.Schema) -> bytes: + """ + Serialize an Arrow Schema to canonical BSON bytes. + + The output is deterministic: identical schemas always produce + identical byte sequences, regardless of field definition order. + """ + doc = {} + for i in range(len(schema)): + field = schema.field(i) + doc[field.name] = serialize_field(field) + + sorted_doc = sort_keys_recursive(doc) + return bson.encode(sorted_doc) + + +def serialize_schema_to_hex(schema: pa.Schema) -> str: + """Serialize schema and return hex string for debugging.""" + return serialize_schema(schema).hex() + + +# ----------------------------------------------------------------------------- +# Deserialization +# ----------------------------------------------------------------------------- + + +def deserialize_type(type_desc: dict) -> pa.DataType: + """Convert a type descriptor back to an Arrow DataType.""" + name = type_desc["name"] + + if name == "null": + return pa.null() + + if name == "bool": + return pa.bool_() + + if name == "int": + bit_width = type_desc["bitWidth"] + signed = type_desc["isSigned"] + type_map = { + (8, True): pa.int8(), + (8, False): pa.uint8(), + (16, True): pa.int16(), + (16, False): pa.uint16(), + (32, True): pa.int32(), + (32, False): pa.uint32(), + (64, True): pa.int64(), + (64, False): pa.uint64(), + } + return type_map[(bit_width, signed)] + + if name == "floatingpoint": + precision_map = { + "HALF": pa.float16(), + "SINGLE": pa.float32(), + "DOUBLE": pa.float64(), + } + return precision_map[type_desc["precision"]] + + if name == "decimal": + bit_width = type_desc["bitWidth"] + precision = type_desc["precision"] + scale = type_desc["scale"] + if bit_width == 128: + return pa.decimal128(precision, scale) + elif bit_width == 256: + return pa.decimal256(precision, scale) + else: + raise ValueError(f"Unsupported decimal bit width: {bit_width}") + + if name == "date": + if type_desc["unit"] == "DAY": + return pa.date32() + else: + return pa.date64() + + if name == "time": + unit_map = { + "SECOND": "s", + "MILLISECOND": "ms", + "MICROSECOND": "us", + "NANOSECOND": "ns", + } + unit = unit_map[type_desc["unit"]] + bit_width = type_desc["bitWidth"] + if bit_width == 32: + return pa.time32(unit) + else: + return pa.time64(unit) + + if name == "timestamp": + unit_map = { + "SECOND": "s", + "MILLISECOND": "ms", + "MICROSECOND": "us", + "NANOSECOND": "ns", + } + unit = unit_map[type_desc["unit"]] + tz = type_desc.get("timezone") + return pa.timestamp(unit, tz=tz) + + if name == "duration": + unit_map = { + "SECOND": "s", + "MILLISECOND": "ms", + "MICROSECOND": "us", + "NANOSECOND": "ns", + } + return pa.duration(unit_map[type_desc["unit"]]) + + if name == "interval": + unit = type_desc["unit"] + if unit == "YEAR_MONTH": + return pa.month_day_nano_interval() # PyArrow limitation + elif unit == "DAY_TIME": + return pa.day_time_interval() + else: + return pa.month_day_nano_interval() + + if name == "binary": + return pa.binary() + + if name == "largebinary": + return pa.large_binary() + + if name == "fixedsizebinary": + return pa.binary(type_desc["byteWidth"]) + + if name == "utf8": + return pa.utf8() + + if name == "largeutf8": + return pa.large_utf8() + + if name == "list": + child_field = deserialize_field("item", type_desc["children"][0]) + return pa.list_(child_field) + + if name == "largelist": + child_field = deserialize_field("item", type_desc["children"][0]) + return pa.large_list(child_field) + + if name == "fixedsizelist": + child_field = deserialize_field("item", type_desc["children"][0]) + return pa.list_(child_field, type_desc["listSize"]) + + if name == "struct": + fields = [] + children = type_desc["children"] + for field_name in sorted(children.keys()): + fields.append(deserialize_field(field_name, children[field_name])) + return pa.struct(fields) + + if name == "map": + key_field = deserialize_field("key", type_desc["children"][0]) + value_field = deserialize_field("value", type_desc["children"][1]) + return pa.map_( + key_field.type, value_field.type, keys_sorted=type_desc["keysSorted"] + ) + + if name == "union": + fields = [] + for i, child in enumerate(type_desc["children"]): + fields.append(deserialize_field(f"field_{i}", child)) + type_ids = type_desc["typeIds"] + mode = type_desc["mode"].lower() + return pa.union(fields, mode=mode, type_codes=type_ids) + + if name == "dictionary": + index_type = deserialize_type(type_desc["indexType"]) + value_type = deserialize_type(type_desc["valueType"]) + return pa.dictionary(index_type, value_type) + + raise ValueError(f"Unknown type name: {name}") + + +def deserialize_field(name: str, field_desc: dict) -> pa.Field: + """Convert a field descriptor back to an Arrow Field.""" + return pa.field( + name, + deserialize_type(field_desc["type"]), + nullable=field_desc["nullable"], + ) + + +def deserialize_schema(bson_bytes: bytes) -> pa.Schema: + """ + Deserialize BSON bytes back to an Arrow Schema. + + Fields are reconstructed in alphabetical order by name. + """ + doc = bson.decode(bson_bytes) + fields = [] + for field_name in sorted(doc.keys()): + fields.append(deserialize_field(field_name, doc[field_name])) + return pa.schema(fields) + + +# ----------------------------------------------------------------------------- +# Testing / Verification +# ----------------------------------------------------------------------------- + + +def verify_roundtrip(schema: pa.Schema) -> bool: + """Verify that a schema survives serialization roundtrip.""" + bson_bytes = serialize_schema(schema) + reconstructed = deserialize_schema(bson_bytes) + return schema.equals(reconstructed) + + +def print_debug(schema: pa.Schema) -> None: + """Print debug information about schema serialization.""" + import json + + print("Original Schema:") + print(schema) + print() + + # Build the document (before BSON encoding) + doc = {} + for i in range(len(schema)): + field = schema.field(i) + doc[field.name] = serialize_field(field) + sorted_doc = sort_keys_recursive(doc) + + print("Canonical JSON representation:") + print(json.dumps(sorted_doc, indent=2)) + print() + + bson_bytes = bson.encode(sorted_doc) + print(f"BSON bytes ({len(bson_bytes)} bytes):") + print(bson_bytes.hex()) + print() + + # Verify roundtrip + reconstructed = deserialize_schema(bson_bytes) + print("Reconstructed Schema:") + print(reconstructed) + print() + print("Roundtrip successful:", schema.equals(reconstructed)) + + +# ----------------------------------------------------------------------------- +# Example usage +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + # Example 1: Simple schema + schema1 = pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("name", pa.utf8(), nullable=True), + pa.field("score", pa.float64(), nullable=False), + ] + ) + print("=" * 60) + print("Example 1: Simple schema") + print("=" * 60) + print_debug(schema1) + + # Example 2: Schema with nested struct + schema2 = pa.schema( + [ + pa.field("user_id", pa.int64(), nullable=False), + pa.field( + "profile", + pa.struct( + [ + pa.field("email", pa.utf8(), nullable=False), + pa.field("age", pa.int32(), nullable=True), + ] + ), + nullable=True, + ), + ] + ) + print("\n" + "=" * 60) + print("Example 2: Nested struct") + print("=" * 60) + print_debug(schema2) + + # Example 3: Schema with list and timestamp + schema3 = pa.schema( + [ + pa.field("event_time", pa.timestamp("us", tz="UTC"), nullable=False), + pa.field( + "tags", + pa.list_(pa.field("item", pa.utf8(), nullable=True)), + nullable=True, + ), + ] + ) + print("\n" + "=" * 60) + print("Example 3: List and timestamp") + print("=" * 60) + print_debug(schema3) + + # Example 4: Demonstrate field order independence + schema4a = pa.schema( + [ + pa.field("b", pa.int32()), + pa.field("a", pa.int32()), + ] + ) + schema4b = pa.schema( + [ + pa.field("a", pa.int32()), + pa.field("b", pa.int32()), + ] + ) + print("\n" + "=" * 60) + print("Example 4: Field order independence") + print("=" * 60) + bytes_a = serialize_schema(schema4a) + bytes_b = serialize_schema(schema4b) + print(f"Schema [b, a] -> {bytes_a.hex()}") + print(f"Schema [a, b] -> {bytes_b.hex()}") + print(f"Identical bytes: {bytes_a == bytes_b}") diff --git a/src/orcapod/__init__.py b/src/orcapod/__init__.py index 226850e3..8afbf1ef 100644 --- a/src/orcapod/__init__.py +++ b/src/orcapod/__init__.py @@ -1,29 +1,27 @@ -from .config import DEFAULT_CONFIG, Config -from .core import DEFAULT_TRACKER_MANAGER -from .core.pods import function_pod, FunctionPod, CachedPod -from .core import streams -from .core import operators -from .core import sources -from .core.sources import DataFrameSource -from . import databases +from .core.function_pod import ( + FunctionPod, + function_pod, +) +from .core.sources.arrow_table_source import ArrowTableSource from .pipeline import Pipeline - - -no_tracking = DEFAULT_TRACKER_MANAGER.no_tracking +# Subpackage re-exports for clean public API +from . import databases # noqa: F401 +from . import nodes # noqa: F401 +from . import operators # noqa: F401 +from . import sources # noqa: F401 +from . import streams # noqa: F401 +from . import types # noqa: F401 __all__ = [ - "DEFAULT_CONFIG", - "Config", - "DEFAULT_TRACKER_MANAGER", - "no_tracking", - "function_pod", + "ArrowTableSource", "FunctionPod", - "CachedPod", - "streams", + "function_pod", + "Pipeline", "databases", - "sources", - "DataFrameSource", + "nodes", "operators", - "Pipeline", + "sources", + "streams", + "types", ] diff --git a/src/orcapod/channels.py b/src/orcapod/channels.py new file mode 100644 index 00000000..e9fa309f --- /dev/null +++ b/src/orcapod/channels.py @@ -0,0 +1,249 @@ +"""Async channel primitives for push-based pipeline execution. + +Provides bounded async channels with close/done signaling, backpressure, +and fan-out (broadcast) support. +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Generic, Protocol, TypeVar, runtime_checkable + +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T_contra = TypeVar("T_contra", contravariant=True) + + +# --------------------------------------------------------------------------- +# Sentinel & exception +# --------------------------------------------------------------------------- + + +class _Sentinel: + """Internal marker signaling channel closure.""" + + __slots__ = () + + def __repr__(self) -> str: + return "" + + +_CLOSED = _Sentinel() + + +class ChannelClosed(Exception): + """Raised when ``receive()`` is called on a closed, drained channel.""" + + +# --------------------------------------------------------------------------- +# Protocol types +# --------------------------------------------------------------------------- + + +@runtime_checkable +class ReadableChannel(Protocol[T_co]): + """Consumer side of a channel.""" + + async def receive(self) -> T_co: + """Receive next item. Raises ``ChannelClosed`` when done.""" + ... + + def __aiter__(self) -> AsyncIterator[T_co]: ... + + async def __anext__(self) -> T_co: ... + + async def collect(self) -> list[T_co]: + """Drain all remaining items into a list.""" + ... + + +@runtime_checkable +class WritableChannel(Protocol[T_contra]): + """Producer side of a channel.""" + + async def send(self, item: T_contra) -> None: + """Send an item. Blocks if channel buffer is full (backpressure).""" + ... + + async def close(self) -> None: + """Signal that no more items will be sent.""" + ... + + +# --------------------------------------------------------------------------- +# Concrete reader / writer views +# --------------------------------------------------------------------------- + + +class _ChannelReader(Generic[T]): + """Concrete ReadableChannel backed by a Channel.""" + + __slots__ = ("_channel",) + + def __init__(self, channel: Channel[T]) -> None: + self._channel = channel + + async def receive(self) -> T: + item = await self._channel._queue.get() + if isinstance(item, _Sentinel): + # Put sentinel back so other readers (broadcast) also see it + await self._channel._queue.put(item) + raise ChannelClosed() + return item # type: ignore[return-value] + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + try: + return await self.receive() + except ChannelClosed: + raise StopAsyncIteration + + async def collect(self) -> list[T]: + items: list[T] = [] + async for item in self: + items.append(item) + return items + + +class _ChannelWriter(Generic[T]): + """Concrete WritableChannel backed by a Channel.""" + + __slots__ = ("_channel",) + + def __init__(self, channel: Channel[T]) -> None: + self._channel = channel + + async def send(self, item: T) -> None: + if self._channel._closed.is_set(): + raise ChannelClosed("Cannot send to a closed channel") + await self._channel._queue.put(item) + + async def close(self) -> None: + if not self._channel._closed.is_set(): + self._channel._closed.set() + await self._channel._queue.put(_CLOSED) + + +# --------------------------------------------------------------------------- +# Channel +# --------------------------------------------------------------------------- + + +@dataclass +class Channel(Generic[T]): + """Bounded async channel with close/done signaling. + + Args: + buffer_size: Maximum number of items that can be buffered. + Defaults to 64. + """ + + buffer_size: int = 64 + _queue: asyncio.Queue[T | _Sentinel] = field(init=False) + _closed: asyncio.Event = field(init=False, default_factory=asyncio.Event) + + def __post_init__(self) -> None: + self._queue = asyncio.Queue(maxsize=self.buffer_size) + + @property + def reader(self) -> _ChannelReader[T]: + """Return a readable view of this channel.""" + return _ChannelReader(self) + + @property + def writer(self) -> _ChannelWriter[T]: + """Return a writable view of this channel.""" + return _ChannelWriter(self) + + +# --------------------------------------------------------------------------- +# Broadcast channel (fan-out) +# --------------------------------------------------------------------------- + + +class _BroadcastReader(Generic[T]): + """A reader that receives items broadcast from a shared source. + + Each broadcast reader maintains its own independent queue so that + multiple downstream consumers can read at their own pace. + """ + + __slots__ = ("_queue",) + + def __init__(self, buffer_size: int) -> None: + self._queue: asyncio.Queue[T | _Sentinel] = asyncio.Queue(maxsize=buffer_size) + + async def receive(self) -> T: + item = await self._queue.get() + if isinstance(item, _Sentinel): + # Re-enqueue so repeated receive() calls also raise + await self._queue.put(item) + raise ChannelClosed() + return item # type: ignore[return-value] + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + try: + return await self.receive() + except ChannelClosed: + raise StopAsyncIteration + + async def collect(self) -> list[T]: + items: list[T] = [] + async for item in self: + items.append(item) + return items + + +class BroadcastChannel(Generic[T]): + """A channel whose output is broadcast to multiple readers. + + Each call to ``add_reader()`` creates an independent reader queue. + Items sent via the writer are copied to every reader's queue. + + Args: + buffer_size: Per-reader buffer size. Defaults to 64. + """ + + def __init__(self, buffer_size: int = 64) -> None: + self._buffer_size = buffer_size + self._readers: list[_BroadcastReader[T]] = [] + self._closed = False + + def add_reader(self) -> _BroadcastReader[T]: + """Create and return a new reader for this broadcast channel.""" + reader = _BroadcastReader[T](self._buffer_size) + self._readers.append(reader) + return reader + + @property + def writer(self) -> _BroadcastWriter[T]: + """Return a writable view of this broadcast channel.""" + return _BroadcastWriter(self) + + +class _BroadcastWriter(Generic[T]): + """Writer that fans out items to all broadcast readers.""" + + __slots__ = ("_broadcast",) + + def __init__(self, broadcast: BroadcastChannel[T]) -> None: + self._broadcast = broadcast + + async def send(self, item: T) -> None: + if self._broadcast._closed: + raise ChannelClosed("Cannot send to a closed channel") + for reader in self._broadcast._readers: + await reader._queue.put(item) + + async def close(self) -> None: + if not self._broadcast._closed: + self._broadcast._closed = True + for reader in self._broadcast._readers: + await reader._queue.put(_CLOSED) diff --git a/src/orcapod/contexts/__init__.py b/src/orcapod/contexts/__init__.py index 116dbbb2..1694df67 100644 --- a/src/orcapod/contexts/__init__.py +++ b/src/orcapod/contexts/__init__.py @@ -1,5 +1,5 @@ """ -OrcaPod Data Context System +Orcapod Data Context System This package manages versioned data contexts that define how data should be interpreted and processed throughout the OrcaPod system. @@ -7,7 +7,7 @@ A DataContext contains: - Semantic type registry for handling structured data types - Arrow hasher for hashing Arrow tables -- Object hasher for general object hashing +- Semantic hasher for general Python object hashing - Versioning information for reproducibility Example usage: @@ -25,10 +25,13 @@ versions = get_available_contexts() """ -from .core import DataContext, ContextValidationError, ContextResolutionError -from .registry import JSONDataContextRegistry from typing import Any -from orcapod.protocols import hashing_protocols as hp, semantic_types_protocols as sp + +from orcapod.protocols import hashing_protocols as hp +from orcapod.protocols import semantic_types_protocols as sp + +from .core import ContextResolutionError, ContextValidationError, DataContext +from .registry import JSONDataContextRegistry # Global registry instance (lazily initialized) _registry: JSONDataContextRegistry | None = None @@ -165,27 +168,27 @@ def get_default_context() -> DataContext: return resolve_context() -def get_default_object_hasher() -> hp.ObjectHasher: +def get_default_semantic_hasher() -> hp.SemanticHasherProtocol: """ - Get the default object hasher. + Get the default semantic hasher. Returns: - ObjectHasher instance for the default context + SemanticHasherProtocol instance for the default context """ - return get_default_context().object_hasher + return get_default_context().semantic_hasher -def get_default_arrow_hasher() -> hp.ArrowHasher: +def get_default_arrow_hasher() -> hp.ArrowHasherProtocol: """ Get the default arrow hasher. Returns: - ArrowHasher instance for the default context + ArrowHasherProtocol instance for the default context """ return get_default_context().arrow_hasher -def get_default_type_converter() -> "sp.TypeConverter": +def get_default_type_converter() -> "sp.TypeConverterProtocol": """ Get the default type converter. diff --git a/src/orcapod/contexts/core.py b/src/orcapod/contexts/core.py index f1b35d33..cd6b1cf5 100644 --- a/src/orcapod/contexts/core.py +++ b/src/orcapod/contexts/core.py @@ -7,7 +7,12 @@ from dataclasses import dataclass -from orcapod.protocols import hashing_protocols as hp, semantic_types_protocols as sp +from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + SemanticHasherProtocol, +) +from orcapod.protocols.semantic_types_protocols import TypeConverterProtocol @dataclass @@ -25,15 +30,17 @@ class DataContext: description: Human-readable description of this context semantic_type_registry: Registry of semantic type converters arrow_hasher: Arrow table hasher for this context - object_hasher: General object hasher for this context + semantic_hasher: General semantic hasher for this context + type_handler_registry: Registry of TypeHandlerProtocol instances for SemanticHasherProtocol """ context_key: str version: str description: str - type_converter: sp.TypeConverter - arrow_hasher: hp.ArrowHasher - object_hasher: hp.ObjectHasher + type_converter: TypeConverterProtocol + arrow_hasher: ArrowHasherProtocol + semantic_hasher: SemanticHasherProtocol # this is the currently the JSON hasher + type_handler_registry: TypeHandlerRegistry class ContextValidationError(Exception): diff --git a/src/orcapod/contexts/data/schemas/context_schema.json b/src/orcapod/contexts/data/schemas/context_schema.json index 0485d51c..1e9f5468 100644 --- a/src/orcapod/contexts/data/schemas/context_schema.json +++ b/src/orcapod/contexts/data/schemas/context_schema.json @@ -11,7 +11,8 @@ "semantic_registry", "type_converter", "arrow_hasher", - "object_hasher" + "semantic_hasher", + "type_handler_registry" ], "properties": { "context_key": { @@ -54,9 +55,21 @@ "$ref": "#/$defs/objectspec", "description": "ObjectSpec for the Arrow hasher component" }, - "object_hasher": { + "semantic_hasher": { "$ref": "#/$defs/objectspec", - "description": "ObjectSpec for the object hasher component" + "description": "ObjectSpec for the semantic hasher component" + }, + "type_handler_registry": { + "$ref": "#/$defs/objectspec", + "description": "ObjectSpec for the TypeHandlerRegistry used by the semantic hasher" + }, + "file_hasher": { + "$ref": "#/$defs/objectspec", + "description": "ObjectSpec for the file content hasher (used by PathContentHandler)" + }, + "function_info_extractor": { + "$ref": "#/$defs/objectspec", + "description": "ObjectSpec for the function info extractor (used by FunctionHandler)" }, "metadata": { "type": "object", @@ -107,18 +120,12 @@ "oneOf": [ { "type": "object", - "required": [ - "_class" - ], + "required": ["_class"], "properties": { "_class": { "type": "string", "pattern": "^[a-zA-Z_][a-zA-Z0-9_.]*\\.[a-zA-Z_][a-zA-Z0-9_]*$", - "description": "Fully qualified class name", - "examples": [ - "orcapod.types.semantic_types.SemanticTypeRegistry", - "orcapod.hashing.arrow_hashers.SemanticArrowHasher" - ] + "description": "Fully qualified class name" }, "_config": { "type": "object", @@ -128,20 +135,31 @@ }, "additionalProperties": false }, + { + "type": "object", + "required": ["_ref"], + "properties": { + "_ref": {"type": "string", "description": "Reference to a named component"} + }, + "additionalProperties": false + }, + { + "type": "object", + "required": ["_type"], + "properties": { + "_type": {"type": "string", "description": "Dotted Python type string, e.g. 'pathlib.Path'"} + }, + "additionalProperties": false + }, { "type": "array", - "description": "Array of object specifications", + "description": "Array or tuple of object specifications", "items": { "$ref": "#/$defs/objectspec" } }, { - "type": [ - "string", - "number", - "boolean", - "null" - ], + "type": ["string", "number", "boolean", "null"], "description": "Primitive values" } ] @@ -184,17 +202,10 @@ } } }, - "object_hasher": { - "_class": "orcapod.hashing.object_hashers.BasicObjectHasher", + "semantic_hasher": { + "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.BaseSemanticHasher", "_config": { - "hasher_id": "object_v0.1", - "function_info_extractor": { - "_class": "orcapod.hashing.function_info_extractors.FunctionSignatureExtractor", - "_config": { - "include_module": true, - "include_defaults": true - } - } + "hasher_id": "semantic_v0.1" } }, "metadata": { diff --git a/src/orcapod/contexts/data/v0.1.json b/src/orcapod/contexts/data/v0.1.json index 9f1708e3..cd16b5d5 100644 --- a/src/orcapod/contexts/data/v0.1.json +++ b/src/orcapod/contexts/data/v0.1.json @@ -2,13 +2,21 @@ "context_key": "std:v0.1:default", "version": "v0.1", "description": "Initial stable release with basic Path semantic type support", + "file_hasher": { + "_class": "orcapod.hashing.file_hashers.BasicFileHasher", + "_config": { + "algorithm": "sha256" + } + }, "semantic_registry": { "_class": "orcapod.semantic_types.semantic_registry.SemanticTypeRegistry", "_config": { "converters": { "path": { "_class": "orcapod.semantic_types.semantic_struct_converters.PathStructConverter", - "_config": {} + "_config": { + "file_hasher": {"_ref": "file_hasher"} + } } } } @@ -33,16 +41,39 @@ } } }, - "object_hasher": { - "_class": "orcapod.hashing.object_hashers.BasicObjectHasher", + "function_info_extractor": { + "_class": "orcapod.hashing.semantic_hashing.function_info_extractors.FunctionSignatureExtractor", "_config": { - "hasher_id": "object_v0.1", - "function_info_extractor": { - "_class": "orcapod.hashing.function_info_extractors.FunctionSignatureExtractor", - "_config": { - "include_module": true, - "include_defaults": true - } + "include_module": true, + "include_defaults": true + } + }, + "type_handler_registry": { + "_class": "orcapod.hashing.semantic_hashing.type_handler_registry.TypeHandlerRegistry", + "_config": { + "handlers": [ + [{"_type": "builtins.bytes"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], + [{"_type": "builtins.bytearray"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.BytesHandler", "_config": {}}], + [{"_type": "pathlib.Path"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.PathContentHandler", "_config": {"file_hasher": {"_ref": "file_hasher"}}}], + [{"_type": "uuid.UUID"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.UUIDHandler", "_config": {}}], + [{"_type": "types.FunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "types.BuiltinFunctionType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "types.MethodType"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.FunctionHandler", "_config": {"function_info_extractor": {"_ref": "function_info_extractor"}}}], + [{"_type": "builtins.type"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.TypeObjectHandler", "_config": {}}], + [{"_type": "types.GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], + [{"_type": "typing._GenericAlias"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.GenericAliasHandler", "_config": {}}], + [{"_type": "typing._SpecialForm"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.SpecialFormHandler", "_config": {}}], + [{"_type": "pyarrow.Table"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {"arrow_hasher": {"_ref": "arrow_hasher"}}}], + [{"_type": "pyarrow.RecordBatch"}, {"_class": "orcapod.hashing.semantic_hashing.builtin_handlers.ArrowTableHandler", "_config": {"arrow_hasher": {"_ref": "arrow_hasher"}}}] + ] + } + }, + "semantic_hasher": { + "_class": "orcapod.hashing.semantic_hashing.semantic_hasher.BaseSemanticHasher", + "_config": { + "hasher_id": "semantic_v0.1", + "type_handler_registry": { + "_ref": "type_handler_registry" } } }, @@ -55,4 +86,4 @@ "Arrow logical serialization method" ] } -} \ No newline at end of file +} diff --git a/src/orcapod/contexts/registry.py b/src/orcapod/contexts/registry.py index e3f09891..0d09c0c9 100644 --- a/src/orcapod/contexts/registry.py +++ b/src/orcapod/contexts/registry.py @@ -6,13 +6,16 @@ """ import json - - +import logging from pathlib import Path from typing import Any -import logging + +from orcapod.contexts.core import ( + ContextResolutionError, + ContextValidationError, + DataContext, +) from orcapod.utils.object_spec import parse_objectspec -from .core import DataContext, ContextValidationError, ContextResolutionError logger = logging.getLogger(__name__) @@ -142,7 +145,8 @@ def _load_spec_file(self, json_file: Path) -> None: "version", "type_converter", "arrow_hasher", - "object_hasher", + "semantic_hasher", + "type_handler_registry", ] missing_fields = [field for field in required_fields if field not in spec] if missing_fields: @@ -257,35 +261,31 @@ def get_context(self, context_string: str | None = None) -> DataContext: f"Failed to resolve context '{context_string}': {e}" ) + # Top-level keys that are metadata, not instantiable components. + _METADATA_KEYS = frozenset({"context_key", "version", "description", "metadata"}) + def _create_context_from_spec(self, spec: dict[str, Any]) -> DataContext: - """Create DataContext instance from validated specification.""" + """Create DataContext instance from validated specification. + + All top-level keys whose value is a dict with a ``_class`` entry are + built in JSON order and added to a shared ``ref_lut``. This means + new versioned components (e.g. ``file_hasher``, ``function_info_extractor``) + can be added to the JSON without touching this method — they are + instantiated automatically and become available as ``_ref`` targets for + later components in the same file. + """ try: - # Parse each component using ObjectSpec context_key = spec["context_key"] version = spec["version"] description = spec.get("description", "") - ref_lut = {} + ref_lut: dict[str, Any] = {} - logger.debug(f"Creating type converter for {version}") - ref_lut["semantic_registry"] = parse_objectspec( - spec["semantic_registry"], - ref_lut=ref_lut, - ) - - logger.debug(f"Creating type converter for {version}") - ref_lut["type_converter"] = parse_objectspec( - spec["type_converter"], ref_lut=ref_lut - ) - - logger.debug(f"Creating arrow hasher for {version}") - ref_lut["arrow_hasher"] = parse_objectspec( - spec["arrow_hasher"], ref_lut=ref_lut - ) - - logger.debug(f"Creating object hasher for {version}") - ref_lut["object_hasher"] = parse_objectspec( - spec["object_hasher"], ref_lut=ref_lut - ) + for key, value in spec.items(): + if key in self._METADATA_KEYS: + continue + if isinstance(value, dict) and "_class" in value: + logger.debug(f"Creating {key} for context {version}") + ref_lut[key] = parse_objectspec(value, ref_lut=ref_lut) return DataContext( context_key=context_key, @@ -293,7 +293,8 @@ def _create_context_from_spec(self, spec: dict[str, Any]) -> DataContext: description=description, type_converter=ref_lut["type_converter"], arrow_hasher=ref_lut["arrow_hasher"], - object_hasher=ref_lut["object_hasher"], + semantic_hasher=ref_lut["semantic_hasher"], + type_handler_registry=ref_lut["type_handler_registry"], ) except Exception as e: diff --git a/src/orcapod/core/__init__.py b/src/orcapod/core/__init__.py index 24f5aabb..724c67c1 100644 --- a/src/orcapod/core/__init__.py +++ b/src/orcapod/core/__init__.py @@ -1,7 +1,5 @@ -from .trackers import DEFAULT_TRACKER_MANAGER -from .system_constants import constants +from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER __all__ = [ "DEFAULT_TRACKER_MANAGER", - "constants", ] diff --git a/src/orcapod/core/arrow_data_utils.py b/src/orcapod/core/arrow_data_utils.py deleted file mode 100644 index 71942081..00000000 --- a/src/orcapod/core/arrow_data_utils.py +++ /dev/null @@ -1,115 +0,0 @@ -# Collection of functions to work with Arrow table data that underlies streams and/or datagrams -from orcapod.utils.lazy_module import LazyModule -from typing import TYPE_CHECKING -from orcapod.core.system_constants import constants -from collections.abc import Collection - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - - -def drop_columns_with_prefix( - table: "pa.Table", - prefix: str | tuple[str, ...], - exclude_columns: Collection[str] = (), -) -> "pa.Table": - """Drop columns with a specific prefix from an Arrow table.""" - columns_to_drop = [ - col - for col in table.column_names - if col.startswith(prefix) and col not in exclude_columns - ] - return table.drop(columns=columns_to_drop) - - -def drop_system_columns( - table: "pa.Table", - system_column_prefix: tuple[str, ...] = ( - constants.META_PREFIX, - constants.DATAGRAM_PREFIX, - ), -) -> "pa.Table": - return drop_columns_with_prefix(table, system_column_prefix) - - -def get_system_columns(table: "pa.Table") -> "pa.Table": - """Get system columns from an Arrow table.""" - return table.select( - [ - col - for col in table.column_names - if col.startswith(constants.SYSTEM_TAG_PREFIX) - ] - ) - - -def add_system_tag_column( - table: "pa.Table", - system_tag_column_name: str, - system_tag_values: str | Collection[str], -) -> "pa.Table": - """Add a system tags column to an Arrow table.""" - if not table.column_names: - raise ValueError("Table is empty") - if isinstance(system_tag_values, str): - system_tag_values = [system_tag_values] * table.num_rows - else: - system_tag_values = list(system_tag_values) - if len(system_tag_values) != table.num_rows: - raise ValueError( - "Length of system_tag_values must match number of rows in the table." - ) - if not system_tag_column_name.startswith(constants.SYSTEM_TAG_PREFIX): - system_tag_column_name = ( - f"{constants.SYSTEM_TAG_PREFIX}{system_tag_column_name}" - ) - tags_column = pa.array(system_tag_values, type=pa.large_string()) - return table.append_column(system_tag_column_name, tags_column) - - -def append_to_system_tags(table: "pa.Table", value: str) -> "pa.Table": - """Append a value to the system tags column in an Arrow table.""" - if not table.column_names: - raise ValueError("Table is empty") - - column_name_map = { - c: f"{c}:{value}" if c.startswith(constants.SYSTEM_TAG_PREFIX) else c - for c in table.column_names - } - return table.rename_columns(column_name_map) - - -def add_source_info( - table: "pa.Table", - source_info: str | Collection[str] | None, - exclude_prefixes: Collection[str] = ( - constants.META_PREFIX, - constants.DATAGRAM_PREFIX, - ), - exclude_columns: Collection[str] = (), -) -> "pa.Table": - """Add source information to an Arrow table.""" - # Create a new column with the source information - if source_info is None or isinstance(source_info, str): - source_column = [source_info] * table.num_rows - elif isinstance(source_info, Collection): - if len(source_info) != table.num_rows: - raise ValueError( - "Length of source_info collection must match number of rows in the table." - ) - source_column = source_info - - # identify columns for which source columns should be created - - for col in table.column_names: - if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns: - continue - source_column = pa.array( - [f"{source_val}::{col}" for source_val in source_column], - type=pa.large_string(), - ) - table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column) - - return table diff --git a/src/orcapod/core/base.py b/src/orcapod/core/base.py index 828c3718..e713e105 100644 --- a/src/orcapod/core/base.py +++ b/src/orcapod/core/base.py @@ -1,38 +1,49 @@ +from __future__ import annotations + import logging -from abc import ABC +from abc import ABC, abstractmethod +from datetime import datetime, timezone from typing import Any -from orcapod import DEFAULT_CONFIG, contexts -from orcapod.config import Config -from orcapod.protocols import hashing_protocols as hp +import orcapod.contexts as contexts +from orcapod.config import DEFAULT_CONFIG, Config +from orcapod.types import ContentHash logger = logging.getLogger(__name__) -class LablableBase: +# Base classes for Orcapod core components, providing common functionality. + + +class LabelableMixin: + """ + Mixin class for objects that can have a label. Provides a mechanism to compute a label based on the object's content. + By default, explicitly set label will always take precedence over computed label and inferred label. + """ + def __init__(self, label: str | None = None, **kwargs): self._label = label super().__init__(**kwargs) @property - def has_assigned_label(self) -> bool: + def label(self) -> str: """ - Check if the label is explicitly set for this object. + Get the label of this object. Returns: - bool: True if the label is explicitly set, False otherwise. + str | None: The label of the object, or None if not set. """ - return self._label is not None + return self._label or self.computed_label() or self.__class__.__name__ @property - def label(self) -> str: + def has_assigned_label(self) -> bool: """ - Get the label of this object. + Check if the label has been explicitly set for this object. Returns: - str | None: The label of the object, or None if not set. + bool: True if the label is explicitly set, False otherwise. """ - return self._label or self.computed_label() or self.__class__.__name__ + return self._label is not None @label.setter def label(self, label: str | None) -> None: @@ -52,18 +63,23 @@ def computed_label(self) -> str | None: return None -class ContextAwareConfigurableBase(ABC): +class DataContextMixin: + """ + Mixin to associate data context and an Orcapod config with an object. Deriving class allows data context and Orcapod config to be + explicitly specified and if not provided, use the default data context and Orcapod config. + """ + def __init__( self, data_context: str | contexts.DataContext | None = None, - orcapod_config: Config | None = None, + config: Config | None = None, **kwargs, ): super().__init__(**kwargs) - if orcapod_config is None: - orcapod_config = DEFAULT_CONFIG - self._orcapod_config = orcapod_config self._data_context = contexts.resolve_context(data_context) + if config is None: + config = DEFAULT_CONFIG # DEFAULT_CONFIG as defined in orcapod/config.py + self._orcapod_config = config @property def orcapod_config(self) -> Config: @@ -73,34 +89,44 @@ def orcapod_config(self) -> Config: def data_context(self) -> contexts.DataContext: return self._data_context + # TODO: re-evaluate whether changing data context should be allowed + @data_context.setter + def data_context(self, context: str | contexts.DataContext | None) -> None: + self._data_context = contexts.resolve_context(context) + @property def data_context_key(self) -> str: """Return the data context key.""" return self._data_context.context_key -class ContentIdentifiableBase(ContextAwareConfigurableBase): +class ContentIdentifiableBase(DataContextMixin, ABC): """ Base class for content-identifiable objects. This class provides a way to define objects that can be uniquely identified based on their content rather than their identity in memory. Specifically, the identity of the object is determined by the structure returned by the `identity_structure` method. The hash of the object is computed based on the `identity_structure` using the provided `ObjectHasher`, - which defaults to the one returned by `get_default_object_hasher`. + which defaults to the one returned by `get_default_semantic_hasher`. Two content-identifiable objects are considered equal if their `identity_structure` returns the same value. """ - def __init__(self, **kwargs) -> None: + def __init__( + self, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, + ) -> None: """ Initialize the ContentHashable with an optional ObjectHasher. Args: identity_structure_hasher (ObjectHasher | None): An instance of ObjectHasher to use for hashing. """ - super().__init__(**kwargs) - self._cached_content_hash: hp.ContentHash | None = None + super().__init__(data_context=data_context, config=config) + self._content_hash_cache: dict[str, ContentHash] = {} self._cached_int_hash: int | None = None + @abstractmethod def identity_structure(self) -> Any: """ Return a structure that represents the identity of this object. @@ -112,23 +138,38 @@ def identity_structure(self) -> Any: Returns: Any: A structure representing this object's content, or None to use default hash """ - raise NotImplementedError("Subclasses must implement identity_structure") + ... - def content_hash(self) -> hp.ContentHash: + def content_hash(self, hasher=None) -> ContentHash: """ Compute a hash based on the content of this object. + The hasher is used for the entire recursive computation — all nested + ContentIdentifiable objects are resolved using the same hasher, ensuring + one consistent context per hash computation. + + Args: + hasher: Optional semantic hasher to use. When omitted, the hasher + is resolved from this object's data_context and the result is + cached by hasher_id for reuse. When provided explicitly, the + result is also cached by hasher_id, so repeated calls with the + same hasher are free. + Returns: - bytes: A byte representation of the hash based on the content. - If no identity structure is provided, return None. - """ - if self._cached_content_hash is None: - structure = self.identity_structure() - # processed_structure = process_structure(structure) - self._cached_content_hash = self.data_context.object_hasher.hash_object( - structure + ContentHash: Stable, content-based hash of the object. + """ + if hasher is None: + hasher = self.data_context.semantic_hasher + cache_key = hasher.hasher_id + + def content_resolver(obj): + return obj.content_hash(hasher) + + if cache_key not in self._content_hash_cache: + self._content_hash_cache[cache_key] = hasher.hash_object( + self.identity_structure(), resolver=content_resolver ) - return self._cached_content_hash + return self._content_hash_cache[cache_key] def __hash__(self) -> int: """ @@ -157,5 +198,151 @@ def __eq__(self, other: object) -> bool: return self.identity_structure() == other.identity_structure() -class LabeledContentIdentifiableBase(ContentIdentifiableBase, LablableBase): - pass +class PipelineElementBase(DataContextMixin, ABC): + """ + Mixin providing pipeline-level identity for objects that participate in a + pipeline graph. + + This is a parallel identity chain to ContentIdentifiableBase. Content + identity (content_hash) captures the precise, data-inclusive identity of + an object. Pipeline identity (pipeline_hash) captures only what is + structurally meaningful for pipeline database path scoping: schemas and + the recursive topology of upstream computation, with no data content. + + Must be used alongside DataContextMixin (directly or via TraceableBase), + which provides self.data_context used by pipeline_hash(). + + The only class that needs to override pipeline_identity_structure() in a + non-trivial way is RootSource, which returns (tag_schema, packet_schema) + as the base case of the recursion. All other pipeline elements return + structures built from the pipeline_hash() values of their upstream + components — ContentHash objects are terminal in the semantic hasher, so + no special hashing mode is required. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._pipeline_hash_cache: dict[str, ContentHash] = {} + + @abstractmethod + def pipeline_identity_structure(self) -> Any: + """ + Return a structure representing this element's pipeline identity. + + Implementations may return raw ContentIdentifiable objects (such as + upstream stream or pod references) as leaves — the pipeline resolver + threaded through pipeline_hash() ensures that PipelineElementProtocol + objects are resolved via pipeline_hash() and other ContentIdentifiable + objects via content_hash(), both using the same hasher throughout. + """ + ... + + def pipeline_hash(self, hasher=None) -> ContentHash: + """ + Return the pipeline-level hash of this element, computed from + pipeline_identity_structure() and cached by hasher_id. + + The hasher is used for the entire recursive computation — all nested + objects are resolved using the same hasher, ensuring one consistent + context per hash computation. + + Args: + hasher: Optional semantic hasher to use. When omitted, resolved + from this object's data_context. + """ + if hasher is None: + hasher = self.data_context.semantic_hasher + cache_key = hasher.hasher_id + if cache_key not in self._pipeline_hash_cache: + from orcapod.protocols.hashing_protocols import PipelineElementProtocol + + def pipeline_resolver(obj: Any) -> ContentHash: + if isinstance(obj, PipelineElementProtocol): + return obj.pipeline_hash(hasher) + return obj.content_hash(hasher) + + self._pipeline_hash_cache[cache_key] = hasher.hash_object( + self.pipeline_identity_structure(), resolver=pipeline_resolver + ) + return self._pipeline_hash_cache[cache_key] + + +class TemporalMixin: + """ + Mixin class that adds temporal functionality to an Orcapod entity. + It provides methods to track and manage the last modified timestamp of the entity. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._update_modified_time() + + @property + def last_modified(self) -> datetime | None: + """ + When this object's content was last modified. + + Returns: + datetime: Content last modified timestamp (timezone-aware) + None: Modification time unknown (assume always changed) + """ + return self._modified_time + + def _set_modified_time(self, modified_time: datetime | None) -> None: + """ + Set the modified time for this object. + + Args: + modified_time (datetime | None): The modified time to set. If None, clears the modified time. + """ + self._modified_time = modified_time + + def _update_modified_time(self) -> None: + """ + Update the modified time to the current time. + """ + self._modified_time = datetime.now(timezone.utc) + + def updated_since(self, timestamp: datetime) -> bool: + """ + Check if the object has been updated since the given timestamp. + + Args: + timestamp (datetime): The timestamp to compare against. + + Returns: + bool: True if the object has been updated since the given timestamp, False otherwise. + """ + # if _modified_time is None, consider it always updated + if self._modified_time is None: + return True + return self._modified_time > timestamp + + +class TraceableBase( + TemporalMixin, LabelableMixin, ContentIdentifiableBase, PipelineElementBase +): + """ + Base class for all default traceable entities, providing common functionality + including data context awareness, content-based identity, (semantic) labeling, + modification timestamp, and pipeline identity. + + Every computation-node class (streams, packet functions, pods) inherits from + TraceableBase, getting both content identity (content_hash) and pipeline + identity (pipeline_hash) automatically. + """ + + def __init__( + self, + label: str | None = None, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, + ): + # Init provided here for explicit listing of parmeters + super().__init__(label=label, data_context=data_context, config=config) + + def __repr__(self): + return self.__class__.__name__ + + def __str__(self): + return self.label diff --git a/src/orcapod/core/cached_function_pod.py b/src/orcapod/core/cached_function_pod.py new file mode 100644 index 00000000..771009dc --- /dev/null +++ b/src/orcapod/core/cached_function_pod.py @@ -0,0 +1,209 @@ +"""CachedFunctionPod — pod-level caching wrapper that intercepts process_packet().""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.core.function_pod import WrappedFunctionPod +from orcapod.core.result_cache import ResultCache +from orcapod.protocols.core_protocols import ( + FunctionPodProtocol, + PacketProtocol, + StreamProtocol, + TagProtocol, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.pipeline.logging_capture import CapturedLogs + +logger = logging.getLogger(__name__) + + +class CachedFunctionPod(WrappedFunctionPod): + """Pod-level caching wrapper that intercepts ``process_packet()``. + + Caches at the ``process_packet(tag, packet)`` level using only the + **input packet content hash** as the cache key — the output of a + packet function depends solely on the packet, not the tag. + + Tag-level provenance tracking (tag + system tags + packet hash) is + handled separately by ``FunctionNode.add_pipeline_record``. + + Uses a shared ``ResultCache`` for lookup/store/conflict-resolution + logic (same mechanism as ``CachedPacketFunction``). + """ + + # Expose RESULT_COMPUTED_FLAG from the shared ResultCache + RESULT_COMPUTED_FLAG = ResultCache.RESULT_COMPUTED_FLAG + + def __init__( + self, + function_pod: FunctionPodProtocol, + result_database: ArrowDatabaseProtocol, + record_path_prefix: tuple[str, ...] = (), + auto_flush: bool = True, + **kwargs, + ) -> None: + super().__init__(function_pod, **kwargs) + self._record_path_prefix = record_path_prefix + self._cache = ResultCache( + result_database=result_database, + record_path=record_path_prefix + self.uri, + auto_flush=auto_flush, + ) + + @property + def _result_database(self) -> ArrowDatabaseProtocol: + """The underlying result database (for FunctionNode access).""" + return self._cache.result_database + + @property + def record_path(self) -> tuple[str, ...]: + """Return the path to the cached records in the result store.""" + return self._cache.record_path + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a packet with pod-level caching. + + The cache key is the input packet content hash only — the function + output depends solely on the packet, not the tag. The output + packet carries a ``RESULT_COMPUTED_FLAG`` meta value: ``True`` if + freshly computed, ``False`` if retrieved from cache. + + Args: + tag: The tag associated with the packet. + packet: The input packet to process. + + Returns: + A ``(tag, output_packet)`` tuple; output_packet is ``None`` + if the inner function filters the packet out. + """ + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached + + tag, output = self._function_pod.process_packet(tag, packet) + if output is not None: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output + + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``. + + DB lookup and store are synchronous (DB protocol is sync), but the + actual computation uses the inner pod's ``async_process_packet`` + for true async execution. + """ + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached + + tag, output = await self._function_pod.async_process_packet(tag, packet) + if output is not None: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output + + def process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + """Process with pod-level caching, returning CapturedLogs alongside. + + On cache hit, returns empty CapturedLogs (no function was executed). + """ + from orcapod.pipeline.logging_capture import CapturedLogs + + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached, CapturedLogs(success=True) + + tag, output, captured = self._function_pod.process_packet_with_capture( + tag, packet + ) + if output is not None and captured.success: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output, captured + + async def async_process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + """Async counterpart of ``process_packet_with_capture``.""" + from orcapod.pipeline.logging_capture import CapturedLogs + + cached = self._cache.lookup(packet) + if cached is not None: + logger.info("Pod-level cache hit") + return tag, cached, CapturedLogs(success=True) + + tag, output, captured = await self._function_pod.async_process_packet_with_capture( + tag, packet + ) + if output is not None and captured.success: + pf = self._function_pod.packet_function + self._cache.store( + packet, + output, + variation_data=pf.get_function_variation_data(), + execution_data=pf.get_execution_data(), + ) + output = output.with_meta_columns(**{self.RESULT_COMPUTED_FLAG: True}) + return tag, output, captured + + def get_all_cached_outputs( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store for this pod.""" + return self._cache.get_all_records( + include_system_columns=include_system_columns + ) + + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + """Invoke the inner pod but with pod-level caching on process_packet. + + The stream returned uses *this* pod's ``process_packet`` (which + includes caching) rather than the inner pod's. + """ + from orcapod.core.function_pod import FunctionPodStream + + # Validate and prepare the input stream + input_stream = self._function_pod.handle_input_streams(*streams) + self._function_pod.validate_inputs(*streams) + + return FunctionPodStream( + function_pod=self, + input_stream=input_stream, + label=label, + ) diff --git a/src/orcapod/core/datagrams/__init__.py b/src/orcapod/core/datagrams/__init__.py index 0c255e36..779ff8c5 100644 --- a/src/orcapod/core/datagrams/__init__.py +++ b/src/orcapod/core/datagrams/__init__.py @@ -1,13 +1,8 @@ -from .arrow_datagram import ArrowDatagram -from .arrow_tag_packet import ArrowTag, ArrowPacket -from .dict_datagram import DictDatagram -from .dict_tag_packet import DictTag, DictPacket +from .datagram import Datagram +from .tag_packet import Packet, Tag __all__ = [ - "ArrowDatagram", - "ArrowTag", - "ArrowPacket", - "DictDatagram", - "DictTag", - "DictPacket", + "Datagram", + "Tag", + "Packet", ] diff --git a/src/orcapod/core/datagrams/arrow_datagram.py b/src/orcapod/core/datagrams/arrow_datagram.py deleted file mode 100644 index 9e5a7a54..00000000 --- a/src/orcapod/core/datagrams/arrow_datagram.py +++ /dev/null @@ -1,842 +0,0 @@ -import logging -from collections.abc import Collection, Iterator, Mapping -from typing import Self, TYPE_CHECKING - - -from orcapod import contexts -from orcapod.core.datagrams.base import BaseDatagram -from orcapod.core.system_constants import constants -from orcapod.types import DataValue, PythonSchema -from orcapod.protocols.hashing_protocols import ContentHash -from orcapod.utils import arrow_utils -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -logger = logging.getLogger(__name__) -DEBUG = False - - -class ArrowDatagram(BaseDatagram): - """ - Immutable datagram implementation using PyArrow Table as storage backend. - - This implementation provides high-performance columnar data operations while - maintaining the datagram interface. It efficiently handles type conversions, - semantic processing, and interoperability with Arrow-based tools. - - The underlying table is split into separate components: - - Data table: Primary business data columns - - Meta table: Internal system metadata with {orcapod.META_PREFIX} ('__') prefixes - - Context table: Data context information with {orcapod.CONTEXT_KEY} - - Future Packet subclass will also handle: - - Source info: Data provenance with {orcapod.SOURCE_PREFIX} ('_source_') prefixes - - When exposing to external tools, semantic types are encoded as - `_{semantic_type}_` prefixes (_path_config_file, _id_user_name). - - All operations return new instances, preserving immutability. - - Example: - >>> table = pa.Table.from_pydict({ - ... "user_id": [123], - ... "name": ["Alice"], - ... "__pipeline_version": ["v2.1.0"], - ... "{orcapod.CONTEXT_KEY}": ["financial_v1"] - ... }) - >>> datagram = ArrowDatagram(table) - >>> updated = datagram.update(name="Alice Smith") - """ - - def __init__( - self, - table: "pa.Table", - meta_info: Mapping[str, DataValue] | None = None, - data_context: str | contexts.DataContext | None = None, - ) -> None: - """ - Initialize ArrowDatagram from PyArrow Table. - - Args: - table: PyArrow Table containing the data. Must have exactly one row. - semantic_converter: Optional converter for semantic type handling. - If None, will be created based on the data context and table schema. - data_context: Context key string or DataContext object. - If None and table contains context column, will extract from table. - - Raises: - ValueError: If table doesn't contain exactly one row. - - Note: - The input table is automatically split into data, meta, and context - components based on column naming conventions. - """ - # Validate table has exactly one row for datagram - if len(table) != 1: - raise ValueError( - "Table must contain exactly one row to be a valid datagram." - ) - - # normalize the table to large data types (for Polars compatibility) - table = arrow_utils.normalize_table_to_large_types(table) - - # Split table into data, meta, and context components - context_columns = ( - [constants.CONTEXT_KEY] - if constants.CONTEXT_KEY in table.column_names - else [] - ) - - # Extract context table from passed in table if present - if constants.CONTEXT_KEY in table.column_names and data_context is None: - context_table = table.select([constants.CONTEXT_KEY]) - data_context = context_table[constants.CONTEXT_KEY].to_pylist()[0] - - # Initialize base class with data context - super().__init__(data_context) - - meta_columns = [ - col for col in table.column_names if col.startswith(constants.META_PREFIX) - ] - # Split table into components - self._data_table = table.drop_columns(context_columns + meta_columns) - self._meta_table = table.select(meta_columns) if meta_columns else None - - if len(self._data_table.column_names) == 0: - raise ValueError("Data table must contain at least one data column.") - - # process supplemented meta info if provided - if meta_info is not None: - # make sure it has the expected prefixes - meta_info = { - ( - f"{constants.META_PREFIX}{k}" - if not k.startswith(constants.META_PREFIX) - else k - ): v - for k, v in meta_info.items() - } - new_meta_table = ( - self._data_context.type_converter.python_dicts_to_arrow_table( - [meta_info], - ) - ) - - if self._meta_table is None: - self._meta_table = new_meta_table - else: - # drop any column that will be overwritten by the new meta table - keep_meta_columns = [ - c - for c in self._meta_table.column_names - if c not in new_meta_table.column_names - ] - self._meta_table = arrow_utils.hstack_tables( - self._meta_table.select(keep_meta_columns), new_meta_table - ) - - # Create data context table - data_context_schema = pa.schema({constants.CONTEXT_KEY: pa.large_string()}) - self._data_context_table = pa.Table.from_pylist( - [{constants.CONTEXT_KEY: self._data_context.context_key}], - schema=data_context_schema, - ) - - # Initialize caches - self._cached_python_schema: PythonSchema | None = None - self._cached_python_dict: dict[str, DataValue] | None = None - self._cached_meta_python_schema: PythonSchema | None = None - self._cached_content_hash: ContentHash | None = None - - # 1. Core Properties (Identity & Structure) - @property - def meta_columns(self) -> tuple[str, ...]: - """Return tuple of meta column names.""" - if self._meta_table is None: - return () - return tuple(self._meta_table.column_names) - - # 2. Dict-like Interface (Data Access) - def __getitem__(self, key: str) -> DataValue: - """Get data column value by key.""" - if key not in self._data_table.column_names: - raise KeyError(f"Data column '{key}' not found") - - return self.as_dict()[key] - - def __contains__(self, key: str) -> bool: - """Check if data column exists.""" - return key in self._data_table.column_names - - def __iter__(self) -> Iterator[str]: - """Iterate over data column names.""" - return iter(self._data_table.column_names) - - def get(self, key: str, default: DataValue = None) -> DataValue: - """Get data column value with default.""" - if key in self._data_table.column_names: - return self.as_dict()[key] - return default - - # 3. Structural Information - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> tuple[str, ...]: - """Return tuple of column names.""" - # Start with data columns - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - result_keys = list(self._data_table.column_names) - - # Add context if requested - if include_context: - result_keys.append(constants.CONTEXT_KEY) - - # Add meta columns if requested - if include_meta_columns: - if include_meta_columns is True: - result_keys.extend(self.meta_columns) - elif isinstance(include_meta_columns, Collection): - # Filter meta columns by prefix matching - filtered_meta_cols = [ - col - for col in self.meta_columns - if any(col.startswith(prefix) for prefix in include_meta_columns) - ] - result_keys.extend(filtered_meta_cols) - - return tuple(result_keys) - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> PythonSchema: - """ - Return Python schema for the datagram. - - Args: - include_meta_columns: Whether to include meta column types. - - True: include all meta column types - - Collection[str]: include meta column types matching these prefixes - - False: exclude meta column types - include_context: Whether to include context type - - Returns: - Python schema - """ - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - # Get data schema (cached) - if self._cached_python_schema is None: - self._cached_python_schema = ( - self._data_context.type_converter.arrow_schema_to_python_schema( - self._data_table.schema - ) - ) - - schema = dict(self._cached_python_schema) - - # Add context if requested - if include_context: - schema[constants.CONTEXT_KEY] = str - - # Add meta schema if requested - if include_meta_columns and self._meta_table is not None: - if self._cached_meta_python_schema is None: - self._cached_meta_python_schema = ( - self._data_context.type_converter.arrow_schema_to_python_schema( - self._meta_table.schema - ) - ) - meta_schema = dict(self._cached_meta_python_schema) - if include_meta_columns is True: - schema.update(meta_schema) - elif isinstance(include_meta_columns, Collection): - filtered_meta_schema = { - k: v - for k, v in meta_schema.items() - if any(k.startswith(prefix) for prefix in include_meta_columns) - } - schema.update(filtered_meta_schema) - - return schema - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Schema": - """ - Return the PyArrow schema for this datagram. - - Args: - include_meta_columns: Whether to include meta columns in the schema. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include context column in the schema - - Returns: - PyArrow schema representing the datagram's structure - """ - # order matters - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - all_schemas = [self._data_table.schema] - - # Add context schema if requested - if include_context: - # TODO: reassess the efficiency of this approach - all_schemas.append(self._data_context_table.schema) - - # Add meta schema if requested - if include_meta_columns and self._meta_table is not None: - if include_meta_columns is True: - meta_schema = self._meta_table.schema - elif isinstance(include_meta_columns, Collection): - # Filter meta schema by prefix matching - matched_fields = [ - field - for field in self._meta_table.schema - if any( - field.name.startswith(prefix) for prefix in include_meta_columns - ) - ] - if matched_fields: - meta_schema = pa.schema(matched_fields) - else: - meta_schema = None - else: - meta_schema = None - - if meta_schema is not None: - all_schemas.append(meta_schema) - - return arrow_utils.join_arrow_schemas(*all_schemas) - - def content_hash(self) -> ContentHash: - """ - Calculate and return content hash of the datagram. - Only includes data columns, not meta columns or context. - - Returns: - Hash string of the datagram content - """ - if self._cached_content_hash is None: - self._cached_content_hash = self._data_context.arrow_hasher.hash_table( - self._data_table, - ) - return self._cached_content_hash - - # 4. Format Conversions (Export) - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> dict[str, DataValue]: - """ - Return dictionary representation of the datagram. - - Args: - include_meta_columns: Whether to include meta columns. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include context key - - Returns: - Dictionary representation - """ - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - # Get data dict (cached) - if self._cached_python_dict is None: - self._cached_python_dict = ( - self._data_context.type_converter.arrow_table_to_python_dicts( - self._data_table - )[0] - ) - - result_dict = dict(self._cached_python_dict) - - # Add context if requested - if include_context: - result_dict[constants.CONTEXT_KEY] = self._data_context.context_key - - # Add meta data if requested - if include_meta_columns and self._meta_table is not None: - if include_meta_columns is True: - meta_dict = self._meta_table.to_pylist()[0] - elif isinstance(include_meta_columns, Collection): - meta_dict = self._meta_table.to_pylist()[0] - # Include only meta columns matching prefixes - meta_dict = { - k: v - for k, v in meta_dict.items() - if any(k.startswith(prefix) for prefix in include_meta_columns) - } - if meta_dict is not None: - result_dict.update(meta_dict) - - return result_dict - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Table": - """ - Convert the datagram to an Arrow table. - - Args: - include_meta_columns: Whether to include meta columns. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include the context column - - Returns: - Arrow table representation - """ - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - all_tables = [self._data_table] - - # Add context if requested - if include_context: - all_tables.append(self._data_context_table) - - # Add meta columns if requested - if include_meta_columns and self._meta_table is not None: - meta_table = None - if include_meta_columns is True: - meta_table = self._meta_table - elif isinstance(include_meta_columns, Collection): - # Filter meta columns by prefix matching - # ensure all given prefixes start with the meta prefix - prefixes = ( - f"{constants.META_PREFIX}{prefix}" - if not prefix.startswith(constants.META_PREFIX) - else prefix - for prefix in include_meta_columns - ) - - matched_cols = [ - col - for col in self._meta_table.column_names - if any(col.startswith(prefix) for prefix in prefixes) - ] - if matched_cols: - meta_table = self._meta_table.select(matched_cols) - else: - meta_table = None - - if meta_table is not None: - all_tables.append(meta_table) - - return arrow_utils.hstack_tables(*all_tables) - - def as_arrow_compatible_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> dict[str, DataValue]: - """ - Return dictionary representation compatible with Arrow. - - Args: - include_meta_columns: Whether to include meta columns. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include context key - - Returns: - Dictionary representation compatible with Arrow - """ - return self.as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ).to_pylist()[0] - - # 5. Meta Column Operations - def get_meta_value(self, key: str, default: DataValue = None) -> DataValue: - """ - Get a meta column value. - - Args: - key: Meta column key (with or without {orcapod.META_PREFIX} ('__') prefix) - default: Default value if not found - - Returns: - Meta column value - """ - if self._meta_table is None: - return default - - # Handle both prefixed and unprefixed keys - if not key.startswith(constants.META_PREFIX): - key = constants.META_PREFIX + key - - if key not in self._meta_table.column_names: - return default - - return self._meta_table[key].to_pylist()[0] - - def with_meta_columns(self, **meta_updates: DataValue) -> Self: - """ - Create a new ArrowDatagram with updated meta columns. - Maintains immutability by returning a new instance. - - Args: - **meta_updates: Meta column updates (keys will be prefixed with {orcapod.META_PREFIX} ('__') if needed) - - Returns: - New ArrowDatagram instance - """ - # Prefix the keys and prepare updates - prefixed_updates = {} - for k, v in meta_updates.items(): - if not k.startswith(constants.META_PREFIX): - k = constants.META_PREFIX + k - prefixed_updates[k] = v - - new_datagram = self.copy(include_cache=False) - - # Start with existing meta data - meta_dict = {} - if self._meta_table is not None: - meta_dict = self._meta_table.to_pylist()[0] - - # Apply updates - meta_dict.update(prefixed_updates) - - # TODO: properly handle case where meta data is None (it'll get inferred as NoneType) - - # Create new meta table - new_datagram._meta_table = ( - self._data_context.type_converter.python_dicts_to_arrow_table([meta_dict]) - if meta_dict - else None - ) - return new_datagram - - def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: - """ - Create a new ArrowDatagram with specified meta columns dropped. - Maintains immutability by returning a new instance. - - Args: - *keys: Meta column keys to drop (with or without {orcapod.META_PREFIX} ('__') prefix) - - Returns: - New ArrowDatagram instance without specified meta columns - """ - if self._meta_table is None: - return self # No meta columns to drop - - # Normalize keys to have prefixes - prefixed_keys = set() - for key in keys: - if not key.startswith(constants.META_PREFIX): - key = constants.META_PREFIX + key - prefixed_keys.add(key) - - missing_keys = prefixed_keys - set(self._meta_table.column_names) - if missing_keys and not ignore_missing: - raise KeyError( - f"Following meta columns do not exist and cannot be dropped: {sorted(missing_keys)}" - ) - - # Only drop columns that actually exist - existing_keys = prefixed_keys - missing_keys - - new_datagram = self.copy(include_cache=False) - if existing_keys: # Only drop if there are existing columns to drop - new_datagram._meta_table = self._meta_table.drop_columns( - list(existing_keys) - ) - - return new_datagram - - # 6. Data Column Operations - def select(self, *column_names: str) -> Self: - """ - Create a new ArrowDatagram with only specified data columns. - Maintains immutability by returning a new instance. - - Args: - *column_names: Data column names to keep - - Returns: - New ArrowDatagram instance with only specified data columns - """ - # Validate columns exist - missing_cols = set(column_names) - set(self._data_table.column_names) - if missing_cols: - raise ValueError(f"Columns not found: {missing_cols}") - - new_datagram = self.copy(include_cache=False) - new_datagram._data_table = new_datagram._data_table.select(column_names) - - return new_datagram - - def drop(self, *column_names: str, ignore_missing: bool = False) -> Self: - """ - Create a new ArrowDatagram with specified data columns dropped. - Maintains immutability by returning a new instance. - - Args: - *column_names: Data column names to drop - - Returns: - New ArrowDatagram instance without specified data columns - """ - - # Filter out specified data columns - missing = set(column_names) - set(self._data_table.column_names) - if missing and not ignore_missing: - raise KeyError( - f"Following columns do not exist and cannot be dropped: {sorted(missing)}" - ) - # Only keep columns that actually exist - existing_columns = tuple( - c for c in column_names if c in self._data_table.column_names - ) - - new_datagram = self.copy(include_cache=False) - if existing_columns: # Only drop if there are existing columns to drop - new_datagram._data_table = self._data_table.drop_columns( - list(existing_columns) - ) - # TODO: consider dropping extra semantic columns if they are no longer needed - return new_datagram - - def rename(self, column_mapping: Mapping[str, str]) -> Self: - """ - Create a new ArrowDatagram with data columns renamed. - Maintains immutability by returning a new instance. - - Args: - column_mapping: Mapping from old column names to new column names - - Returns: - New ArrowDatagram instance with renamed data columns - """ - # Create new schema with renamed fields, preserving original types - - if not column_mapping: - return self - - new_names = [column_mapping.get(k, k) for k in self._data_table.column_names] - - new_datagram = self.copy(include_cache=False) - new_datagram._data_table = new_datagram._data_table.rename_columns(new_names) - - return new_datagram - - def update(self, **updates: DataValue) -> Self: - """ - Create a new ArrowDatagram with specific column values updated. - - Args: - **updates: Column names and their new values - - Returns: - New ArrowDatagram instance with updated values - - Raises: - KeyError: If any specified column doesn't exist - - Example: - # Convert relative path to absolute path - updated = datagram.update(file_path="/absolute/path/to/file.txt") - - # Update multiple values - updated = datagram.update(status="processed", file_path="/new/path") - """ - # Only update if there are columns to update - if not updates: - return self - - # Validate all columns exist - missing_cols = set(updates.keys()) - set(self._data_table.column_names) - if missing_cols: - raise KeyError( - f"Only existing columns can be updated. Following columns were not found: {sorted(missing_cols)}" - ) - - new_datagram = self.copy(include_cache=False) - - # use existing schema - sub_schema = arrow_utils.schema_select( - new_datagram._data_table.schema, list(updates.keys()) - ) - - update_table = self._data_context.type_converter.python_dicts_to_arrow_table( - [updates], arrow_schema=sub_schema - ) - - new_datagram._data_table = arrow_utils.hstack_tables( - self._data_table.drop_columns(list(updates.keys())), update_table - ).select(self._data_table.column_names) # adjsut the order to match original - - return new_datagram - - def with_columns( - self, - column_types: Mapping[str, type] | None = None, - **updates: DataValue, - ) -> Self: - """ - Create a new ArrowDatagram with new data columns added. - Maintains immutability by returning a new instance. - - Args: - column_updates: New data columns as a mapping - column_types: Optional type specifications for new columns - **kwargs: New data columns as keyword arguments - - Returns: - New ArrowDatagram instance with new data columns added - - Raises: - ValueError: If any column already exists (use update() instead) - """ - # Combine explicit updates with kwargs - - if not updates: - return self - - # Error if any of the columns already exists - existing_overlaps = set(updates.keys()) & set(self._data_table.column_names) - if existing_overlaps: - raise ValueError( - f"Columns already exist: {sorted(existing_overlaps)}. " - f"Use update() to modify existing columns." - ) - - # create a copy and perform in-place updates - new_datagram = self.copy() - - # TODO: consider simplifying this conversion logic - - # TODO: cleanup the handling of typespec python schema and various conversion points - new_data_table = self._data_context.type_converter.python_dicts_to_arrow_table( - [updates], python_schema=dict(column_types) if column_types else None - ) - - # perform in-place update - new_datagram._data_table = arrow_utils.hstack_tables( - new_datagram._data_table, new_data_table - ) - - return new_datagram - - # 7. Context Operations - def with_context_key(self, new_context_key: str) -> Self: - """ - Create a new ArrowDatagram with a different data context key. - Maintains immutability by returning a new instance. - - Args: - new_context_key: New data context key string - - Returns: - New ArrowDatagram instance with new context - """ - # TODO: consider if there is a more efficient way to handle context - # Combine all tables for reconstruction - - new_datagram = self.copy(include_cache=False) - new_datagram._data_context = contexts.resolve_context(new_context_key) - return new_datagram - - # 8. Utility Operations - def copy(self, include_cache: bool = True) -> Self: - """Return a copy of the datagram.""" - new_datagram = super().copy() - - new_datagram._data_table = self._data_table - new_datagram._meta_table = self._meta_table - new_datagram._data_context = self._data_context - - if include_cache: - new_datagram._cached_python_schema = self._cached_python_schema - new_datagram._cached_python_dict = self._cached_python_dict - new_datagram._cached_content_hash = self._cached_content_hash - new_datagram._cached_meta_python_schema = self._cached_meta_python_schema - else: - new_datagram._cached_python_schema = None - new_datagram._cached_python_dict = None - new_datagram._cached_content_hash = None - new_datagram._cached_meta_python_schema = None - - return new_datagram - - # 9. String Representations - def __str__(self) -> str: - """ - Return user-friendly string representation. - - Shows the datagram as a simple dictionary for user-facing output, - messages, and logging. Only includes data columns for clean output. - - Returns: - Dictionary-style string representation of data columns only. - - Example: - >>> str(datagram) - "{'user_id': 123, 'name': 'Alice'}" - >>> print(datagram) - {'user_id': 123, 'name': 'Alice'} - """ - return str(self.as_dict()) - - def __repr__(self) -> str: - """ - Return detailed string representation for debugging. - - Shows the datagram type and comprehensive information including - data columns, meta columns count, and context for debugging purposes. - - Returns: - Detailed representation with type and metadata information. - - Example: - >>> repr(datagram) - "ArrowDatagram(data={'user_id': 123, 'name': 'Alice'}, meta_columns=2, context='std:v1.0.0:abc123')" - """ - if DEBUG: - data_dict = self.as_dict() - meta_count = len(self.meta_columns) - context_key = self.data_context_key - - return ( - f"{self.__class__.__name__}(" - f"data={data_dict}, " - f"meta_columns={meta_count}, " - f"context='{context_key}'" - f")" - ) - else: - return str(self.as_dict()) diff --git a/src/orcapod/core/datagrams/arrow_tag_packet.py b/src/orcapod/core/datagrams/arrow_tag_packet.py deleted file mode 100644 index 24d2185d..00000000 --- a/src/orcapod/core/datagrams/arrow_tag_packet.py +++ /dev/null @@ -1,554 +0,0 @@ -import logging -from collections.abc import Collection, Mapping -from typing import Self, TYPE_CHECKING - - -from orcapod.core.system_constants import constants -from orcapod import contexts -from orcapod.semantic_types import infer_python_schema_from_pylist_data - -from orcapod.types import DataValue, PythonSchema -from orcapod.utils import arrow_utils - -from orcapod.core.datagrams.arrow_datagram import ArrowDatagram -from orcapod.utils.lazy_module import LazyModule - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - - -class ArrowTag(ArrowDatagram): - """ - A tag implementation using Arrow table backend. - - Represents a single-row Arrow table that can be converted to Python - dictionary representation while caching computed values for efficiency. - - Initialize with an Arrow table. - - Args: - table: Single-row Arrow table representing the tag - - Raises: - ValueError: If table doesn't contain exactly one row - """ - - def __init__( - self, - table: "pa.Table", - system_tags: Mapping[str, DataValue] | None = None, - data_context: str | contexts.DataContext | None = None, - ) -> None: - if len(table) != 1: - raise ValueError( - "ArrowTag should only contain a single row, " - "as it represents a single tag." - ) - super().__init__( - table=table, - data_context=data_context, - ) - extracted_system_tag_columns = [ - c - for c in self._data_table.column_names - if c.startswith(constants.SYSTEM_TAG_PREFIX) - ] - self._system_tags_dict: dict[str, DataValue] = ( - self._data_context.type_converter.arrow_table_to_python_dicts( - self._data_table.select(extracted_system_tag_columns) - )[0] - ) - self._system_tags_dict.update(system_tags or {}) - self._system_tags_python_schema = infer_python_schema_from_pylist_data( - [self._system_tags_dict] - ) - self._system_tags_table = ( - self._data_context.type_converter.python_dicts_to_arrow_table( - [self._system_tags_dict], python_schema=self._system_tags_python_schema - ) - ) - - self._data_table = self._data_table.drop_columns(extracted_system_tag_columns) - - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> tuple[str, ...]: - keys = super().keys( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - keys += tuple(self._system_tags_dict.keys()) - return keys - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> PythonSchema: - """Return copy of the Python schema.""" - schema = super().types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - schema.update(self._system_tags_python_schema) - return schema - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> "pa.Schema": - """ - Return the PyArrow schema for this datagram. - - Args: - include_data_context: Whether to include data context column in the schema - include_source: Whether to include source info columns in the schema - - Returns: - PyArrow schema representing the datagram's structure - """ - schema = super().arrow_schema( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - return arrow_utils.join_arrow_schemas( - schema, self._system_tags_table.schema - ) - return schema - - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> dict[str, DataValue]: - """ - Convert to dictionary representation. - - Args: - include_source: Whether to include source info fields - - Returns: - Dictionary representation of the packet - """ - return_dict = super().as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - return_dict.update(self._system_tags_dict) - return return_dict - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> "pa.Table": - table = super().as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if ( - include_all_info or include_system_tags - ) and self._system_tags_table.num_columns > 0: - # add system_tags only if there are actual system tag columns - table = arrow_utils.hstack_tables(table, self._system_tags_table) - return table - - def as_datagram( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_system_tags: bool = False, - ) -> ArrowDatagram: - table = self.as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_system_tags=include_system_tags, - ) - return ArrowDatagram( - table, - data_context=self._data_context, - ) - - def system_tags(self) -> dict[str, DataValue | None]: - """ - Return system tags for all keys. - - Returns: - Copy of the dictionary mapping field names to their source info - """ - return self._system_tags_dict.copy() - - # 8. Utility Operations - def copy(self, include_cache: bool = True) -> Self: - """Return a copy of the datagram.""" - new_tag = super().copy(include_cache=include_cache) - - new_tag._system_tags_dict = self._system_tags_dict.copy() - new_tag._system_tags_python_schema = self._system_tags_python_schema.copy() - new_tag._system_tags_table = self._system_tags_table - - return new_tag - - -class ArrowPacket(ArrowDatagram): - """ - Arrow table-based packet implementation with comprehensive features. - - A packet implementation that uses Arrow tables as the primary storage format, - providing efficient memory usage and columnar data operations while supporting - source information tracking and content hashing. - - - Initialize ArrowPacket with Arrow table and configuration. - - Args: - table: Single-row Arrow table representing the packet - source_info: Optional source information mapping - semantic_converter: Optional semantic converter - semantic_type_registry: Registry for semantic types - finger_print: Optional fingerprint for tracking - arrow_hasher: Optional Arrow hasher - post_hash_callback: Optional callback after hash calculation - skip_source_info_extraction: Whether to skip source info processing - - Raises: - ValueError: If table doesn't contain exactly one row - """ - - def __init__( - self, - table: "pa.Table | pa.RecordBatch", - meta_info: Mapping[str, DataValue] | None = None, - source_info: Mapping[str, str | None] | None = None, - data_context: str | contexts.DataContext | None = None, - ) -> None: - if len(table) != 1: - raise ValueError( - "ArrowPacket should only contain a single row, " - "as it represents a single packet." - ) - if source_info is None: - source_info = {} - else: - # normalize by removing any existing prefixes - source_info = { - ( - k.removeprefix(constants.SOURCE_PREFIX) - if k.startswith(constants.SOURCE_PREFIX) - else k - ): v - for k, v in source_info.items() - } - - # normalize the table to ensure it has the expected source_info columns - # TODO: use simpler function to ensure source_info columns - data_table, prefixed_tables = arrow_utils.prepare_prefixed_columns( - table, - {constants.SOURCE_PREFIX: source_info}, - exclude_columns=[constants.CONTEXT_KEY], - exclude_prefixes=[constants.META_PREFIX], - ) - - super().__init__( - data_table, - meta_info=meta_info, - data_context=data_context, - ) - self._source_info_table = prefixed_tables[constants.SOURCE_PREFIX] - - self._cached_source_info: dict[str, str | None] | None = None - self._cached_python_schema: PythonSchema | None = None - - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> tuple[str, ...]: - keys = super().keys( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - keys += tuple(f"{constants.SOURCE_PREFIX}{k}" for k in self.keys()) - return keys - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> PythonSchema: - """Return copy of the Python schema.""" - schema = super().types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - for key in self.keys(): - schema[f"{constants.SOURCE_PREFIX}{key}"] = str - return schema - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> "pa.Schema": - """ - Return the PyArrow schema for this datagram. - - Args: - include_data_context: Whether to include data context column in the schema - include_source: Whether to include source info columns in the schema - - Returns: - PyArrow schema representing the datagram's structure - """ - schema = super().arrow_schema( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - return arrow_utils.join_arrow_schemas( - schema, self._source_info_table.schema - ) - return schema - - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> dict[str, DataValue]: - """ - Convert to dictionary representation. - - Args: - include_source: Whether to include source info fields - - Returns: - Dictionary representation of the packet - """ - return_dict = super().as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - return_dict.update( - { - f"{constants.SOURCE_PREFIX}{k}": v - for k, v in self.source_info().items() - } - ) - return return_dict - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> "pa.Table": - table = super().as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - # add source_info only if there are columns and the table has meaningful data - if ( - self._source_info_table.num_columns > 0 - and self._source_info_table.num_rows > 0 - ): - table = arrow_utils.hstack_tables(table, self._source_info_table) - return table - - def as_datagram( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_source: bool = False, - ) -> ArrowDatagram: - table = self.as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_source=include_source, - ) - return ArrowDatagram( - table, - data_context=self._data_context, - ) - - def source_info(self) -> dict[str, str | None]: - """ - Return source information for all keys. - - Returns: - Copy of the dictionary mapping field names to their source info - """ - if self._cached_source_info is None: - self._cached_source_info = { - k.removeprefix(constants.SOURCE_PREFIX): v - for k, v in self._source_info_table.to_pylist()[0].items() - } - return self._cached_source_info.copy() - - def with_source_info(self, **source_info: str | None) -> Self: - """ - Create a copy of the packet with updated source information. - - Args: - source_info: New source information mapping - - Returns: - New ArrowPacket instance with updated source info - """ - new_packet = self.copy(include_cache=False) - - existing_source_info_with_prefix = self._source_info_table.to_pylist()[0] - for key, value in source_info.items(): - if not key.startswith(constants.SOURCE_PREFIX): - # Ensure the key is prefixed correctly - key = f"{constants.SOURCE_PREFIX}{key}" - if key in existing_source_info_with_prefix: - existing_source_info_with_prefix[key] = value - - new_packet._source_info_table = pa.Table.from_pylist( - [existing_source_info_with_prefix] - ) - return new_packet - - def rename(self, column_mapping: Mapping[str, str]) -> Self: - """ - Create a new ArrowDatagram with data columns renamed. - Maintains immutability by returning a new instance. - - Args: - column_mapping: Mapping from old column names to new column names - - Returns: - New ArrowDatagram instance with renamed data columns - """ - # Create new schema with renamed fields, preserving original types - - if not column_mapping: - return self - - new_names = [column_mapping.get(k, k) for k in self._data_table.column_names] - - new_source_info_names = [ - f"{constants.SOURCE_PREFIX}{column_mapping.get(k.removeprefix(constants.SOURCE_PREFIX), k.removeprefix(constants.SOURCE_PREFIX))}" - for k in self._source_info_table.column_names - ] - - new_datagram = self.copy(include_cache=False) - new_datagram._data_table = new_datagram._data_table.rename_columns(new_names) - new_datagram._source_info_table = ( - new_datagram._source_info_table.rename_columns(new_source_info_names) - ) - - return new_datagram - - def with_columns( - self, - column_types: Mapping[str, type] | None = None, - **updates: DataValue, - ) -> Self: - """ - Create a new ArrowPacket with new data columns added. - Maintains immutability by returning a new instance. - Also adds corresponding empty source info columns for new columns. - - Args: - column_types: Optional type specifications for new columns - **updates: New data columns as keyword arguments - - Returns: - New ArrowPacket instance with new data columns and corresponding source info columns - - Raises: - ValueError: If any column already exists (use update() instead) - """ - if not updates: - return self - - # First call parent method to add the data columns - new_packet = super().with_columns(column_types=column_types, **updates) - - # Now add corresponding empty source info columns for the new columns - source_info_updates = {} - for column_name in updates.keys(): - source_key = f"{constants.SOURCE_PREFIX}{column_name}" - source_info_updates[source_key] = None # Empty source info - - # Add new source info columns to the source info table - if source_info_updates: - # Get existing source info - schema = new_packet._source_info_table.schema - existing_source_info = new_packet._source_info_table.to_pylist()[0] - - # Add the new empty source info columns - existing_source_info.update(source_info_updates) - schema_columns = list(schema) - schema_columns.extend( - [ - pa.field(name, pa.large_string()) - for name in source_info_updates.keys() - ] - ) - new_schema = pa.schema(schema_columns) - - # Update the source info table - new_packet._source_info_table = pa.Table.from_pylist( - [existing_source_info], new_schema - ) - - return new_packet - - # 8. Utility Operations - def copy(self, include_cache: bool = True) -> Self: - """Return a copy of the datagram.""" - new_packet = super().copy(include_cache=include_cache) - new_packet._source_info_table = self._source_info_table - - if include_cache: - new_packet._cached_source_info = self._cached_source_info - else: - new_packet._cached_source_info = None - - return new_packet diff --git a/src/orcapod/core/datagrams/base.py b/src/orcapod/core/datagrams/base.py deleted file mode 100644 index ec688604..00000000 --- a/src/orcapod/core/datagrams/base.py +++ /dev/null @@ -1,283 +0,0 @@ -""" -Data structures and utilities for working with datagrams in OrcaPod. - -This module provides classes and functions for handling packet-like data structures -that can represent data in various formats (Python dicts, Arrow tables, etc.) while -maintaining type information, source metadata, and semantic type conversion capability. - -Key classes: -- SemanticConverter: Converts between different data representations. Intended for internal use. -- DictDatagram: Immutable dict-based data structure -- PythonDictPacket: Python dict-based packet with source info -- ArrowPacket: Arrow table-based packet implementation -- PythonDictTag/ArrowTag: Tag implementations for data identification - -The module also provides utilities for schema validation, table operations, -and type conversions between semantic stores, Python stores, and Arrow tables. -""" - -import logging -from abc import abstractmethod -from collections.abc import Collection, Iterator, Mapping -from typing import Self, TypeAlias, TYPE_CHECKING -from orcapod import contexts -from orcapod.core.base import ContentIdentifiableBase -from orcapod.protocols.hashing_protocols import ContentHash - -from orcapod.utils.lazy_module import LazyModule -from orcapod.types import DataValue, PythonSchema - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -# A conveniece packet-like type that defines a value that can be -# converted to a packet. It's broader than Packet and a simple mapping -# from string keys to DataValue (e.g., int, float, str) can be regarded -# as PacketLike, allowing for more flexible interfaces. -# Anything that requires Packet-like data but without the strict features -# of a Packet should accept PacketLike. -# One should be careful when using PacketLike as a return type as it does not -# enforce the typespec or source_info, which are important for packet integrity. -PacketLike: TypeAlias = Mapping[str, DataValue] - -PythonStore: TypeAlias = Mapping[str, DataValue] - - -class ImmutableDict(Mapping[str, DataValue]): - """ - An immutable dictionary-like container for DataValues. - - Provides a read-only view of a dictionary mapping strings to DataValues, - implementing the Mapping protocol for compatibility with dict-like operations. - - Initialize with data from a mapping. - Args: - data: Source mapping to copy data from - """ - - def __init__(self, data: Mapping[str, DataValue]): - self._data = dict(data) - - def __getitem__(self, key: str) -> DataValue: - return self._data[key] - - def __iter__(self): - return iter(self._data) - - def __len__(self) -> int: - return len(self._data) - - def __repr__(self) -> str: - return self._data.__repr__() - - def __str__(self) -> str: - return self._data.__str__() - - def __or__(self, other: Mapping[str, DataValue]) -> Self: - """ - Create a new ImmutableDict by merging with another mapping. - - Args: - other: Another mapping to merge with - - Returns: - A new ImmutableDict containing the combined data - """ - return self.__class__(self._data | dict(other)) - - -def contains_prefix_from(column: str, prefixes: Collection[str]) -> bool: - """ - Check if a column name matches any of the given prefixes. - - Args: - column: Column name to check - prefixes: Collection of prefixes to match against - - Returns: - True if the column starts with any of the prefixes, False otherwise - """ - for prefix in prefixes: - if column.startswith(prefix): - return True - return False - - -class BaseDatagram(ContentIdentifiableBase): - """ - Abstract base class for immutable datagram implementations. - - Provides shared functionality and enforces consistent interface across - different storage backends (dict, Arrow table, etc.). Concrete subclasses - must implement the abstract methods to handle their specific storage format. - - The base class only manages the data context key string - how that key - is interpreted and used is left to concrete implementations. - """ - - def __init__(self, data_context: contexts.DataContext | str | None = None) -> None: - """ - Initialize base datagram with data context. - - Args: - data_context: Context for semantic interpretation. Can be a string key - or a DataContext object, or None for default. - """ - self._data_context = contexts.resolve_context(data_context) - self._converter = self._data_context.type_converter - - # 1. Core Properties (Identity & Structure) - @property - def data_context_key(self) -> str: - """Return the data context key.""" - return self._data_context.context_key - - @property - @abstractmethod - def meta_columns(self) -> tuple[str, ...]: - """Return tuple of meta column names.""" - ... - - # TODO: add meta info - - # 2. Dict-like Interface (Data Access) - @abstractmethod - def __getitem__(self, key: str) -> DataValue: - """Get data column value by key.""" - ... - - @abstractmethod - def __contains__(self, key: str) -> bool: - """Check if data column exists.""" - ... - - @abstractmethod - def __iter__(self) -> Iterator[str]: - """Iterate over data column names.""" - ... - - @abstractmethod - def get(self, key: str, default: DataValue = None) -> DataValue: - """Get data column value with default.""" - ... - - # 3. Structural Information - @abstractmethod - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> tuple[str, ...]: - """Return tuple of column names.""" - ... - - @abstractmethod - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> PythonSchema: - """Return type specification for the datagram.""" - ... - - @abstractmethod - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Schema": - """Return the PyArrow schema for this datagram.""" - ... - - @abstractmethod - def content_hash(self) -> ContentHash: - """Calculate and return content hash of the datagram.""" - ... - - # 4. Format Conversions (Export) - @abstractmethod - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> dict[str, DataValue]: - """Return dictionary representation of the datagram.""" - ... - - @abstractmethod - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Table": - """Convert the datagram to an Arrow table.""" - ... - - # 5. Meta Column Operations - @abstractmethod - def get_meta_value(self, key: str, default: DataValue = None) -> DataValue: - """Get a meta column value.""" - ... - - @abstractmethod - def with_meta_columns(self, **updates: DataValue) -> Self: - """Create new datagram with updated meta columns.""" - ... - - @abstractmethod - def drop_meta_columns(self, *keys: str) -> Self: - """Create new datagram with specified meta columns removed.""" - ... - - # 6. Data Column Operations - @abstractmethod - def select(self, *column_names: str) -> Self: - """Create new datagram with only specified data columns.""" - ... - - @abstractmethod - def drop(self, *column_names: str) -> Self: - """Create new datagram with specified data columns removed.""" - ... - - @abstractmethod - def rename(self, column_mapping: Mapping[str, str]) -> Self: - """Create new datagram with data columns renamed.""" - ... - - @abstractmethod - def update(self, **updates: DataValue) -> Self: - """Create new datagram with existing column values updated.""" - ... - - @abstractmethod - def with_columns( - self, - column_types: Mapping[str, type] | None = None, - **updates: DataValue, - ) -> Self: - """Create new datagram with additional data columns.""" - ... - - # 7. Context Operations - def with_context_key(self, new_context_key: str) -> Self: - """Create new datagram with different data context.""" - new_datagram = self.copy(include_cache=False) - new_datagram._data_context = contexts.resolve_context(new_context_key) - return new_datagram - - # 8. Utility Operations - def copy(self, include_cache: bool = True) -> Self: - """Create a shallow copy of the datagram.""" - new_datagram = object.__new__(self.__class__) - new_datagram._data_context = self._data_context - return new_datagram diff --git a/src/orcapod/core/datagrams/datagram.py b/src/orcapod/core/datagrams/datagram.py new file mode 100644 index 00000000..728feb7f --- /dev/null +++ b/src/orcapod/core/datagrams/datagram.py @@ -0,0 +1,781 @@ +""" +Unified datagram implementation. + +A single ``Datagram`` class that internally holds either an Arrow table or a Python +dict — whichever was provided at construction — and lazily converts to the other +representation only when required. + +Principles +---------- +- **Minimal conversion**: structural operations (select, drop, rename) stay Arrow-native + when the Arrow representation is already loaded. +- **Dict for value access**: ``__getitem__``, ``get``, ``as_dict()`` always operate through + the Python dict (loaded lazily from Arrow when needed). +- **Arrow for hashing**: ``content_hash()`` always uses the Arrow table (loaded lazily from + dict when needed) via the data context's ``ArrowTableHandler``. +- **Meta is always dict**: meta columns are stored as a Python dict regardless of how the + primary data was provided; the Arrow meta table is built lazily. +""" + +from __future__ import annotations + +import logging +from collections.abc import Collection, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Self, cast + +from uuid_utils import uuid7 + +from orcapod import contexts +from orcapod.config import Config +from orcapod.core.base import ContentIdentifiableBase +from orcapod.protocols.semantic_types_protocols import TypeConverterProtocol +from orcapod.semantic_types import infer_python_schema_from_pylist_data +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, DataValue, Schema, SchemaLike +from orcapod.utils import arrow_utils +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +class Datagram(ContentIdentifiableBase): + """ + Immutable datagram backed by either an Arrow table or a Python dict. + + Accepts either a ``Mapping[str, DataValue]`` (dict-path) or a + ``pa.Table | pa.RecordBatch`` (Arrow-path) as primary data. The alternative + representation is computed lazily and cached. + + Column conventions (same as the legacy implementations): + - Keys starting with ``constants.META_PREFIX`` (``__``) → meta columns + - The special key ``constants.CONTEXT_KEY`` → data-context column (extracted, not stored) + - Everything else → primary data columns + """ + + # ------------------------------------------------------------------ + # Construction + # ------------------------------------------------------------------ + + def __init__( + self, + data: Mapping[str, DataValue] | pa.Table | pa.RecordBatch, + python_schema: SchemaLike | None = None, + meta_info: Mapping[str, DataValue] | None = None, + record_id: str | None = None, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, + ) -> None: + if isinstance(data, pa.RecordBatch): + data = pa.Table.from_batches([data]) + + if isinstance(data, pa.Table): + self._init_from_table(data, meta_info, data_context, record_id) + else: + self._init_from_dict( + data, python_schema, meta_info, data_context, record_id + ) + + def _init_from_dict( + self, + data: Mapping[str, DataValue], + python_schema: SchemaLike | None, + meta_info: Mapping[str, DataValue] | None, + data_context: str | contexts.DataContext | None, + record_id: str | None, + ) -> None: + data_columns: dict[str, DataValue] = {} + meta_columns: dict[str, DataValue] = {} + extracted_context = None + + for k, v in data.items(): + if k == constants.CONTEXT_KEY: + if data_context is None: + extracted_context = cast(str, v) + elif k.startswith(constants.META_PREFIX): + meta_columns[k] = v + else: + data_columns[k] = v + + super().__init__(data_context=data_context or extracted_context) + self._datagram_id = record_id + + self._data_dict: dict[str, DataValue] | None = data_columns + self._data_table: pa.Table | None = None + + inferred = infer_python_schema_from_pylist_data( + [data_columns], default_type=str + ) + inferred = infer_python_schema_from_pylist_data( + [data_columns], default_type=str + ) + self._data_python_schema: Schema | None = ( + Schema({k: python_schema.get(k, v) for k, v in inferred.items()}) + if python_schema + else inferred + ) + self._data_arrow_schema: pa.Schema | None = None + + if meta_info is not None: + meta_columns.update(meta_info) + self._meta: dict[str, DataValue] = meta_columns + self._meta_python_schema: Schema = infer_python_schema_from_pylist_data( + [meta_columns], default_type=str + ) + self._meta_table: pa.Table | None = None + self._context_table: pa.Table | None = None + + def _init_from_table( + self, + table: pa.Table, + meta_info: Mapping[str, DataValue] | None, + data_context: str | contexts.DataContext | None, + record_id: str | None, + ) -> None: + if len(table) != 1: + raise ValueError( + "Table must contain exactly one row to be a valid datagram." + ) + + table = arrow_utils.normalize_table_to_large_types(table) + + # Extract context from table if not provided externally + if constants.CONTEXT_KEY in table.column_names and data_context is None: + data_context = table[constants.CONTEXT_KEY].to_pylist()[0] + + context_cols = [c for c in table.column_names if c == constants.CONTEXT_KEY] + + super().__init__(data_context=data_context) + self._datagram_id = record_id + + meta_col_names = [ + c for c in table.column_names if c.startswith(constants.META_PREFIX) + ] + self._data_table = table.drop_columns(context_cols + meta_col_names) + self._data_dict = None + self._data_python_schema = None # computed lazily + self._data_arrow_schema = None # computed lazily + + if len(self._data_table.column_names) == 0: + raise ValueError("Data table must contain at least one data column.") + + # Build meta table + meta_table: "pa.Table | None" = ( + table.select(meta_col_names) if meta_col_names else None + ) + if meta_info is not None: + normalized_meta = { + k + if k.startswith(constants.META_PREFIX) + else f"{constants.META_PREFIX}{k}": v + for k, v in meta_info.items() + } + new_meta = self.converter.python_dicts_to_arrow_table([normalized_meta]) + if meta_table is None: + meta_table = new_meta + else: + keep = [ + c for c in meta_table.column_names if c not in new_meta.column_names + ] + meta_table = arrow_utils.hstack_tables( + meta_table.select(keep), new_meta + ) + + # Store meta as dict (always); Arrow table is lazy. + # Derive schema via infer_python_schema_from_pylist_data (same as DictDatagram) + # to avoid typing.Any values that arrow_schema_to_python_schema may emit. + if meta_table is not None and meta_table.num_columns > 0: + self._meta = meta_table.to_pylist()[0] + self._meta_python_schema = infer_python_schema_from_pylist_data( + [self._meta], default_type=str + ) + else: + self._meta = {} + self._meta_python_schema = Schema.empty() + + self._meta_table = None # built lazily + self._context_table = None + + # ------------------------------------------------------------------ + # Internal helpers (lazy loading) + # ------------------------------------------------------------------ + + def _ensure_data_dict(self) -> dict[str, DataValue]: + """ + Ensure that dictionary representation is materialized and then returned + """ + if self._data_dict is None: + assert self._data_table is not None + self._data_dict = self.converter.arrow_table_to_python_dicts( + self._data_table + )[0] + return self._data_dict + + def _ensure_data_table(self) -> pa.Table: + """ + Ensure that Arrow table representation is materialized and then returned + """ + if self._data_table is None: + assert self._data_dict is not None + self._data_table = self.converter.python_dicts_to_arrow_table( + [self._data_dict], + self._data_python_schema, + ) + return self._data_table + + def _ensure_python_schema(self) -> Schema: + """ + Ensure that Python schema is materialized and then returned + """ + if self._data_python_schema is None: + assert self._data_table is not None + self._data_python_schema = self.converter.arrow_schema_to_python_schema( + self._data_table.schema + ) + return self._data_python_schema + + def _ensure_arrow_schema(self) -> pa.Schema: + """ + Ensure that Arrow schema is materialized and then returned + """ + if self._data_arrow_schema is None: + if self._data_table is not None: + self._data_arrow_schema = self._data_table.schema + else: + self._data_arrow_schema = self.converter.python_schema_to_arrow_schema( + self._ensure_python_schema() + ) + return self._data_arrow_schema + + def _ensure_context_table(self) -> pa.Table: + """ + Ensure context table is materialized and then returned (relevant for Arrow representation) + """ + if self._context_table is None: + import pyarrow as _pa + + schema = _pa.schema({constants.CONTEXT_KEY: _pa.large_string()}) + self._context_table = _pa.Table.from_pylist( + [{constants.CONTEXT_KEY: self._data_context.context_key}], + schema=schema, + ) + return self._context_table + + def _ensure_meta_table(self) -> pa.Table | None: + """ + Ensure meta table is materialized and returned (relevant for Arrow representation) + """ + if not self._meta: + return None + if self._meta_table is None: + self._meta_table = self.converter.python_dicts_to_arrow_table( + [self._meta], python_schema=self._meta_python_schema + ) + return self._meta_table + + # ------------------------------------------------------------------ + # 1. Core Properties + # ------------------------------------------------------------------ + + @property + def meta_columns(self) -> tuple[str, ...]: + return tuple(self._meta.keys()) + + # ------------------------------------------------------------------ + # 2. Dict-like Interface + # ------------------------------------------------------------------ + + def __getitem__(self, key: str) -> DataValue: + return self._ensure_data_dict()[key] + + def __contains__(self, key: str) -> bool: + if self._data_table is not None: + return key in self._data_table.column_names + assert self._data_dict is not None + return key in self._data_dict + + def __iter__(self) -> Iterator[str]: + if self._data_table is not None: + return iter(self._data_table.column_names) + assert self._data_dict is not None + return iter(self._data_dict) + + def get(self, key: str, default: DataValue = None) -> DataValue: + if key not in self: + return default + return self._ensure_data_dict()[key] + + # ------------------------------------------------------------------ + # 3. Structural Information + # ------------------------------------------------------------------ + + def keys( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> tuple[str, ...]: + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + + if self._data_table is not None: + data_keys: list[str] = list(self._data_table.column_names) + else: + assert self._data_dict is not None + data_keys = list(self._data_dict.keys()) + + if column_config.context: + data_keys.append(constants.CONTEXT_KEY) + + if column_config.meta: + include_meta = column_config.meta + if include_meta is True: + data_keys.extend(self.meta_columns) + elif isinstance(include_meta, Collection): + data_keys.extend( + c + for c in self.meta_columns + if any(c.startswith(p) for p in include_meta) + ) + + return tuple(data_keys) + + def schema( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> Schema: + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + result = dict(self._ensure_python_schema()) + + if column_config.context: + result[constants.CONTEXT_KEY] = str + + if column_config.meta and self._meta: + include_meta = column_config.meta + if include_meta is True: + result.update(self._meta_python_schema) + elif isinstance(include_meta, Collection): + result.update( + { + k: v + for k, v in self._meta_python_schema.items() + if any(k.startswith(p) for p in include_meta) + } + ) + + return Schema(result) + + def arrow_schema( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "pa.Schema": + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + all_schemas = [self._ensure_arrow_schema()] + + if column_config.context: + all_schemas.append(self._ensure_context_table().schema) + + if column_config.meta and self._meta: + meta_table = self._ensure_meta_table() + if meta_table is not None: + include_meta = column_config.meta + if include_meta is True: + all_schemas.append(meta_table.schema) + elif isinstance(include_meta, Collection): + import pyarrow as _pa + + matched = [ + f + for f in meta_table.schema + if any(f.name.startswith(p) for p in include_meta) + ] + if matched: + all_schemas.append(_pa.schema(matched)) + + return arrow_utils.join_arrow_schemas(*all_schemas) + + def identity_structure(self) -> Any: + """Return the primary data table as this datagram's identity. + + The semantic hasher dispatches ``pa.Table`` to ``ArrowTableHandler``, + which delegates to the data context's ``arrow_hasher``. This means + ``content_hash()`` (inherited from ``ContentIdentifiableBase``) produces + a stable, content-addressed hash of the data columns without any + special-casing in ``Datagram`` itself. + """ + return self._ensure_data_table() + + @property + def datagram_id(self) -> str: + """Return (or lazily generate) the datagram's unique ID.""" + if self._datagram_id is None: + self._datagram_id = str(uuid7()) + return self._datagram_id + + @property + def converter(self) -> TypeConverterProtocol: + """Semantic type converter for this datagram's data context.""" + return self.data_context.type_converter + + def with_context_key(self, new_context_key: str) -> Self: + """Create a new datagram with a different data-context key.""" + new_datagram = self.copy(include_cache=False) + new_datagram._data_context = contexts.resolve_context(new_context_key) + return new_datagram + + # ------------------------------------------------------------------ + # 4. Format Conversions + # ------------------------------------------------------------------ + + def as_dict( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> dict[str, DataValue]: + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + result = dict(self._ensure_data_dict()) + + if column_config.context: + result[constants.CONTEXT_KEY] = self._data_context.context_key + + if column_config.meta and self._meta: + include_meta = column_config.meta + if include_meta is True: + result.update(self._meta) + elif isinstance(include_meta, Collection): + result.update( + { + k: v + for k, v in self._meta.items() + if any(k.startswith(p) for p in include_meta) + } + ) + + return result + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> pa.Table: + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + all_tables = [self._ensure_data_table()] + + if column_config.context: + all_tables.append(self._ensure_context_table()) + + if column_config.meta and self._meta: + meta_table = self._ensure_meta_table() + if meta_table is not None: + include_meta = column_config.meta + if include_meta is True: + all_tables.append(meta_table) + elif isinstance(include_meta, Collection): + # Normalize: ensure all given prefixes start with META_PREFIX + prefixes = [ + p + if p.startswith(constants.META_PREFIX) + else f"{constants.META_PREFIX}{p}" + for p in include_meta + ] + matched_cols = [ + c + for c in meta_table.column_names + if any(c.startswith(p) for p in prefixes) + ] + if matched_cols: + all_tables.append(meta_table.select(matched_cols)) + + return arrow_utils.hstack_tables(*all_tables) + + def as_arrow_compatible_dict( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> dict[str, DataValue]: + return self.as_table(columns=columns, all_info=all_info).to_pylist()[0] + + # ------------------------------------------------------------------ + # 5. Meta Column Operations + # ------------------------------------------------------------------ + + def get_meta_value(self, key: str, default: DataValue = None) -> DataValue: + if not key.startswith(constants.META_PREFIX): + key = constants.META_PREFIX + key + return self._meta.get(key, default) + + def get_meta_info(self) -> dict[str, DataValue]: + return dict(self._meta) + + def with_meta_columns(self, **meta_updates: DataValue) -> Self: + prefixed = { + k + if k.startswith(constants.META_PREFIX) + else f"{constants.META_PREFIX}{k}": v + for k, v in meta_updates.items() + } + new_d = self.copy(include_cache=False) + new_d._meta = {**self._meta, **prefixed} + new_d._meta_python_schema = infer_python_schema_from_pylist_data( + [new_d._meta], default_type=str + ) + return new_d + + def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: + prefixed = { + k if k.startswith(constants.META_PREFIX) else f"{constants.META_PREFIX}{k}" + for k in keys + } + missing = prefixed - set(self._meta.keys()) + if missing and not ignore_missing: + raise KeyError( + f"Following meta columns do not exist and cannot be dropped: {sorted(missing)}" + ) + new_d = self.copy(include_cache=False) + new_d._meta = {k: v for k, v in self._meta.items() if k not in prefixed} + new_d._meta_python_schema = infer_python_schema_from_pylist_data( + [new_d._meta], default_type=str + ) + return new_d + + # ------------------------------------------------------------------ + # 6. Data Column Operations (prefer Arrow when loaded) + # ------------------------------------------------------------------ + + def select(self, *column_names: str) -> Self: + if self._data_table is not None: + missing = set(column_names) - set(self._data_table.column_names) + if missing: + raise ValueError(f"Columns not found: {missing}") + new_d = self.copy(include_cache=False) + new_d._data_table = self._data_table.select(list(column_names)) + new_d._data_dict = None + new_d._data_python_schema = None + new_d._data_arrow_schema = None + return new_d + else: + assert self._data_dict is not None + missing = set(column_names) - set(self._data_dict.keys()) + if missing: + raise ValueError(f"Columns not found: {missing}") + schema = self._ensure_python_schema() + new_d = self.copy(include_cache=False) + new_d._data_dict = { + k: v for k, v in self._data_dict.items() if k in column_names + } + new_d._data_python_schema = Schema( + {k: v for k, v in schema.items() if k in column_names} + ) + return new_d + + def drop(self, *column_names: str, ignore_missing: bool = False) -> Self: + if self._data_table is not None: + missing = set(column_names) - set(self._data_table.column_names) + if missing and not ignore_missing: + raise KeyError( + f"Following columns do not exist and cannot be dropped: {sorted(missing)}" + ) + existing = [c for c in column_names if c in self._data_table.column_names] + new_d = self.copy(include_cache=False) + if existing: + new_d._data_table = self._data_table.drop_columns(existing) + new_d._data_dict = None + new_d._data_python_schema = None + new_d._data_arrow_schema = None + return new_d + else: + assert self._data_dict is not None + missing = set(column_names) - set(self._data_dict.keys()) + if missing and not ignore_missing: + raise KeyError( + f"Following columns do not exist and cannot be dropped: {sorted(missing)}" + ) + new_data = { + k: v for k, v in self._data_dict.items() if k not in column_names + } + if not new_data: + raise ValueError("Cannot drop all data columns") + schema = self._ensure_python_schema() + new_d = self.copy(include_cache=False) + new_d._data_dict = new_data + new_d._data_python_schema = Schema( + {k: v for k, v in schema.items() if k in new_data} + ) + return new_d + + def rename(self, column_mapping: Mapping[str, str]) -> Self: + if not column_mapping: + return self + if self._data_table is not None: + new_names = [ + column_mapping.get(k, k) for k in self._data_table.column_names + ] + new_d = self.copy(include_cache=False) + new_d._data_table = self._data_table.rename_columns(new_names) + new_d._data_dict = None + new_d._data_python_schema = None + new_d._data_arrow_schema = None + return new_d + else: + assert self._data_dict is not None + schema = self._ensure_python_schema() + new_d = self.copy(include_cache=False) + new_d._data_dict = { + column_mapping.get(k, k): v for k, v in self._data_dict.items() + } + new_d._data_python_schema = Schema( + {column_mapping.get(k, k): v for k, v in schema.items()} + ) + return new_d + + def update(self, **updates: DataValue) -> Self: + if not updates: + return self + + data_keys = ( + set(self._data_table.column_names) + if self._data_table is not None + else set(self._data_dict.keys()) # type: ignore[union-attr] + ) + missing = set(updates.keys()) - data_keys + if missing: + raise KeyError( + f"Only existing columns can be updated. " + f"Following columns were not found: {sorted(missing)}" + ) + + if self._data_table is not None and self._data_dict is None: + # Arrow-native update: preserves type precision without loading full dict + sub_schema = arrow_utils.schema_select( + self._data_table.schema, list(updates.keys()) + ) + update_table = self.converter.python_dicts_to_arrow_table( + [updates], arrow_schema=sub_schema + ) + new_d = self.copy(include_cache=False) + new_d._data_table = arrow_utils.hstack_tables( + self._data_table.drop_columns(list(updates.keys())), update_table + ).select(self._data_table.column_names) + new_d._data_dict = None + new_d._data_python_schema = None + new_d._data_arrow_schema = None + return new_d + else: + assert self._data_dict is not None + new_d = self.copy(include_cache=False) + new_d._data_dict = {**self._data_dict, **updates} + new_d._data_table = None + return new_d + + def with_columns( + self, + column_types: "Mapping[str, type] | None" = None, + **updates: DataValue, + ) -> Self: + if not updates: + return self + + data_keys = ( + set(self._data_table.column_names) + if self._data_table is not None + else set(self._data_dict.keys()) # type: ignore[union-attr] + ) + existing_overlaps = set(updates.keys()) & data_keys + if existing_overlaps: + raise ValueError( + f"Columns already exist: {sorted(existing_overlaps)}. " + f"Use update() to modify existing columns." + ) + + if self._data_table is not None and self._data_dict is None: + new_data_table = self.converter.python_dicts_to_arrow_table( + [updates], + python_schema=dict(column_types) if column_types else None, + ) + new_d = self.copy(include_cache=False) + new_d._data_table = arrow_utils.hstack_tables( + self._data_table, new_data_table + ) + new_d._data_python_schema = None + new_d._data_arrow_schema = None + return new_d + else: + assert self._data_dict is not None + new_data = {**self._data_dict, **updates} + schema = dict(self._ensure_python_schema()) + if column_types: + schema.update(column_types) + inferred = infer_python_schema_from_pylist_data([new_data]) + new_schema = Schema( + {k: schema.get(k, inferred.get(k, str)) for k in new_data} + ) + new_d = self.copy(include_cache=False) + new_d._data_dict = new_data + new_d._data_python_schema = new_schema + new_d._data_table = None + return new_d + + # ------------------------------------------------------------------ + # 8. Utility Operations + # ------------------------------------------------------------------ + + def copy(self, include_cache: bool = True, preserve_id: bool = True) -> Self: + new_d = object.__new__(self.__class__) + + # Fields from ContentIdentifiableBase / DataContextMixin + new_d._data_context = self._data_context + new_d._orcapod_config = self._orcapod_config + new_d._content_hash_cache = ( + dict(self._content_hash_cache) if include_cache else {} + ) + new_d._cached_int_hash = None + + # Datagram identity + new_d._datagram_id = self._datagram_id if preserve_id else None + + # Data representations — Arrow table is immutable so a ref copy is fine + new_d._data_table = self._data_table + new_d._data_dict = ( + dict(self._data_dict) if self._data_dict is not None else None + ) + new_d._data_python_schema = Schema( + dict(self._data_python_schema) + if self._data_python_schema is not None + else None + ) + new_d._data_arrow_schema = self._data_arrow_schema + + # Meta — always dict + new_d._meta = dict(self._meta) + new_d._meta_python_schema = Schema(self._meta_python_schema) + + if include_cache: + new_d._meta_table = self._meta_table + new_d._context_table = self._context_table + else: + new_d._meta_table = None + new_d._context_table = None + + return new_d + + # ------------------------------------------------------------------ + # 9. String Representations + # ------------------------------------------------------------------ + + def __str__(self) -> str: + if self._data_dict is not None: + return str(self._data_dict) + return str(self.as_dict()) + + def __repr__(self) -> str: + return self.__str__() diff --git a/src/orcapod/core/datagrams/dict_datagram.py b/src/orcapod/core/datagrams/dict_datagram.py deleted file mode 100644 index 642a5b26..00000000 --- a/src/orcapod/core/datagrams/dict_datagram.py +++ /dev/null @@ -1,836 +0,0 @@ -import logging -from collections.abc import Collection, Iterator, Mapping -from typing import Self, cast, TYPE_CHECKING - -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.system_constants import constants -from orcapod import contexts -from orcapod.core.datagrams.base import BaseDatagram -from orcapod.semantic_types import infer_python_schema_from_pylist_data -from orcapod.types import DataValue, PythonSchema, PythonSchemaLike -from orcapod.utils import arrow_utils -from orcapod.protocols.hashing_protocols import ContentHash - -logger = logging.getLogger(__name__) - -# FIXME: make this configurable! -DEBUG = False - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - - -class DictDatagram(BaseDatagram): - """ - Immutable datagram implementation using dictionary as storage backend. - - This implementation uses composition (not inheritance from Mapping) to maintain - control over the interface while leveraging dictionary efficiency for data access. - Provides clean separation between data, meta, and context components. - - The underlying data is split into separate components: - - Data dict: Primary business data columns - - Meta dict: Internal system metadata with {orcapod.META_PREFIX} ('__') prefixes - - Context: Data context information with {orcapod.CONTEXT_KEY} - - Future Packet subclass will also handle: - - Source info: Data provenance with {orcapod.SOURCE_PREFIX} ('_source_') prefixes - - When exposing to external tools, semantic types are encoded as - `_{semantic_type}_` prefixes (_path_config_file, _id_user_name). - - All operations return new instances, preserving immutability. - - Example: - >>> data = {{ - ... "user_id": 123, - ... "name": "Alice", - ... "__pipeline_version": "v2.1.0", - ... "{orcapod.CONTEXT_KEY}": "financial_v1" - ... }} - >>> datagram = DictDatagram(data) - >>> updated = datagram.update(name="Alice Smith") - """ - - def __init__( - self, - data: Mapping[str, DataValue], - python_schema: PythonSchemaLike | None = None, - meta_info: Mapping[str, DataValue] | None = None, - data_context: str | contexts.DataContext | None = None, - ) -> None: - """ - Initialize DictDatagram from dictionary data. - - Args: - data: Source data mapping containing all column data. - typespec: Optional type specification for fields. - semantic_converter: Optional converter for semantic type handling. - If None, will be created based on data context and inferred types. - data_context: Data context for semantic type resolution. - If None and data contains context column, will extract from data. - - Note: - The input data is automatically split into data, meta, and context - components based on column naming conventions. - """ - # Parse through data and extract different column types - data_columns = {} - meta_columns = {} - extracted_context = None - - for k, v in data.items(): - if k == constants.CONTEXT_KEY: - # Extract data context but keep it separate from meta data - if data_context is None: - extracted_context = v - # Don't store context in meta_data - it's managed separately - elif k.startswith(constants.META_PREFIX): - # Double underscore = meta metadata - meta_columns[k] = v - else: - # Everything else = user data (including _source_ and semantic types) - data_columns[k] = v - - # Initialize base class with data context - final_context = data_context or cast(str, extracted_context) - super().__init__(final_context) - - # Store data and meta components separately (immutable) - self._data = dict(data_columns) - if meta_info is not None: - meta_columns.update(meta_info) - self._meta_data = meta_columns - - # Combine provided typespec info with inferred typespec from content - # If the column value is None and no type spec is provided, defaults to str. - inferred_schema = infer_python_schema_from_pylist_data( - [self._data], default_type=str - ) - - self._data_python_schema = ( - {k: python_schema.get(k, v) for k, v in inferred_schema.items()} - if python_schema - else inferred_schema - ) - - # Create schema for meta data - inferred_meta_schema = infer_python_schema_from_pylist_data( - [self._meta_data], default_type=str - ) - self._meta_python_schema = ( - {k: python_schema.get(k, v) for k, v in inferred_meta_schema.items()} - if python_schema - else inferred_meta_schema - ) - - # Initialize caches - self._cached_data_table: pa.Table | None = None - self._cached_meta_table: pa.Table | None = None - self._cached_content_hash: ContentHash | None = None - self._cached_data_arrow_schema: pa.Schema | None = None - self._cached_meta_arrow_schema: pa.Schema | None = None - - def _get_total_dict(self) -> dict[str, DataValue]: - """ - Return the total dictionary representation including meta and context. - - This is used for content hashing and exporting to Arrow. - """ - total_dict = dict(self._data) - total_dict.update(self._meta_data) - total_dict[constants.CONTEXT_KEY] = self._data_context - return total_dict - - # 1. Core Properties (Identity & Structure) - @property - def meta_columns(self) -> tuple[str, ...]: - """Return tuple of meta column names.""" - return tuple(self._meta_data.keys()) - - def get_meta_info(self) -> dict[str, DataValue]: - """ - Get meta column information. - - Returns: - Dictionary of meta column names and their values. - """ - return dict(self._meta_data) - - # 2. Dict-like Interface (Data Access) - def __getitem__(self, key: str) -> DataValue: - """Get data column value by key.""" - if key not in self._data: - raise KeyError(f"Data column '{key}' not found") - return self._data[key] - - def __contains__(self, key: str) -> bool: - """Check if data column exists.""" - return key in self._data - - def __iter__(self) -> Iterator[str]: - """Iterate over data column names.""" - return iter(self._data) - - def get(self, key: str, default: DataValue = None) -> DataValue: - """Get data column value with default.""" - return self._data.get(key, default) - - # 3. Structural Information - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> tuple[str, ...]: - """Return tuple of column names.""" - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - # Start with data columns - result_keys = list(self._data.keys()) - - # Add context if requested - if include_context: - result_keys.append(constants.CONTEXT_KEY) - - # Add meta columns if requested - if include_meta_columns: - if include_meta_columns is True: - result_keys.extend(self.meta_columns) - elif isinstance(include_meta_columns, Collection): - # Filter meta columns by prefix matching - filtered_meta_cols = [ - col - for col in self.meta_columns - if any(col.startswith(prefix) for prefix in include_meta_columns) - ] - result_keys.extend(filtered_meta_cols) - - return tuple(result_keys) - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> PythonSchema: - """ - Return Python schema for the datagram. - - Args: - include_meta_columns: Whether to include meta column types. - - True: include all meta column types - - Collection[str]: include meta column types matching these prefixes - - False: exclude meta column types - include_context: Whether to include context type - - Returns: - Python schema - """ - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - # Start with data schema - schema = dict(self._data_python_schema) - - # Add context if requested - if include_context: - schema[constants.CONTEXT_KEY] = str - - # Add meta schema if requested - if include_meta_columns and self._meta_data: - if include_meta_columns is True: - schema.update(self._meta_python_schema) - elif isinstance(include_meta_columns, Collection): - filtered_meta_schema = { - k: v - for k, v in self._meta_python_schema.items() - if any(k.startswith(prefix) for prefix in include_meta_columns) - } - schema.update(filtered_meta_schema) - - return schema - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Schema": - """ - Return the PyArrow schema for this datagram. - - Args: - include_meta_columns: Whether to include meta columns in the schema. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include context column in the schema - - Returns: - PyArrow schema representing the datagram's structure - """ - include_meta_columns = include_all_info or include_meta_columns - include_context = include_all_info or include_context - - # Build data schema (cached) - if self._cached_data_arrow_schema is None: - self._cached_data_arrow_schema = ( - self._data_context.type_converter.python_schema_to_arrow_schema( - self._data_python_schema - ) - ) - - all_schemas = [self._cached_data_arrow_schema] - - # Add context schema if requested - if include_context: - context_schema = self._converter.python_schema_to_arrow_schema( - {constants.CONTEXT_KEY: str} - ) - all_schemas.append(context_schema) - - # Add meta schema if requested - if include_meta_columns and self._meta_data: - if include_meta_columns is True: - meta_schema = self._get_meta_arrow_schema() - elif isinstance(include_meta_columns, Collection): - # Filter meta schema by prefix matching - meta_schema = ( - arrow_utils.select_schema_columns_with_prefixes( - self._get_meta_arrow_schema(), - include_meta_columns, - ) - or None - ) - else: - meta_schema = None - - if meta_schema is not None: - all_schemas.append(meta_schema) - - return arrow_utils.join_arrow_schemas(*all_schemas) - - def content_hash(self) -> ContentHash: - """ - Calculate and return content hash of the datagram. - Only includes data columns, not meta columns or context. - - Returns: - Hash string of the datagram content - """ - if self._cached_content_hash is None: - self._cached_content_hash = self._data_context.arrow_hasher.hash_table( - self.as_table(include_meta_columns=False, include_context=False), - ) - return self._cached_content_hash - - # 4. Format Conversions (Export) - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> dict[str, DataValue]: - """ - Return dictionary representation of the datagram. - - Args: - include_meta_columns: Whether to include meta columns. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include context key - - Returns: - Dictionary representation - """ - include_context = include_all_info or include_context - include_meta_columns = include_all_info or include_meta_columns - - result_dict = dict(self._data) # Start with user data - - # Add context if requested - if include_context: - result_dict[constants.CONTEXT_KEY] = self._data_context.context_key - - # Add meta columns if requested - if include_meta_columns and self._meta_data: - if include_meta_columns is True: - # Include all meta columns - result_dict.update(self._meta_data) - elif isinstance(include_meta_columns, Collection): - # Include only meta columns matching prefixes - filtered_meta_data = { - k: v - for k, v in self._meta_data.items() - if any(k.startswith(prefix) for prefix in include_meta_columns) - } - result_dict.update(filtered_meta_data) - - return result_dict - - def as_arrow_compatible_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> dict[str, DataValue]: - """ - Return dictionary representation compatible with Arrow. - - Args: - include_meta_columns: Whether to include meta columns. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include context key - - Returns: - Dictionary representation compatible with Arrow - """ - # FIXME: this is a super inefficient implementation! - python_dict = self.as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - python_schema = self.types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - - return self._data_context.type_converter.python_dicts_to_struct_dicts( - [python_dict], python_schema=python_schema - )[0] - - def _get_meta_arrow_table(self) -> "pa.Table": - if self._cached_meta_table is None: - arrow_schema = self._get_meta_arrow_schema() - self._cached_meta_table = pa.Table.from_pylist( - [self._meta_data], - schema=arrow_schema, - ) - assert self._cached_meta_table is not None, ( - "Meta Arrow table should be initialized by now" - ) - return self._cached_meta_table - - def _get_meta_arrow_schema(self) -> "pa.Schema": - if self._cached_meta_arrow_schema is None: - self._cached_meta_arrow_schema = ( - self._data_context.type_converter.python_schema_to_arrow_schema( - self._meta_python_schema - ) - ) - - assert self._cached_meta_arrow_schema is not None, ( - "Meta Arrow schema should be initialized by now" - ) - return self._cached_meta_arrow_schema - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Table": - """ - Convert the datagram to an Arrow table. - - Args: - include_meta_columns: Whether to include meta columns. - - True: include all meta columns - - Collection[str]: include meta columns matching these prefixes - - False: exclude meta columns - include_context: Whether to include the context column - - Returns: - Arrow table representation - """ - include_context = include_all_info or include_context - include_meta_columns = include_all_info or include_meta_columns - - # Build data table (cached) - if self._cached_data_table is None: - self._cached_data_table = ( - self._data_context.type_converter.python_dicts_to_arrow_table( - [self._data], - self._data_python_schema, - ) - ) - assert self._cached_data_table is not None, ( - "Data Arrow table should be initialized by now" - ) - result_table = self._cached_data_table - - # Add context if requested - # TODO: consider using type converter for consistency - if include_context: - result_table = result_table.append_column( - constants.CONTEXT_KEY, - pa.array([self._data_context.context_key], type=pa.large_string()), - ) - - # Add meta columns if requested - meta_table = None - if include_meta_columns and self._meta_data: - meta_table = self._get_meta_arrow_table() - # Select appropriate meta columns - if isinstance(include_meta_columns, Collection): - # Filter meta columns by prefix matching - meta_table = arrow_utils.select_table_columns_with_prefixes( - meta_table, include_meta_columns - ) - - # Combine tables if we have meta columns to add - if meta_table: - result_table = arrow_utils.hstack_tables(result_table, meta_table) - - return result_table - - # 5. Meta Column Operations - def get_meta_value(self, key: str, default: DataValue = None) -> DataValue: - """ - Get meta column value with optional default. - - Args: - key: Meta column key (with or without {orcapod.META_PREFIX} ('__') prefix). - default: Value to return if meta column doesn't exist. - - Returns: - Meta column value if exists, otherwise the default value. - """ - # Handle both prefixed and unprefixed keys - if not key.startswith(constants.META_PREFIX): - key = constants.META_PREFIX + key - - return self._meta_data.get(key, default) - - def with_meta_columns(self, **meta_updates: DataValue) -> Self: - """ - Create a new DictDatagram with updated meta columns. - Maintains immutability by returning a new instance. - - Args: - **meta_updates: Meta column updates (keys will be prefixed with {orcapod.META_PREFIX} ('__') if needed) - - Returns: - New DictDatagram instance - """ - # Prefix the keys and prepare updates - prefixed_updates = {} - for k, v in meta_updates.items(): - if not k.startswith(constants.META_PREFIX): - k = constants.META_PREFIX + k - prefixed_updates[k] = v - - # Start with existing meta data - new_meta_data = dict(self._meta_data) - new_meta_data.update(prefixed_updates) - - # Reconstruct full data dict for new instance - full_data = dict(self._data) # User data - full_data.update(new_meta_data) # Meta data - - return self.__class__( - data=full_data, - data_context=self._data_context, - ) - - def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: - """ - Create a new DictDatagram with specified meta columns dropped. - Maintains immutability by returning a new instance. - - Args: - *keys: Meta column keys to drop (with or without {orcapod.META_PREFIX} ('__') prefix) - ignore_missing: If True, ignore missing meta columns without raising an error. - - Raises: - KeyError: If any specified meta column to drop doesn't exist and ignore_missing=False. - - Returns: - New DictDatagram instance without specified meta columns - """ - # Normalize keys to have prefixes - prefixed_keys = set() - for key in keys: - if not key.startswith(constants.META_PREFIX): - key = constants.META_PREFIX + key - prefixed_keys.add(key) - - missing_keys = prefixed_keys - set(self._meta_data.keys()) - if missing_keys and not ignore_missing: - raise KeyError( - f"Following meta columns do not exist and cannot be dropped: {sorted(missing_keys)}" - ) - - # Filter out specified meta columns - new_meta_data = { - k: v for k, v in self._meta_data.items() if k not in prefixed_keys - } - - # Reconstruct full data dict for new instance - full_data = dict(self._data) # User data - full_data.update(new_meta_data) # Filtered meta data - - return self.__class__( - data=full_data, - data_context=self._data_context, - ) - - # 6. Data Column Operations - def select(self, *column_names: str) -> Self: - """ - Create a new DictDatagram with only specified data columns. - Maintains immutability by returning a new instance. - - Args: - *column_names: Data column names to keep - - Returns: - New DictDatagram instance with only specified data columns - """ - # Validate columns exist - missing_cols = set(column_names) - set(self._data.keys()) - if missing_cols: - raise KeyError(f"Columns not found: {missing_cols}") - - # Keep only specified data columns - new_data = {k: v for k, v in self._data.items() if k in column_names} - - # Reconstruct full data dict for new instance - full_data = new_data # Selected user data - full_data.update(self._meta_data) # Keep existing meta data - - return self.__class__( - data=full_data, - data_context=self._data_context, - ) - - def drop(self, *column_names: str, ignore_missing: bool = False) -> Self: - """ - Create a new DictDatagram with specified data columns dropped. - Maintains immutability by returning a new instance. - - Args: - *column_names: Data column names to drop - - Returns: - New DictDatagram instance without specified data columns - """ - # Filter out specified data columns - missing = set(column_names) - set(self._data.keys()) - if missing and not ignore_missing: - raise KeyError( - f"Following columns do not exist and cannot be dropped: {sorted(missing)}" - ) - - new_data = {k: v for k, v in self._data.items() if k not in column_names} - - if not new_data: - raise ValueError("Cannot drop all data columns") - - new_datagram = self.copy(include_cache=False) - new_datagram._data = new_data - return new_datagram - - def rename(self, column_mapping: Mapping[str, str]) -> Self: - """ - Create a new DictDatagram with data columns renamed. - Maintains immutability by returning a new instance. - - Args: - column_mapping: Mapping from old column names to new column names - - Returns: - New DictDatagram instance with renamed data columns - """ - # Rename data columns according to mapping, preserving original types - new_data = {} - for old_name, value in self._data.items(): - new_name = column_mapping.get(old_name, old_name) - new_data[new_name] = value - - # Handle python_schema updates for renamed columns - new_python_schema = None - if self._data_python_schema: - existing_python_schema = dict(self._data_python_schema) - - # Rename types according to column mapping - renamed_python_schema = {} - for old_name, old_type in existing_python_schema.items(): - new_name = column_mapping.get(old_name, old_name) - renamed_python_schema[new_name] = old_type - - new_python_schema = renamed_python_schema - - # Reconstruct full data dict for new instance - full_data = new_data # Renamed user data - full_data.update(self._meta_data) # Keep existing meta data - - return self.__class__( - data=full_data, - python_schema=new_python_schema, - data_context=self._data_context, - ) - - def update(self, **updates: DataValue) -> Self: - """ - Create a new DictDatagram with existing column values updated. - Maintains immutability by returning a new instance if any values are changed. - - Args: - **updates: Column names and their new values (columns must exist) - - Returns: - New DictDatagram instance with updated values - - Raises: - KeyError: If any column doesn't exist (use with_columns() to add new columns) - """ - if not updates: - return self - - # Error if any column doesn't exist - missing_columns = set(updates.keys()) - set(self._data.keys()) - if missing_columns: - raise KeyError( - f"Columns not found: {sorted(missing_columns)}. " - f"Use with_columns() to add new columns." - ) - - # Update existing columns - new_data = dict(self._data) - new_data.update(updates) - - new_datagram = self.copy(include_cache=False) - new_datagram._data = new_data - return new_datagram - - def with_columns( - self, - column_types: Mapping[str, type] | None = None, - **updates: DataValue, - ) -> Self: - """ - Create a new DictDatagram with new data columns added. - Maintains immutability by returning a new instance. - - Args: - column_updates: New data columns as a mapping - column_types: Optional type specifications for new columns - **kwargs: New data columns as keyword arguments - - Returns: - New DictDatagram instance with new data columns added - - Raises: - ValueError: If any column already exists (use update() instead) - """ - # Combine explicit updates with kwargs - - if not updates: - return self - - # Error if any column already exists - existing_overlaps = set(updates.keys()) & set(self._data.keys()) - if existing_overlaps: - raise ValueError( - f"Columns already exist: {sorted(existing_overlaps)}. " - f"Use update() to modify existing columns." - ) - - # Update user data with new columns - new_data = dict(self._data) - new_data.update(updates) - - # Create updated python schema - handle None values by defaulting to str - python_schema = self.types() - if column_types is not None: - python_schema.update(column_types) - - new_python_schema = infer_python_schema_from_pylist_data([new_data]) - new_python_schema = { - k: python_schema.get(k, v) for k, v in new_python_schema.items() - } - - new_datagram = self.copy(include_cache=False) - new_datagram._data = new_data - new_datagram._data_python_schema = new_python_schema - - return new_datagram - - # 8. Utility Operations - def copy(self, include_cache: bool = True) -> Self: - """ - Create a shallow copy of the datagram. - - Returns a new datagram instance with the same data and cached values. - This is more efficient than reconstructing from scratch when you need - an identical datagram instance. - - Returns: - New DictDatagram instance with copied data and caches. - """ - new_datagram = super().copy() - new_datagram._data = self._data.copy() - new_datagram._meta_data = self._meta_data.copy() - new_datagram._data_python_schema = self._data_python_schema.copy() - new_datagram._meta_python_schema = self._meta_python_schema.copy() - - if include_cache: - new_datagram._cached_data_table = self._cached_data_table - new_datagram._cached_meta_table = self._cached_meta_table - new_datagram._cached_content_hash = self._cached_content_hash - new_datagram._cached_data_arrow_schema = self._cached_data_arrow_schema - new_datagram._cached_meta_arrow_schema = self._cached_meta_arrow_schema - else: - new_datagram._cached_data_table = None - new_datagram._cached_meta_table = None - new_datagram._cached_content_hash = None - new_datagram._cached_data_arrow_schema = None - new_datagram._cached_meta_arrow_schema = None - - return new_datagram - - # 9. String Representations - def __str__(self) -> str: - """ - Return user-friendly string representation. - - Shows the datagram as a simple dictionary for user-facing output, - messages, and logging. Only includes data columns for clean output. - - Returns: - Dictionary-style string representation of data columns only. - """ - return str(self._data) - - def __repr__(self) -> str: - """ - Return detailed string representation for debugging. - - Shows the datagram type and comprehensive information including - data columns, meta columns count, and context for debugging purposes. - - Returns: - Detailed representation with type and metadata information. - """ - if DEBUG: - meta_count = len(self.meta_columns) - context_key = self.data_context_key - - return ( - f"{self.__class__.__name__}(" - f"data={self._data}, " - f"meta_columns={meta_count}, " - f"context='{context_key}'" - f")" - ) - else: - return str(self._data) diff --git a/src/orcapod/core/datagrams/dict_tag_packet.py b/src/orcapod/core/datagrams/dict_tag_packet.py deleted file mode 100644 index 11e6d66e..00000000 --- a/src/orcapod/core/datagrams/dict_tag_packet.py +++ /dev/null @@ -1,547 +0,0 @@ -import logging -from collections.abc import Collection, Mapping -from typing import Self, TYPE_CHECKING - - -from orcapod.core.system_constants import constants -from orcapod import contexts -from orcapod.core.datagrams.dict_datagram import DictDatagram -from orcapod.utils import arrow_utils -from orcapod.semantic_types import infer_python_schema_from_pylist_data -from orcapod.types import DataValue, PythonSchema, PythonSchemaLike -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -logger = logging.getLogger(__name__) - - -class DictTag(DictDatagram): - """ - A simple tag implementation using Python dictionary. - - Represents a tag (metadata) as a dictionary that can be converted - to different representations like Arrow tables. - """ - - def __init__( - self, - data: Mapping[str, DataValue], - system_tags: Mapping[str, DataValue] | None = None, - meta_info: Mapping[str, DataValue] | None = None, - python_schema: dict[str, type] | None = None, - data_context: str | contexts.DataContext | None = None, - ) -> None: - """ - Initialize the tag with data. - - Args: - data: Dictionary containing tag data - """ - # normalize the data content and remove any source info keys - data_only = { - k: v - for k, v in data.items() - if not k.startswith(constants.SYSTEM_TAG_PREFIX) - } - extracted_system_tags = { - k: v for k, v in data.items() if k.startswith(constants.SYSTEM_TAG_PREFIX) - } - - super().__init__( - data_only, - python_schema=python_schema, - meta_info=meta_info, - data_context=data_context, - ) - - self._system_tags = {**extracted_system_tags, **(system_tags or {})} - self._system_tags_python_schema: PythonSchema = ( - infer_python_schema_from_pylist_data([self._system_tags]) - ) - self._cached_system_tags_table: pa.Table | None = None - self._cached_system_tags_schema: pa.Schema | None = None - - def _get_total_dict(self) -> dict[str, DataValue]: - """Return the total dictionary representation including system tags.""" - total_dict = super()._get_total_dict() - total_dict.update(self._system_tags) - return total_dict - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> "pa.Table": - """Convert the packet to an Arrow table.""" - table = super().as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - - if include_all_info or include_system_tags: - # Only create and stack system tags table if there are actually system tags - if self._system_tags: # Check if system tags dict is not empty - if self._cached_system_tags_table is None: - self._cached_system_tags_table = ( - self._data_context.type_converter.python_dicts_to_arrow_table( - [self._system_tags], - python_schema=self._system_tags_python_schema, - ) - ) - table = arrow_utils.hstack_tables(table, self._cached_system_tags_table) - return table - - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> dict[str, DataValue]: - """ - Return dictionary representation. - - Args: - include_source: Whether to include source info fields - - Returns: - Dictionary representation of the packet - """ - dict_copy = super().as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - dict_copy.update(self._system_tags) - return dict_copy - - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> tuple[str, ...]: - """Return keys of the Python schema.""" - keys = super().keys( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - keys += tuple(self._system_tags.keys()) - return keys - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> PythonSchema: - """Return copy of the Python schema.""" - schema = super().types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - schema.update(self._system_tags_python_schema) - return schema - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> "pa.Schema": - """ - Return the PyArrow schema for this datagram. - - Args: - include_data_context: Whether to include data context column in the schema - include_source: Whether to include source info columns in the schema - - Returns: - PyArrow schema representing the datagram's structure - """ - schema = super().arrow_schema( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_system_tags: - if self._cached_system_tags_schema is None: - self._cached_system_tags_schema = ( - self._data_context.type_converter.python_schema_to_arrow_schema( - self._system_tags_python_schema - ) - ) - return arrow_utils.join_arrow_schemas( - schema, self._cached_system_tags_schema - ) - return schema - - def as_datagram( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_system_tags: bool = False, - ) -> DictDatagram: - """ - Convert the packet to a DictDatagram. - - Args: - include_source: Whether to include source info fields - - Returns: - DictDatagram representation of the packet - """ - - data = self.as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_system_tags=include_system_tags, - ) - python_schema = self.types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_system_tags=include_system_tags, - ) - return DictDatagram( - data, - python_schema=python_schema, - data_context=self._data_context, - ) - - def system_tags(self) -> dict[str, DataValue]: - """ - Return source information for all keys. - - Returns: - Dictionary mapping field names to their source info - """ - return dict(self._system_tags) - - def copy(self, include_cache: bool = True) -> Self: - """Return a shallow copy of the packet.""" - instance = super().copy(include_cache=include_cache) - instance._system_tags = self._system_tags.copy() - if include_cache: - instance._cached_system_tags_table = self._cached_system_tags_table - instance._cached_system_tags_schema = self._cached_system_tags_schema - - else: - instance._cached_system_tags_table = None - instance._cached_system_tags_schema = None - - return instance - - -class DictPacket(DictDatagram): - """ - Enhanced packet implementation with source information support. - - Extends DictDatagram to include source information tracking and - enhanced table conversion capabilities that can include or exclude - source metadata. - - Initialize packet with data and optional source information. - - Args: - data: Primary data content - source_info: Optional mapping of field names to source information - typespec: Optional type specification - semantic_converter: Optional semantic converter - semantic_type_registry: Registry for semantic types. Defaults to system default registry. - arrow_hasher: Optional Arrow hasher. Defaults to system default arrow hasher. - """ - - def __init__( - self, - data: Mapping[str, DataValue], - meta_info: Mapping[str, DataValue] | None = None, - source_info: Mapping[str, str | None] | None = None, - python_schema: PythonSchemaLike | None = None, - data_context: str | contexts.DataContext | None = None, - ) -> None: - # normalize the data content and remove any source info keys - data_only = { - k: v for k, v in data.items() if not k.startswith(constants.SOURCE_PREFIX) - } - contained_source_info = { - k.removeprefix(constants.SOURCE_PREFIX): v - for k, v in data.items() - if k.startswith(constants.SOURCE_PREFIX) - } - - super().__init__( - data_only, - python_schema=python_schema, - meta_info=meta_info, - data_context=data_context, - ) - - self._source_info = {**contained_source_info, **(source_info or {})} - self._cached_source_info_table: pa.Table | None = None - self._cached_source_info_schema: pa.Schema | None = None - - @property - def _source_info_arrow_schema(self) -> "pa.Schema": - if self._cached_source_info_schema is None: - self._cached_source_info_schema = ( - self._converter.python_schema_to_arrow_schema( - self._source_info_python_schema - ) - ) - - return self._cached_source_info_schema - - @property - def _source_info_python_schema(self) -> dict[str, type]: - """Return the Python schema for source info.""" - return {f"{constants.SOURCE_PREFIX}{k}": str for k in self.keys()} - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> "pa.Table": - """Convert the packet to an Arrow table.""" - table = super().as_table( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - if self._cached_source_info_table is None: - source_info_data = { - f"{constants.SOURCE_PREFIX}{k}": v - for k, v in self.source_info().items() - } - self._cached_source_info_table = pa.Table.from_pylist( - [source_info_data], schema=self._source_info_arrow_schema - ) - assert self._cached_source_info_table is not None, ( - "Cached source info table should not be None" - ) - # subselect the corresponding _source_info as the columns present in the data table - source_info_table = self._cached_source_info_table.select( - [ - f"{constants.SOURCE_PREFIX}{k}" - for k in table.column_names - if k in self.keys() - ] - ) - table = arrow_utils.hstack_tables(table, source_info_table) - return table - - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> dict[str, DataValue]: - """ - Return dictionary representation. - - Args: - include_source: Whether to include source info fields - - Returns: - Dictionary representation of the packet - """ - dict_copy = super().as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - for key, value in self.source_info().items(): - dict_copy[f"{constants.SOURCE_PREFIX}{key}"] = value - return dict_copy - - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> tuple[str, ...]: - """Return keys of the Python schema.""" - keys = super().keys( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - keys += tuple(f"{constants.SOURCE_PREFIX}{key}" for key in super().keys()) - return keys - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> PythonSchema: - """Return copy of the Python schema.""" - schema = super().types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - for key in self.keys(): - schema[f"{constants.SOURCE_PREFIX}{key}"] = str - return schema - - def rename(self, column_mapping: Mapping[str, str]) -> Self: - """ - Create a new DictDatagram with data columns renamed. - Maintains immutability by returning a new instance. - - Args: - column_mapping: Mapping from old column names to new column names - - Returns: - New DictDatagram instance with renamed data columns - """ - # Rename data columns according to mapping, preserving original types - - new_data = {column_mapping.get(k, k): v for k, v in self._data.items()} - - new_source_info = { - column_mapping.get(k, k): v for k, v in self._source_info.items() - } - - # Handle python_schema updates for renamed columns - new_python_schema = { - column_mapping.get(k, k): v for k, v in self._data_python_schema.items() - } - - return self.__class__( - data=new_data, - meta_info=self._meta_data, - source_info=new_source_info, - python_schema=new_python_schema, - data_context=self._data_context, - ) - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> "pa.Schema": - """ - Return the PyArrow schema for this datagram. - - Args: - include_data_context: Whether to include data context column in the schema - include_source: Whether to include source info columns in the schema - - Returns: - PyArrow schema representing the datagram's structure - """ - schema = super().arrow_schema( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_context=include_context, - ) - if include_all_info or include_source: - return arrow_utils.join_arrow_schemas( - schema, self._source_info_arrow_schema - ) - return schema - - def as_datagram( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_source: bool = False, - ) -> DictDatagram: - """ - Convert the packet to a DictDatagram. - - Args: - include_source: Whether to include source info fields - - Returns: - DictDatagram representation of the packet - """ - - data = self.as_dict( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_source=include_source, - ) - python_schema = self.types( - include_all_info=include_all_info, - include_meta_columns=include_meta_columns, - include_source=include_source, - ) - return DictDatagram( - data, - python_schema=python_schema, - data_context=self._data_context, - ) - - def source_info(self) -> dict[str, str | None]: - """ - Return source information for all keys. - - Returns: - Dictionary mapping field names to their source info - """ - return {key: self._source_info.get(key, None) for key in self.keys()} - - def with_source_info(self, **source_info: str | None) -> Self: - """ - Create a new packet with updated source information. - - Args: - **kwargs: Key-value pairs to update source information - - Returns: - New DictPacket instance with updated source info - """ - current_source_info = self._source_info.copy() - - for key, value in source_info.items(): - # Remove prefix if it exists, since _source_info stores unprefixed keys - if key.startswith(constants.SOURCE_PREFIX): - key = key.removeprefix(constants.SOURCE_PREFIX) - current_source_info[key] = value - - new_packet = self.copy(include_cache=False) - new_packet._source_info = current_source_info - - return new_packet - - def copy(self, include_cache: bool = True) -> Self: - """Return a shallow copy of the packet.""" - instance = super().copy(include_cache=include_cache) - instance._source_info = self._source_info.copy() - if include_cache: - instance._cached_source_info_table = self._cached_source_info_table - instance._cached_source_info_schema = self._cached_source_info_schema - - else: - instance._cached_source_info_table = None - instance._cached_source_info_schema = None - - return instance diff --git a/src/orcapod/core/datagrams/tag_packet.py b/src/orcapod/core/datagrams/tag_packet.py new file mode 100644 index 00000000..eade8c4d --- /dev/null +++ b/src/orcapod/core/datagrams/tag_packet.py @@ -0,0 +1,474 @@ +""" +Tag and Packet — datagram subclasses with system-tags and source-info support. + +``Tag`` + Extends ``Datagram`` with *system tags*: metadata fields whose names start with + ``constants.SYSTEM_TAG_PREFIX``. System tags travel alongside the primary data + but are excluded from content hashing and structural operations unless explicitly + requested via ``ColumnConfig(system_tags=True)``. + +``Packet`` + Extends ``Datagram`` with *source information*: provenance tokens (strings or None) + keyed by data-column name. Source-info keys are stored without the + ``constants.SOURCE_PREFIX`` internally and added back when serialising via + ``as_dict()`` / ``as_table()``. +""" + +from __future__ import annotations + +import logging +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, Self + +from orcapod import contexts +from orcapod.core.datagrams.datagram import Datagram +from orcapod.semantic_types import infer_python_schema_from_pylist_data +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, DataValue, Schema, SchemaLike +from orcapod.utils import arrow_utils +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Tag +# --------------------------------------------------------------------------- + + +class Tag(Datagram): + """ + Datagram with system-tags support. + + System tags are metadata fields whose names begin with + ``constants.SYSTEM_TAG_PREFIX``. They are excluded from the primary data + representation (and therefore from content hashing) unless the caller requests + them via ``ColumnConfig(system_tags=True)``. + + Accepts the same inputs as ``Datagram`` (dict or Arrow table/batch). + System-tag fields found in the input are automatically extracted. + """ + + def __init__( + self, + data: "Mapping[str, DataValue] | pa.Table | pa.RecordBatch", + system_tags: "Mapping[str, DataValue] | None" = None, + meta_info: "Mapping[str, DataValue] | None" = None, + python_schema: "SchemaLike | None" = None, + data_context: "str | contexts.DataContext | None" = None, + record_id: "str | None" = None, + **kwargs, + ) -> None: + import pyarrow as _pa + + if isinstance(data, _pa.RecordBatch): + data = _pa.Table.from_batches([data]) + + extracted_sys_tags: dict[str, DataValue] + + if isinstance(data, _pa.Table): + # Arrow path: call super() first, then extract system-tag columns from + # self._data_table (same pattern as the legacy ArrowTag). + super().__init__( + data, + meta_info=meta_info, + data_context=data_context, + record_id=record_id, + **kwargs, + ) + sys_tag_cols = [ + c + for c in self._data_table.column_names # type: ignore[union-attr] + if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + if sys_tag_cols: + extracted_sys_tags = ( + self._data_context.type_converter.arrow_table_to_python_dicts( + self._data_table.select(sys_tag_cols) # type: ignore[union-attr] + )[0] + ) + self._data_table = self._data_table.drop_columns(sys_tag_cols) # type: ignore[union-attr] + # Invalidate derived caches + self._data_arrow_schema = None + else: + extracted_sys_tags = {} + else: + # Dict path: extract system-tag keys before calling super() + data_only = { + k: v + for k, v in data.items() + if not k.startswith(constants.SYSTEM_TAG_PREFIX) + } + extracted_sys_tags = { + k: v + for k, v in data.items() + if k.startswith(constants.SYSTEM_TAG_PREFIX) + } + super().__init__( + data_only, + python_schema=python_schema, + meta_info=meta_info, + data_context=data_context, + record_id=record_id, + **kwargs, + ) + + self._system_tags: dict[str, DataValue] = { + **extracted_sys_tags, + **(system_tags or {}), + } + self._system_tags_python_schema: Schema = infer_python_schema_from_pylist_data( + [self._system_tags], default_type=str + ) + self._system_tags_table: "pa.Table | None" = None + + # ------------------------------------------------------------------ + # Internal helper + # ------------------------------------------------------------------ + + def _ensure_system_tags_table(self) -> "pa.Table": + if self._system_tags_table is None: + self._system_tags_table = ( + self._data_context.type_converter.python_dicts_to_arrow_table( + [self._system_tags], + python_schema=self._system_tags_python_schema, + ) + ) + return self._system_tags_table + + # ------------------------------------------------------------------ + # Overrides + # ------------------------------------------------------------------ + + def keys( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> tuple[str, ...]: + keys = super().keys(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.system_tags: + keys += tuple(self._system_tags.keys()) + return keys + + def schema( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> Schema: + schema = super().schema(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.system_tags: + return Schema({**schema, **self._system_tags_python_schema}) + return schema + + def arrow_schema( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "pa.Schema": + schema = super().arrow_schema(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.system_tags and self._system_tags: + return arrow_utils.join_arrow_schemas( + schema, self._ensure_system_tags_table().schema + ) + return schema + + def as_dict( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "dict[str, DataValue]": + result = super().as_dict(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.system_tags: + result.update(self._system_tags) + return result + + def as_table( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "pa.Table": + table = super().as_table(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.system_tags and self._system_tags: + table = arrow_utils.hstack_tables(table, self._ensure_system_tags_table()) + return table + + def system_tags(self) -> "dict[str, DataValue]": + """Return a copy of the system-tags dict.""" + return dict(self._system_tags) + + def as_datagram( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> Datagram: + data = self.as_dict(columns=columns, all_info=all_info) + python_schema = self.schema(columns=columns, all_info=all_info) + return Datagram( + data, python_schema=python_schema, data_context=self._data_context + ) + + def copy(self, include_cache: bool = True, preserve_id: bool = False) -> Self: + new_tag = super().copy(include_cache=include_cache, preserve_id=preserve_id) + new_tag._system_tags = dict(self._system_tags) + new_tag._system_tags_python_schema = self._system_tags_python_schema + new_tag._system_tags_table = self._system_tags_table if include_cache else None + return new_tag + + +# --------------------------------------------------------------------------- +# Packet +# --------------------------------------------------------------------------- + + +class Packet(Datagram): + """ + Datagram with source-information tracking. + + Source info maps each data-column name to a provenance token (``str | None``). + Keys in ``_source_info`` are stored **without** the ``SOURCE_PREFIX``; the + prefix is added transparently when serialising to dict or Arrow table. + + Accepts the same inputs as ``Datagram`` (dict or Arrow table/batch). + Source-info fields (columns beginning with ``SOURCE_PREFIX``) found in the + input are automatically extracted. + """ + + def __init__( + self, + data: "Mapping[str, DataValue] | pa.Table | pa.RecordBatch", + meta_info: "Mapping[str, DataValue] | None" = None, + source_info: "Mapping[str, str | None] | None" = None, + python_schema: "SchemaLike | None" = None, + data_context: "str | contexts.DataContext | None" = None, + record_id: "str | None" = None, + **kwargs, + ) -> None: + import pyarrow as _pa + + if isinstance(data, _pa.RecordBatch): + data = _pa.Table.from_batches([data]) + + if isinstance(data, _pa.Table): + # Arrow path: use prepare_prefixed_columns to split source-info from data + if source_info is None: + source_info = {} + else: + # Normalise: remove existing prefix from provided keys + source_info = { + k.removeprefix(constants.SOURCE_PREFIX) + if k.startswith(constants.SOURCE_PREFIX) + else k: v + for k, v in source_info.items() + } + + data_table, prefixed_tables = arrow_utils.prepare_prefixed_columns( + data, + {constants.SOURCE_PREFIX: source_info}, + exclude_columns=[constants.CONTEXT_KEY], + exclude_prefixes=[constants.META_PREFIX], + ) + super().__init__( + data_table, + meta_info=meta_info, + data_context=data_context, + record_id=record_id, + **kwargs, + ) + si_table = prefixed_tables[constants.SOURCE_PREFIX] + if si_table.num_columns > 0 and si_table.num_rows > 0: + self._source_info: dict[str, str | None] = { + k.removeprefix(constants.SOURCE_PREFIX): v + for k, v in si_table.to_pylist()[0].items() + } + else: + self._source_info = {} + else: + # Dict path: extract source-info keys before calling super() + data_only = { + k: v + for k, v in data.items() + if not k.startswith(constants.SOURCE_PREFIX) + } + contained_source_info: dict[str, str | None] = { + k.removeprefix(constants.SOURCE_PREFIX): v # type: ignore[misc] + for k, v in data.items() + if k.startswith(constants.SOURCE_PREFIX) + } + super().__init__( + data_only, + python_schema=python_schema, + meta_info=meta_info, + data_context=data_context, + record_id=record_id, + **kwargs, + ) + self._source_info = {**contained_source_info, **(source_info or {})} + + self._source_info_table: "pa.Table | None" = None + + # ------------------------------------------------------------------ + # Internal helper + # ------------------------------------------------------------------ + + def _ensure_source_info_table(self) -> "pa.Table": + if self._source_info_table is None: + import pyarrow as _pa + + if self._source_info: + prefixed = { + f"{constants.SOURCE_PREFIX}{k}": v + for k, v in self._source_info.items() + } + schema = _pa.schema( + [_pa.field(k, _pa.large_string()) for k in prefixed] + ) + self._source_info_table = _pa.Table.from_pylist( + [prefixed], schema=schema + ) + else: + self._source_info_table = _pa.table({}) + return self._source_info_table + + # ------------------------------------------------------------------ + # Source-info API + # ------------------------------------------------------------------ + + def source_info(self) -> "dict[str, str | None]": + """Return source info for all data-column keys (None for unknown).""" + return {k: self._source_info.get(k) for k in self.keys()} + + def with_source_info(self, **source_info: "str | None") -> Self: + """Create a copy with updated source-information entries.""" + current = dict(self._source_info) + for key, value in source_info.items(): + if key.startswith(constants.SOURCE_PREFIX): + key = key.removeprefix(constants.SOURCE_PREFIX) + current[key] = value + new_p = self.copy(include_cache=False) + new_p._source_info = current + return new_p + + # ------------------------------------------------------------------ + # Overrides + # ------------------------------------------------------------------ + + def keys( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> tuple[str, ...]: + keys = super().keys(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.source: + keys += tuple(f"{constants.SOURCE_PREFIX}{k}" for k in super().keys()) + return keys + + def schema( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> Schema: + schema = dict(super().schema(columns=columns, all_info=all_info)) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.source: + for key in super().keys(): + schema[f"{constants.SOURCE_PREFIX}{key}"] = str + return Schema(schema) + + def arrow_schema( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "pa.Schema": + schema = super().arrow_schema(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.source: + si_table = self._ensure_source_info_table() + if si_table.num_columns > 0: + return arrow_utils.join_arrow_schemas(schema, si_table.schema) + return schema + + def as_dict( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "dict[str, DataValue]": + result = super().as_dict(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.source: + for key, value in self.source_info().items(): + result[f"{constants.SOURCE_PREFIX}{key}"] = value + return result + + def as_table( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> "pa.Table": + table = super().as_table(columns=columns, all_info=all_info) + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + if column_config.source: + si_table = self._ensure_source_info_table() + if si_table.num_columns > 0 and si_table.num_rows > 0: + table = arrow_utils.hstack_tables(table, si_table) + return table + + def rename(self, column_mapping: "Mapping[str, str]") -> Self: + new_p = super().rename(column_mapping) + new_p._source_info = { + column_mapping.get(k, k): v for k, v in self._source_info.items() + } + new_p._source_info_table = None + return new_p + + def with_columns( + self, + column_types: "Mapping[str, type] | None" = None, + **updates: DataValue, + ) -> Self: + new_p = super().with_columns(column_types=column_types, **updates) + new_source_info = dict(self._source_info) + for col in updates: + new_source_info[col] = None # new columns get empty source info + new_p._source_info = new_source_info + new_p._source_info_table = None + return new_p + + def as_datagram( + self, + *, + columns: "ColumnConfig | dict[str, Any] | None" = None, + all_info: bool = False, + ) -> Datagram: + data = self.as_dict(columns=columns, all_info=all_info) + python_schema = self.schema(columns=columns, all_info=all_info) + return Datagram( + data=data, python_schema=python_schema, data_context=self._data_context + ) + + def copy(self, include_cache: bool = True, preserve_id: bool = True) -> Self: + new_p = super().copy(include_cache=include_cache, preserve_id=preserve_id) + new_p._source_info = dict(self._source_info) + new_p._source_info_table = self._source_info_table if include_cache else None + return new_p diff --git a/src/orcapod/core/executors/__init__.py b/src/orcapod/core/executors/__init__.py new file mode 100644 index 00000000..179fb260 --- /dev/null +++ b/src/orcapod/core/executors/__init__.py @@ -0,0 +1,7 @@ +from orcapod.core.executors.base import PacketFunctionExecutorBase +from orcapod.core.executors.local import LocalExecutor + +__all__ = [ + "PacketFunctionExecutorBase", + "LocalExecutor", +] diff --git a/src/orcapod/core/executors/base.py b/src/orcapod/core/executors/base.py new file mode 100644 index 00000000..1a7d0d6b --- /dev/null +++ b/src/orcapod/core/executors/base.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + + +class PacketFunctionExecutorBase(ABC): + """Abstract base class for packet function executors. + + An executor defines *where* and *how* a packet function's computation + runs (e.g. in-process, on a Ray cluster, in a container). Executors + are type-specific: each declares the ``packet_function_type_id`` values + it supports. + + Subclasses must implement ``execute`` and optionally ``async_execute``. + """ + + @property + @abstractmethod + def executor_type_id(self) -> str: + """Unique identifier for this executor type, e.g. ``'local'``, ``'ray.v0'``.""" + ... + + @abstractmethod + def supported_function_type_ids(self) -> frozenset[str]: + """Return the set of ``packet_function_type_id`` values this executor can run. + + Return an empty ``frozenset`` to indicate support for *all* types. + """ + ... + + def supports(self, packet_function_type_id: str) -> bool: + """Return ``True`` if this executor can handle the given function type. + + Default implementation checks membership in + ``supported_function_type_ids()``; an empty set means "supports all". + """ + ids = self.supported_function_type_ids() + return len(ids) == 0 or packet_function_type_id in ids + + @abstractmethod + def execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Synchronously execute *packet_function* on *packet*. + + Implementations should call ``packet_function.direct_call(packet)`` + to invoke the function's native computation, bypassing executor + routing, and pass through the ``(result, CapturedLogs)`` tuple. + """ + ... + + async def async_execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Asynchronous counterpart of ``execute``. + + The default implementation delegates to ``execute`` synchronously. + Subclasses should override for truly async execution. + """ + return self.execute(packet_function, packet) + + @property + def supports_concurrent_execution(self) -> bool: + """Whether this executor can run multiple packets concurrently. + + Default is ``False``. Subclasses that support truly concurrent + execution (e.g. via a remote cluster) should override to ``True``. + """ + return False + + def with_options(self, **opts: Any) -> "PacketFunctionExecutorBase": + """Return a **new** executor instance configured with the given per-node options. + + The default implementation returns a shallow copy of *self*. + Subclasses that carry mutable state (e.g. ``RayExecutor``) should + override to produce a properly configured new instance. + """ + return copy.copy(self) + + # ------------------------------------------------------------------ + # Callable-level execution (PythonFunctionExecutorProtocol) + # ------------------------------------------------------------------ + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + """Synchronously execute *fn* with *kwargs*, returning captured I/O. + + Default implementation calls ``fn(**kwargs)`` with no capture and + returns empty :class:`~orcapod.pipeline.logging_capture.CapturedLogs`. + Exceptions propagate to the caller. Subclasses (e.g. + ``LocalExecutor``, ``RayExecutor``) override to add I/O capture and + exception swallowing. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options. + + Returns: + ``(raw_result, CapturedLogs)`` + """ + from orcapod.pipeline.logging_capture import CapturedLogs + + return fn(**kwargs), CapturedLogs() + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + """Asynchronously execute *fn* with *kwargs*, returning captured I/O. + + Default implementation delegates to ``execute_callable`` + synchronously. Subclasses should override for truly async execution. + """ + return self.execute_callable(fn, kwargs, executor_options) + + def get_execution_data(self) -> dict[str, Any]: + """Return metadata describing the execution environment. + + Recorded alongside results for observability but does not affect + content or pipeline hashes. The default returns the executor type id. + """ + return {"executor_type": self.executor_type_id} diff --git a/src/orcapod/core/executors/local.py b/src/orcapod/core/executors/local.py new file mode 100644 index 00000000..ec05a086 --- /dev/null +++ b/src/orcapod/core/executors/local.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import asyncio +import inspect +import traceback as _traceback_module +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from orcapod.core.executors.base import PacketFunctionExecutorBase + +if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + + +class LocalExecutor(PacketFunctionExecutorBase): + """Default executor -- runs the packet function directly in the current process. + + Supports all packet function types (``supported_function_type_ids`` + returns an empty set). + """ + + @property + def executor_type_id(self) -> str: + return "local" + + def supported_function_type_ids(self) -> frozenset[str]: + return frozenset() + + def execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + return packet_function.direct_call(packet) + + async def async_execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + return await packet_function.direct_async_call(packet) + + # -- PythonFunctionExecutorProtocol -- + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + ctx = LocalCaptureContext() + raw_result = None + success = True + tb: str | None = None + with ctx: + try: + if inspect.iscoroutinefunction(fn): + raw_result = self._run_async_sync(fn, kwargs) + else: + raw_result = fn(**kwargs) + except Exception: + success = False + tb = _traceback_module.format_exc() + return raw_result, ctx.get_captured(success=success, tb=tb) + + @staticmethod + def _run_async_sync(fn: Callable[..., Any], kwargs: dict[str, Any]) -> Any: + """Run an async function synchronously, handling nested event loops.""" + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(fn(**kwargs)) + else: + from concurrent.futures import ThreadPoolExecutor + + with ThreadPoolExecutor(1) as pool: + return pool.submit(lambda: asyncio.run(fn(**kwargs))).result() + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + ctx = LocalCaptureContext() + raw_result = None + success = True + tb: str | None = None + with ctx: + try: + if inspect.iscoroutinefunction(fn): + raw_result = await fn(**kwargs) + else: + loop = asyncio.get_running_loop() + raw_result = await loop.run_in_executor(None, lambda: fn(**kwargs)) + except Exception: + success = False + tb = _traceback_module.format_exc() + return raw_result, ctx.get_captured(success=success, tb=tb) + + def with_options(self, **opts: Any) -> "LocalExecutor": + """Return a new ``LocalExecutor``. + + ``LocalExecutor`` carries no state, so options are ignored. + """ + return LocalExecutor() diff --git a/src/orcapod/core/executors/ray.py b/src/orcapod/core/executors/ray.py new file mode 100644 index 00000000..b1236fbc --- /dev/null +++ b/src/orcapod/core/executors/ray.py @@ -0,0 +1,342 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from orcapod.core.executors.base import PacketFunctionExecutorBase + +if TYPE_CHECKING: + from orcapod.core.packet_function import PythonPacketFunction + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol + + +class RayExecutor(PacketFunctionExecutorBase): + """Executor that dispatches Python packet functions to a Ray cluster. + + Only supports ``packet_function_type_id == "python.function.v0"``. + + The caller is responsible for calling ``ray.init(...)`` before using + this executor. If ``ray_address`` is provided and Ray has not been + initialized yet, this executor will call ``ray.init(address=...)`` + lazily on first use. + + Note: + ``ray`` is an optional dependency. Import errors surface at + construction time so callers get a clear message. + """ + + SUPPORTED_TYPES: frozenset[str] = frozenset({"python.function.v0"}) + + def __init__( + self, + *, + ray_address: str | None = None, + num_cpus: int | None = None, + num_gpus: int | None = None, + resources: dict[str, float] | None = None, + **ray_remote_opts: Any, + ) -> None: + """Create a RayExecutor. + + Args: + ray_address: Address of the Ray cluster to connect to (e.g. + ``"ray://host:10001"``). If ``None`` and Ray is not yet + initialised, ``ray.init()`` is called without an address, + which starts a local cluster. + num_cpus: Number of CPUs to request per remote task. Passed + directly to ``ray.remote(num_cpus=...)``. + num_gpus: Number of GPUs to request per remote task. Passed + directly to ``ray.remote(num_gpus=...)``. + resources: Custom resource requirements dict forwarded to + ``ray.remote(resources=...)``. + **ray_remote_opts: Any additional keyword arguments accepted by + ``ray.remote()`` (e.g. ``memory``, ``max_calls``, + ``runtime_env``, ``accelerator_type``). + """ + try: + import ray # noqa: F401 + except ImportError as exc: + raise ImportError( + "RayExecutor requires the 'ray' package. " + "Install it with: pip install ray" + ) from exc + + self._ray_address = ray_address + + # Collect all remote opts into a single dict so that arbitrary Ray + # options (memory, max_calls, accelerator_type, label_selector, …) + # can be passed through without hardcoding each one. + self._remote_opts: dict[str, Any] = {} + if num_cpus is not None: + self._remote_opts["num_cpus"] = num_cpus + if num_gpus is not None: + self._remote_opts["num_gpus"] = num_gpus + if resources is not None: + self._remote_opts["resources"] = resources + self._remote_opts.update(ray_remote_opts) + + def _ensure_ray_initialized(self) -> None: + """Initialize Ray if it has not been initialized yet. + + Also registers a cloudpickle dispatch for ``logging.Logger`` so that + user functions referencing loggers can be sent to Ray workers that + do not have orcapod installed. + + By default cloudpickle serializes Logger instances by value, which + traverses the parent chain to the root logger. After + ``install_capture_streams()`` the root logger has a + ``ContextVarLoggingHandler`` from ``orcapod``. Workers without + orcapod cannot deserialize that class. + + Registering loggers as ``(logging.getLogger, (name,))`` is the + correct semantic — loggers are name-keyed singletons — and produces + no orcapod dependency in the pickled bytes. + """ + import logging + import ray + + try: + import cloudpickle + + def _pickle_logger(l: logging.Logger) -> tuple: + # Root logger has name "root" but must be fetched as "" + name = "" if isinstance(l, logging.RootLogger) else l.name + return logging.getLogger, (name,) + + cloudpickle.CloudPickler.dispatch[logging.Logger] = _pickle_logger + cloudpickle.CloudPickler.dispatch[logging.RootLogger] = _pickle_logger + except Exception: + pass # cloudpickle not available or API changed — best effort + + if not ray.is_initialized(): + if self._ray_address is not None: + ray.init(address=self._ray_address) + else: + ray.init() + + @property + def executor_type_id(self) -> str: + return "ray.v0" + + def supported_function_type_ids(self) -> frozenset[str]: + return self.SUPPORTED_TYPES + + @property + def supports_concurrent_execution(self) -> bool: + return True + + def _build_remote_opts(self) -> dict[str, Any]: + """Return a copy of the Ray remote options dict.""" + return dict(self._remote_opts) + + def _as_python_packet_function( + self, packet_function: PacketFunctionProtocol + ) -> "PythonPacketFunction": + """Return *packet_function* cast to ``PythonPacketFunction``, or raise. + + Raises: + TypeError: If *packet_function* is not a ``PythonPacketFunction`` + instance and therefore does not expose the attributes required + for remote execution. + """ + from orcapod.core.packet_function import PythonPacketFunction + + if not isinstance(packet_function, PythonPacketFunction): + raise TypeError( + f"RayExecutor only supports PythonPacketFunction, " + f"got {type(packet_function).__name__}" + ) + return packet_function + + def execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs + + pf = self._as_python_packet_function(packet_function) + if not pf.is_active(): + return None, CapturedLogs(success=True) + + raw, captured = self.execute_callable(pf._function, packet.as_dict()) + if not captured.success: + return None, captured + return pf._build_output_packet(raw), captured + + async def async_execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs + + pf = self._as_python_packet_function(packet_function) + if not pf.is_active(): + return None, CapturedLogs(success=True) + + raw, captured = await self.async_execute_callable(pf._function, packet.as_dict()) + if not captured.success: + return None, captured + return pf._build_output_packet(raw), captured + + # -- PythonFunctionExecutorProtocol -- + + @staticmethod + def _make_capture_wrapper() -> Callable[..., Any]: + """Return an inline capture wrapper suitable for Ray remote execution. + + The wrapper is defined as a closure (not a module-level import) so that + cloudpickle serializes it by bytecode rather than by module reference. + This means the Ray cluster workers do **not** need ``orcapod`` installed + — only the standard library is required on the worker side. + + The wrapper returns a plain 6-tuple ``(raw_result, stdout, stderr, + python_logs, traceback_str, success)`` so no orcapod types cross the + Ray object store; the driver reconstructs :class:`CapturedLogs` from + the tuple. + """ + def _capture(fn: Any, kwargs: dict) -> tuple: + import io + import logging + import os + import sys + import tempfile + import traceback as _tb + + stdout_tmp = tempfile.TemporaryFile() + stderr_tmp = tempfile.TemporaryFile() + orig_stdout_fd = os.dup(1) + orig_stderr_fd = os.dup(2) + orig_sys_stdout = sys.stdout + orig_sys_stderr = sys.stderr + sys_stdout_buf = io.StringIO() + sys_stderr_buf = io.StringIO() + log_records: list = [] + + fmt = logging.Formatter("%(levelname)s:%(name)s:%(message)s") + + class _H(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + log_records.append(fmt.format(record)) + + handler = _H() + root_logger = logging.getLogger() + orig_level = root_logger.level + root_logger.setLevel(logging.DEBUG) + root_logger.addHandler(handler) + + raw_result = None + success = True + tb_str = None + try: + sys.stdout.flush() + sys.stderr.flush() + os.dup2(stdout_tmp.fileno(), 1) + os.dup2(stderr_tmp.fileno(), 2) + sys.stdout = sys_stdout_buf + sys.stderr = sys_stderr_buf + try: + raw_result = fn(**kwargs) + except Exception: + success = False + tb_str = _tb.format_exc() + finally: + sys.stdout = orig_sys_stdout + sys.stderr = orig_sys_stderr + os.dup2(orig_stdout_fd, 1) + os.dup2(orig_stderr_fd, 2) + os.close(orig_stdout_fd) + os.close(orig_stderr_fd) + root_logger.removeHandler(handler) + root_logger.setLevel(orig_level) + stdout_tmp.seek(0) + stderr_tmp.seek(0) + cap_stdout = ( + stdout_tmp.read().decode("utf-8", errors="replace") + + sys_stdout_buf.getvalue() + ) + cap_stderr = ( + stderr_tmp.read().decode("utf-8", errors="replace") + + sys_stderr_buf.getvalue() + ) + stdout_tmp.close() + stderr_tmp.close() + + return raw_result, cap_stdout, cap_stderr, "\n".join(log_records), tb_str, success + + return _capture + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + """Execute *fn* on the Ray cluster with fd-level I/O capture. + + The capture wrapper is serialized by bytecode (not module reference) so + the Ray cluster workers do not need ``orcapod`` installed. + """ + import ray + + from orcapod.pipeline.logging_capture import CapturedLogs + + self._ensure_ray_initialized() + wrapper = self._make_capture_wrapper() + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + remote_fn = ray.remote(**self._build_remote_opts())(wrapper) + ref = remote_fn.remote(fn, kwargs) + raw, stdout, stderr, python_logs, tb, success = ray.get(ref) + return raw, CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + """Async counterpart of :meth:`execute_callable`.""" + import ray + + from orcapod.pipeline.logging_capture import CapturedLogs + + self._ensure_ray_initialized() + wrapper = self._make_capture_wrapper() + wrapper.__name__ = fn.__name__ + wrapper.__qualname__ = fn.__qualname__ + remote_fn = ray.remote(**self._build_remote_opts())(wrapper) + ref = remote_fn.remote(fn, kwargs) + raw, stdout, stderr, python_logs, tb, success = await asyncio.wrap_future( + ref.future() + ) + return raw, CapturedLogs( + stdout=stdout, stderr=stderr, python_logs=python_logs, + traceback=tb, success=success, + ) + + def with_options(self, **opts: Any) -> "RayExecutor": + """Return a new ``RayExecutor`` with the given options merged in. + + The returned executor shares the same ``ray_address``. All opts are + passed through to ``ray.remote()``/``.options()`` as-is — no keys are + hardcoded, so any option Ray supports (``num_cpus``, ``num_gpus``, + ``memory``, ``max_calls``, ``accelerator_type``, ``label_selector``, + ``runtime_env``, …) can be used. Node-level opts override + pipeline-level defaults. + """ + merged = {**self._remote_opts, **opts} + return RayExecutor(ray_address=self._ray_address, **merged) + + def get_execution_data(self) -> dict[str, Any]: + return { + "executor_type": self.executor_type_id, + "ray_address": self._ray_address or "auto", + **self._remote_opts, + } diff --git a/src/orcapod/core/function_pod.py b/src/orcapod/core/function_pod.py new file mode 100644 index 00000000..7c525188 --- /dev/null +++ b/src/orcapod/core/function_pod.py @@ -0,0 +1,793 @@ +from __future__ import annotations + +import asyncio +import logging +from abc import abstractmethod +from collections.abc import Callable, Collection, Iterator, Sequence +from functools import wraps +from typing import TYPE_CHECKING, Any, Protocol, cast + +from orcapod import contexts +from orcapod.channels import ReadableChannel, WritableChannel +from orcapod.config import Config +from orcapod.core.base import TraceableBase +from orcapod.core.packet_function import CachedPacketFunction, PythonPacketFunction +from orcapod.core.streams.base import StreamBase +from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER +from orcapod.protocols.core_protocols import ( + ArgumentGroup, + FunctionPodProtocol, + PacketFunctionExecutorProtocol, + PacketFunctionProtocol, + PacketProtocol, + PodProtocol, + StreamProtocol, + TagProtocol, + TrackerManagerProtocol, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.types import ( + ColumnConfig, + NodeConfig, + PipelineConfig, + Schema, + resolve_concurrency, +) +from orcapod.utils import arrow_utils, schema_utils +from orcapod.utils.lazy_module import LazyModule + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + + +def _executor_supports_concurrent( + packet_function: PacketFunctionProtocol, +) -> bool: + """Return True if the packet function's executor supports concurrent execution.""" + executor = packet_function.executor + return executor is not None and executor.supports_concurrent_execution + + +class _FunctionPodBase(TraceableBase): + """Base pod that applies a packet function to each input packet.""" + + def __init__( + self, + packet_function: PacketFunctionProtocol, + tracker_manager: TrackerManagerProtocol | None = None, + label: str | None = None, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, + ) -> None: + super().__init__( + label=label, + data_context=data_context, + config=config, + ) + self.tracker_manager = tracker_manager or DEFAULT_TRACKER_MANAGER + self._packet_function = packet_function + self._output_schema_hash = None + + def computed_label(self) -> str | None: + """Use the packet function's canonical name as the default label.""" + return self._packet_function.canonical_function_name + + @property + def packet_function(self) -> PacketFunctionProtocol: + return self._packet_function + + @property + def executor(self) -> PacketFunctionExecutorProtocol | None: + """The executor set on the underlying packet function, or ``None``.""" + return self._packet_function.executor + + @executor.setter + def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor on the underlying packet function.""" + self._packet_function.executor = executor + + def identity_structure(self) -> Any: + return self.packet_function.identity_structure() + + def pipeline_identity_structure(self) -> Any: + return self.packet_function + + @property + def uri(self) -> tuple[str, ...]: + if self._output_schema_hash is None: + self._output_schema_hash = self.data_context.semantic_hasher.hash_object( + # hash the vanilla output schema with no extra columns + self.packet_function.output_packet_schema + ).to_string() + return ( + self.packet_function.canonical_function_name, + self._output_schema_hash, + f"v{self.packet_function.major_version}", + self.packet_function.packet_function_type_id, + ) + + def multi_stream_handler(self) -> PodProtocol: + from orcapod.core.operators import Join + + return Join() + + def validate_inputs(self, *streams: StreamProtocol) -> None: + """Validate input streams, raising exceptions if invalid. + + Args: + *streams: Input streams to validate. + + Raises: + ValueError: If inputs are incompatible with the packet function schema. + """ + input_stream = self.handle_input_streams(*streams) + _, incoming_packet_schema = input_stream.output_schema() + self._validate_input_schema(incoming_packet_schema) + + def _validate_input_schema(self, input_schema: Schema) -> None: + expected_packet_schema = self.packet_function.input_packet_schema + if not schema_utils.check_schema_compatibility( + input_schema, expected_packet_schema + ): + # TODO: use custom exception type for better error handling + raise ValueError( + f"Incoming packet data type {input_schema} is not compatible with expected input schema {expected_packet_schema}" + ) + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a single packet using the pod's packet function. + + Args: + tag: The tag associated with the packet. + packet: The input packet to process. + + Returns: + A ``(tag, output_packet)`` tuple; output_packet is ``None`` if + the function filters the packet out. CapturedLogs are discarded + (only relevant for node-level execution with observers). + """ + result, _captured = self.packet_function.call(packet) + return tag, result + + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``process_packet``.""" + result, _captured = await self.packet_function.async_call(packet) + return tag, result + + def process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None, "CapturedLogs"]: + """Process a single packet and return CapturedLogs alongside the result. + + Used by FunctionNode to get logs without a ContextVar side-channel. + """ + from orcapod.pipeline.logging_capture import CapturedLogs + + result, captured = self.packet_function.call(packet) + return tag, result, captured + + async def async_process_packet_with_capture( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None, "CapturedLogs"]: + """Async counterpart of ``process_packet_with_capture``.""" + from orcapod.pipeline.logging_capture import CapturedLogs + + result, captured = await self.packet_function.async_call(packet) + return tag, result, captured + + def handle_input_streams(self, *streams: StreamProtocol) -> StreamProtocol: + """Handle multiple input streams by joining them if necessary. + + Args: + *streams: Input streams to handle. + """ + # handle multiple input streams + if len(streams) == 0: + raise ValueError("At least one input stream is required") + elif len(streams) > 1: + # TODO: simplify the multi-stream handling logic + multi_stream_handler = self.multi_stream_handler() + joined_stream = multi_stream_handler.process(*streams) + return joined_stream + return streams[0] + + @abstractmethod + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + """Invoke the packet processor on the input stream(s). + + If multiple streams are passed in, they are joined before processing. + + Args: + *streams: Input streams to process. + label: Optional label for tracking. + + Returns: + The resulting output stream. + """ + ... + + def __call__( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + """Convenience alias for ``process``.""" + logger.debug(f"Invoking pod {self} on streams through __call__: {streams}") + # perform input stream validation + return self.process(*streams, label=label) + + def argument_symmetry(self, streams: Collection[StreamProtocol]) -> ArgumentGroup: + return self.multi_stream_handler().argument_symmetry(streams) + + def output_schema( + self, + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, incoming_packet_schema = self.multi_stream_handler().output_schema( + *streams, columns=columns, all_info=all_info + ) + # validate that incoming_packet_schema is valid + self._validate_input_schema(incoming_packet_schema) + # The output schema of the FunctionPodProtocol is determined by the packet function + # TODO: handle and extend to include additional columns + # Namely, the source columns + return tag_schema, self.packet_function.output_packet_schema + + +class FunctionPod(_FunctionPodBase): + def __init__( + self, + packet_function: PacketFunctionProtocol, + node_config: NodeConfig | None = None, + **kwargs, + ) -> None: + super().__init__(packet_function, **kwargs) + self._node_config = node_config or NodeConfig() + + @property + def node_config(self) -> NodeConfig: + return self._node_config + + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> FunctionPodStream: + """Invoke the packet processor on the input stream(s). + + Args: + *streams: Input streams to process. + label: Optional label for tracking. + + Returns: + A ``FunctionPodStream`` wrapping the computation. + """ + logger.debug(f"Invoking kernel {self} on streams: {streams}") + + input_stream = self.handle_input_streams(*streams) + + # perform input stream schema validation + self._validate_input_schema(input_stream.output_schema()[1]) + self.tracker_manager.record_function_pod_invocation( + self, input_stream, label=label + ) + output_stream = FunctionPodStream( + function_pod=self, + input_stream=input_stream, + label=label, + ) + return output_stream + + def __call__( + self, *streams: StreamProtocol, label: str | None = None + ) -> FunctionPodStream: + """Convenience alias for ``process``.""" + logger.debug(f"Invoking pod {self} on streams through __call__: {streams}") + # perform input stream validation + return self.process(*streams, label=label) + + def to_config(self) -> dict[str, Any]: + """Serialize this function pod to a JSON-compatible config dict. + + Returns: + A JSON-serializable dict containing the URI, packet function config, + and node config for this function pod. + """ + config: dict[str, Any] = { + "uri": list(self.uri), + "packet_function": self.packet_function.to_config(), + "node_config": None, + } + if ( + self._node_config is not None + and self._node_config.max_concurrency is not None + ): + config["node_config"] = { + "max_concurrency": self._node_config.max_concurrency, + } + return config + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FunctionPod": + """Reconstruct a FunctionPod from a config dict. + + Args: + config: A dict as produced by :meth:`to_config`. + + Returns: + A new ``FunctionPod`` instance. + """ + from orcapod.pipeline.serialization import resolve_packet_function_from_config + + pf_config = config["packet_function"] + packet_function = resolve_packet_function_from_config(pf_config) + + node_config = None + if config.get("node_config") is not None: + node_config = NodeConfig(**config["node_config"]) + + return cls(packet_function=packet_function, node_config=node_config) + + # ------------------------------------------------------------------ + # Async channel execution (streaming mode) + # ------------------------------------------------------------------ + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + pipeline_config: PipelineConfig | None = None, + ) -> None: + """Streaming async execution with per-packet concurrency control. + + Each input (tag, packet) is processed independently. A semaphore + controls how many packets are in-flight concurrently. + """ + try: + pipeline_config = pipeline_config or PipelineConfig() + max_concurrency = resolve_concurrency(self._node_config, pipeline_config) + + sem = ( + asyncio.Semaphore(max_concurrency) + if max_concurrency is not None + else None + ) + + async def process_one(tag: TagProtocol, packet: PacketProtocol) -> None: + try: + tag, result_packet = await self.async_process_packet(tag, packet) + if result_packet is not None: + await output.send((tag, result_packet)) + finally: + if sem is not None: + sem.release() + + async with asyncio.TaskGroup() as tg: + async for tag, packet in inputs[0]: + if sem is not None: + await sem.acquire() + tg.create_task(process_one(tag, packet)) + finally: + await output.close() + + +class FunctionPodStream(StreamBase): + """Recomputable stream wrapping a packet function.""" + + def __init__( + self, function_pod: FunctionPodProtocol, input_stream: StreamProtocol, **kwargs + ) -> None: + self._function_pod = function_pod + self._input_stream = input_stream + super().__init__(**kwargs) + + # Iterator acquired lazily on first use to avoid triggering upstream + # computation during construction. + self._cached_input_iterator: ( + Iterator[tuple[TagProtocol, PacketProtocol]] | None + ) = None + self._needs_iterator = True + + # PacketProtocol-level caching (for the output packets) + self._cached_output_packets: dict[ + int, tuple[TagProtocol, PacketProtocol | None] + ] = {} + self._cached_output_table: pa.Table | None = None + self._cached_content_hash_column: pa.Array | None = None + + @property + def producer(self) -> PodProtocol: + return self._function_pod + + @property + def executor(self) -> PacketFunctionExecutorProtocol | None: + """The executor set on the underlying packet function.""" + return self._function_pod.packet_function.executor + + @executor.setter + def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor on the underlying packet function.""" + self._function_pod.packet_function.executor = executor + + @property + def upstreams(self) -> tuple[StreamProtocol, ...]: + return (self._input_stream,) + + def identity_structure(self) -> Any: + return ( + self._function_pod, + self._function_pod.argument_symmetry((self._input_stream,)), + ) + + def pipeline_identity_structure(self) -> Any: + return self.identity_structure() + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + tag_schema, packet_schema = self.output_schema( + columns=columns, all_info=all_info + ) + + return tuple(tag_schema.keys()), tuple(packet_schema.keys()) + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + return self._function_pod.output_schema( + self._input_stream, columns=columns, all_info=all_info + ) + + def _ensure_iterator(self) -> None: + """Lazily acquire the upstream iterator on first use.""" + if self._needs_iterator: + self._cached_input_iterator = self._input_stream.iter_packets() + self._needs_iterator = False + self._update_modified_time() + + def clear_cache(self) -> None: + """Discard all in-memory cached state.""" + self._cached_input_iterator = None + self._needs_iterator = True + self._cached_output_packets.clear() + self._cached_output_table = None + self._cached_content_hash_column = None + self._update_modified_time() + + def __iter__(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + return self.iter_packets() + + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + if self.is_stale: + self.clear_cache() + self._ensure_iterator() + if self._cached_input_iterator is not None: + if _executor_supports_concurrent(self._function_pod.packet_function): + yield from self._iter_packets_concurrent() + else: + yield from self._iter_packets_sequential() + else: + # Yield from snapshot of complete cache + for i in range(len(self._cached_output_packets)): + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet + + def _iter_packets_sequential( + self, + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + input_iter = self._cached_input_iterator + assert input_iter is not None + for i, (tag, packet) in enumerate(input_iter): + if i in self._cached_output_packets: + # Use cached result + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet + else: + # Process packet + tag, output_packet = self._function_pod.process_packet(tag, packet) + self._cached_output_packets[i] = (tag, output_packet) + if output_packet is not None: + yield tag, output_packet + + # Mark completion by releasing the iterator + self._cached_input_iterator = None + + def _iter_packets_concurrent( + self, + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Collect remaining inputs, execute concurrently, and yield results in order.""" + input_iter = self._cached_input_iterator + assert input_iter is not None + + # Materialise remaining inputs and separate cached from uncached. + all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] + to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] + for i, (tag, packet) in enumerate(input_iter): + all_inputs.append((i, tag, packet)) + if i not in self._cached_output_packets: + to_compute.append((i, tag, packet)) + self._cached_input_iterator = None + + # Submit uncached packets concurrently via async_process_packet. + if to_compute: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + results = [ + self._function_pod.process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + else: + + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self._function_pod.async_process_packet(tag, pkt) + for _, tag, pkt in to_compute + ] + ) + ) + + results = asyncio.run(_gather()) + + for (i, _, _), (tag, output_packet) in zip(to_compute, results): + self._cached_output_packets[i] = (tag, output_packet) + + # Yield everything in original order. + for i, *_ in all_inputs: + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + if self._cached_output_table is None: + all_tags = [] + all_packets = [] + tag_schema, packet_schema = None, None + for tag, packet in self.iter_packets(): + if tag_schema is None: + tag_schema = tag.arrow_schema(all_info=True) + if packet_schema is None: + packet_schema = packet.arrow_schema(all_info=True) + # TODO: make use of arrow_compat dict + all_tags.append(tag.as_dict(all_info=True)) + all_packets.append(packet.as_dict(all_info=True)) + + # TODO: re-verify the implemetation of this conversion + converter = self.data_context.type_converter + + struct_packets = converter.python_dicts_to_struct_dicts(all_packets) + all_tags_as_tables: pa.Table = pa.Table.from_pylist( + all_tags, schema=tag_schema + ) + # drop context key column from tags table (guard: column absent on empty stream) + if constants.CONTEXT_KEY in all_tags_as_tables.column_names: + all_tags_as_tables = all_tags_as_tables.drop([constants.CONTEXT_KEY]) + all_packets_as_tables: pa.Table = pa.Table.from_pylist( + struct_packets, schema=packet_schema + ) + + self._cached_output_table = arrow_utils.hstack_tables( + all_tags_as_tables, all_packets_as_tables + ) + assert self._cached_output_table is not None, ( + "_cached_output_table should not be None here." + ) + + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + + drop_columns = [] + if not column_config.system_tags: + # TODO: get system tags more effiicently + drop_columns.extend( + [ + c + for c in self._cached_output_table.column_names + if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + ) + if not column_config.source: + drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1]) + if not column_config.context: + drop_columns.append(constants.CONTEXT_KEY) + + output_table = self._cached_output_table.drop( + [c for c in drop_columns if c in self._cached_output_table.column_names] + ) + + # lazily prepare content hash column if requested + if column_config.content_hash: + if self._cached_content_hash_column is None: + content_hashes = [] + # TODO: verify that order will be preserved + for tag, packet in self.iter_packets(): + content_hashes.append(packet.content_hash().to_string()) + self._cached_content_hash_column = pa.array( + content_hashes, type=pa.large_string() + ) + assert self._cached_content_hash_column is not None, ( + "_cached_content_hash_column should not be None here." + ) + hash_column_name = ( + "_content_hash" + if column_config.content_hash is True + else column_config.content_hash + ) + output_table = output_table.append_column( + hash_column_name, self._cached_content_hash_column + ) + + if column_config.sort_by_tags: + # TODO: reimplement using polars natively + output_table = ( + pl.DataFrame(output_table) + .sort(by=self.keys()[0], descending=False) + .to_arrow() + ) + # output_table = output_table.sort_by( + # [(column, "ascending") for column in self.keys()[0]] + # ) + return output_table + + +class CallableWithPod(Protocol): + @property + def pod(self) -> _FunctionPodBase: + """Return the associated function pod.""" + ... + + def __call__(self, *args, **kwargs): + """Call the underlying function.""" + ... + + +def function_pod( + output_keys: str | Sequence[str] | None = None, + function_name: str | None = None, + version: str = "v0.0", + label: str | None = None, + result_database: ArrowDatabaseProtocol | None = None, + pod_cache_database: ArrowDatabaseProtocol | None = None, + executor: PacketFunctionExecutorProtocol | None = None, + **kwargs, +) -> Callable[..., CallableWithPod]: + """Decorator that attaches a ``FunctionPod`` as a ``pod`` attribute. + + Args: + output_keys: Keys for the function output(s). + function_name: Name of the function pod; defaults to ``func.__name__``. + version: Version string for the packet function. + label: Optional label for tracking. + result_database: Optional database for packet-level caching + (wraps the packet function in ``CachedPacketFunction``). + pod_cache_database: Optional database for pod-level caching + (wraps the pod in ``CachedFunctionPod``, which caches at the + ``process_packet`` level using input packet content hash). + executor: Optional executor for running the packet function. + **kwargs: Forwarded to ``PythonPacketFunction``. + + Returns: + A decorator that adds a ``pod`` attribute to the wrapped function. + """ + + def decorator(func: Callable) -> CallableWithPod: + if func.__name__ == "": + raise ValueError("Lambda functions cannot be used with function_pod") + + packet_function = PythonPacketFunction( + func, + output_keys=output_keys, + function_name=function_name or func.__name__, + version=version, + label=label, + executor=executor, + **kwargs, + ) + + # if database is provided, wrap in CachedPacketFunction + if result_database is not None: + packet_function = CachedPacketFunction( + packet_function, + result_database=result_database, + ) + + # Create a simple typed function pod + pod: _FunctionPodBase = FunctionPod( + packet_function=packet_function, + ) + + # if pod_cache_database is provided, wrap in CachedFunctionPod + if pod_cache_database is not None: + from orcapod.core.cached_function_pod import CachedFunctionPod + + pod = CachedFunctionPod( + function_pod=pod, + result_database=pod_cache_database, + ) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + setattr(wrapper, "pod", pod) + return cast(CallableWithPod, wrapper) + + return decorator + + +class WrappedFunctionPod(_FunctionPodBase): + """Wrapper for a function pod, delegating call logic to the inner pod.""" + + def __init__( + self, + function_pod: FunctionPodProtocol, + data_context: str | contexts.DataContext | None = None, + **kwargs, + ) -> None: + # if data_context is not explicitly given, use that of the contained pod + if data_context is None: + data_context = function_pod.data_context_key + super().__init__( + packet_function=function_pod.packet_function, + data_context=data_context, + **kwargs, + ) + self._function_pod = function_pod + + def computed_label(self) -> str | None: + return self._function_pod.label + + @property + def uri(self) -> tuple[str, ...]: + return self._function_pod.uri + + def validate_inputs(self, *streams: StreamProtocol) -> None: + self._function_pod.validate_inputs(*streams) + + def argument_symmetry(self, streams: Collection[StreamProtocol]) -> ArgumentGroup: + return self._function_pod.argument_symmetry(streams) + + def output_schema( + self, + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + return self._function_pod.output_schema( + *streams, columns=columns, all_info=all_info + ) + + # TODO: reconsider whether to return FunctionPodStream here in the signature + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + return self._function_pod.process(*streams, label=label) diff --git a/src/orcapod/core/kernels.py b/src/orcapod/core/kernels.py deleted file mode 100644 index 52e1f8c0..00000000 --- a/src/orcapod/core/kernels.py +++ /dev/null @@ -1,241 +0,0 @@ -from abc import abstractmethod -from collections.abc import Collection -from datetime import datetime, timezone -from typing import Any -from orcapod.protocols import core_protocols as cp -import logging -from orcapod.core.streams import KernelStream -from orcapod.core.base import LabeledContentIdentifiableBase -from orcapod.core.trackers import DEFAULT_TRACKER_MANAGER -from orcapod.types import PythonSchema - -logger = logging.getLogger(__name__) - - -class TrackedKernelBase(LabeledContentIdentifiableBase): - """ - Kernel defines the fundamental unit of computation that can be performed on zero, one or more streams of data. - It is the base class for all computations and transformations that can be performed on a collection of streams - (including an empty collection). - A kernel is defined as a callable that takes a (possibly empty) collection of streams as the input - and returns a new stream as output (note that output stream is always singular). - Each "invocation" of the kernel on a collection of streams is assigned a unique ID. - The corresponding invocation information is stored as Invocation object and attached to the output stream - for computational graph tracking. - """ - - def __init__( - self, - label: str | None = None, - skip_tracking: bool = False, - tracker_manager: cp.TrackerManager | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self._label = label - - self._skip_tracking = skip_tracking - self._tracker_manager = tracker_manager or DEFAULT_TRACKER_MANAGER - self._last_modified = None - self._kernel_hash = None - self._set_modified_time() - - @property - def reference(self) -> tuple[str, ...]: - """ - Returns a unique identifier for the kernel. - This is used to identify the kernel in the computational graph. - """ - return ( - f"{self.__class__.__name__}", - self.content_hash().to_hex(), - ) - - @property - def last_modified(self) -> datetime | None: - """ - When the kernel was last modified. For most kernels, this is the timestamp - of the kernel creation. - """ - return self._last_modified - - # TODO: reconsider making this a public method - def _set_modified_time( - self, timestamp: datetime | None = None, invalidate: bool = False - ) -> None: - """ - Sets the last modified time of the kernel. - If `invalidate` is True, it resets the last modified time to None to indicate unstable state that'd signal downstream - to recompute when using the kernel. Othewrise, sets the last modified time to the current time or to the provided timestamp. - """ - if invalidate: - self._last_modified = None - return - - if timestamp is not None: - self._last_modified = timestamp - else: - self._last_modified = datetime.now(timezone.utc) - - @abstractmethod - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Return the output types of the kernel given the input streams. - """ - ... - - def output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - processed_streams = self.pre_kernel_processing(*streams) - self.validate_inputs(*processed_streams) - return self.kernel_output_types( - *processed_streams, include_system_tags=include_system_tags - ) - - @abstractmethod - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - """ - Identity structure for this kernel. Input stream(s), if present, have already been preprocessed - and validated. - """ - ... - - def identity_structure(self, streams: Collection[cp.Stream] | None = None) -> Any: - """ - Default implementation of identity_structure for the kernel only - concerns the kernel class and the streams if present. Subclasses of - Kernels should override this method to provide a more meaningful - representation of the kernel. Note that kernel must provide the notion - of identity under possibly two distinct contexts: - 1) identity of the kernel in itself when invoked without any stream - 2) identity of the specific invocation of the kernel with a collection of streams - While the latter technically corresponds to the identity of the invocation and not - the kernel, only kernel can provide meaningful information as to the uniqueness of - the invocation as only kernel would know if / how the input stream(s) alter the identity - of the invocation. For example, if the kernel corresponds to an commutative computation - and therefore kernel K(x, y) == K(y, x), then the identity structure must reflect the - equivalence of the two by returning the same identity structure for both invocations. - This can be achieved, for example, by returning a set over the streams instead of a tuple. - """ - if streams is not None: - streams = self.pre_kernel_processing(*streams) - self.validate_inputs(*streams) - return self.kernel_identity_structure(streams) - - @abstractmethod - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Trigger the main computation of the kernel on a collection of streams. - This method is called when the kernel is invoked with a collection of streams. - Subclasses should override this method to provide the kernel with its unique behavior - """ - - def pre_kernel_processing(self, *streams: cp.Stream) -> tuple[cp.Stream, ...]: - """ - Pre-processing step that can be overridden by subclasses to perform any necessary pre-processing - on the input streams before the main computation. This is useful if you need to modify the input streams - or perform any other operations before the main computation. Critically, any Kernel/Pod invocations in the - pre-processing step will be tracked outside of the computation in the kernel. - Default implementation is a no-op, returning the input streams unchanged. - """ - return streams - - @abstractmethod - def validate_inputs(self, *streams: cp.Stream) -> None: - """ - Validate the input streams before the main computation but after the pre-kernel processing - """ - ... - - def prepare_output_stream( - self, *streams: cp.Stream, label: str | None = None - ) -> KernelStream: - """ - Prepare the output stream for the kernel invocation. - This method is called after the main computation is performed. - It creates a KernelStream with the provided streams and label. - """ - return KernelStream(source=self, upstreams=streams, label=label) - - def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> None: - """ - Track the invocation of the kernel with the provided streams. - This is a convenience method that calls record_kernel_invocation. - """ - if not self._skip_tracking and self._tracker_manager is not None: - self._tracker_manager.record_kernel_invocation(self, streams, label=label) - - def __call__( - self, *streams: cp.Stream, label: str | None = None, **kwargs - ) -> KernelStream: - processed_streams = self.pre_kernel_processing(*streams) - self.validate_inputs(*processed_streams) - output_stream = self.prepare_output_stream(*processed_streams, label=label) - self.track_invocation(*processed_streams, label=label) - return output_stream - - def __repr__(self): - return self.__class__.__name__ - - def __str__(self): - if self._label is not None: - return f"{self.__class__.__name__}({self._label})" - return self.__class__.__name__ - - -class WrappedKernel(TrackedKernelBase): - """ - A wrapper for a kernels useful when you want to use an existing kernel - but need to provide some extra functionality. - - Default implementation provides a simple passthrough to the wrapped kernel. - If you want to provide a custom behavior, be sure to override the methods - that you want to change. Note that the wrapped kernel must implement the - `Kernel` protocol. Refer to `orcapod.protocols.data_protocols.Kernel` for more details. - """ - - def __init__(self, kernel: cp.Kernel, **kwargs) -> None: - # TODO: handle fixed input stream already set on the kernel - super().__init__(**kwargs) - self.kernel = kernel - - def computed_label(self) -> str | None: - """ - Compute a label for this kernel based on its content. - If label is not explicitly set for this kernel and computed_label returns a valid value, - it will be used as label of this kernel. - """ - return self.kernel.label - - @property - def reference(self) -> tuple[str, ...]: - return self.kernel.reference - - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - return self.kernel.output_types( - *streams, include_system_tags=include_system_tags - ) - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - return self.kernel.identity_structure(streams) - - def validate_inputs(self, *streams: cp.Stream) -> None: - return self.kernel.validate_inputs(*streams) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - return self.kernel.forward(*streams) - - def __repr__(self): - return f"WrappedKernel({self.kernel!r})" - - def __str__(self): - return f"WrappedKernel:{self.kernel!s}" diff --git a/src/orcapod/core/nodes/__init__.py b/src/orcapod/core/nodes/__init__.py new file mode 100644 index 00000000..5d2ef1ee --- /dev/null +++ b/src/orcapod/core/nodes/__init__.py @@ -0,0 +1,14 @@ +from typing import TypeAlias + +from .function_node import FunctionNode +from .operator_node import OperatorNode +from .source_node import SourceNode + +GraphNode: TypeAlias = SourceNode | FunctionNode | OperatorNode + +__all__ = [ + "FunctionNode", + "GraphNode", + "OperatorNode", + "SourceNode", +] diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py new file mode 100644 index 00000000..630dac05 --- /dev/null +++ b/src/orcapod/core/nodes/function_node.py @@ -0,0 +1,1355 @@ +"""FunctionNode — stream node for packet function invocations with optional DB persistence.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any, cast + +from orcapod import contexts +from orcapod.channels import ReadableChannel, WritableChannel +from orcapod.config import Config +from orcapod.core.cached_function_pod import CachedFunctionPod +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.core.streams.base import StreamBase +from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER +from orcapod.protocols.core_protocols import ( + FunctionPodProtocol, + PacketFunctionExecutorProtocol, + PacketFunctionProtocol, + PacketProtocol, + StreamProtocol, + TagProtocol, + TrackerManagerProtocol, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.types import ( + ColumnConfig, + ContentHash, + Schema, +) +from orcapod.utils import arrow_utils, schema_utils +from orcapod.utils.lazy_module import LazyModule + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa + + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + + +def _executor_supports_concurrent( + packet_function: PacketFunctionProtocol, +) -> bool: + """Return True if the packet function's executor supports concurrent execution.""" + executor = packet_function.executor + return executor is not None and executor.supports_concurrent_execution + + +class FunctionNode(StreamBase): + """Stream node representing a packet function invocation with optional DB persistence. + + When constructed without database parameters, provides the core stream + interface (identity, schema, iteration) without any persistence. When + databases are provided (either at construction or via ``attach_databases``), + adds result caching via ``CachedFunctionPod``, pipeline record storage, + and two-phase iteration (cached first, then compute missing). + """ + + node_type = "function" + + def __init__( + self, + function_pod: FunctionPodProtocol, + input_stream: StreamProtocol, + tracker_manager: TrackerManagerProtocol | None = None, + label: str | None = None, + config: Config | None = None, + # Optional DB params for persistent mode: + pipeline_database: ArrowDatabaseProtocol | None = None, + result_database: ArrowDatabaseProtocol | None = None, + result_path_prefix: tuple[str, ...] | None = None, + pipeline_path_prefix: tuple[str, ...] = (), + ): + if tracker_manager is None: + tracker_manager = DEFAULT_TRACKER_MANAGER + self.tracker_manager = tracker_manager + self._packet_function = function_pod.packet_function + + # FunctionPod used for the `producer` property and pipeline identity + self._function_pod = function_pod + super().__init__(label=label, config=config) + + # validate the input stream + _, incoming_packet_types = input_stream.output_schema() + expected_packet_schema = self._packet_function.input_packet_schema + if not schema_utils.check_schema_compatibility( + incoming_packet_types, expected_packet_schema + ): + raise ValueError( + f"Incoming packet data type {incoming_packet_types} from {input_stream} " + f"is not compatible with expected input schema {expected_packet_schema}" + ) + + self._input_stream = input_stream + + # stream-level caching state (iterator acquired lazily on first use) + self._cached_input_iterator: ( + Iterator[tuple[TagProtocol, PacketProtocol]] | None + ) = None + self._needs_iterator = True + self._cached_output_packets: dict[ + int, tuple[TagProtocol, PacketProtocol | None] + ] = {} + self._cached_output_table: pa.Table | None = None + self._cached_content_hash_column: pa.Array | None = None + + # DB persistence state (initially None; set via __init__ params or attach_databases) + self._pipeline_database: ArrowDatabaseProtocol | None = None + self._cached_function_pod: CachedFunctionPod | None = None + self._pipeline_path_prefix: tuple[str, ...] = () + self._pipeline_node_hash: str | None = None + self._output_schema_hash: str | None = None + + if pipeline_database is not None: + self.attach_databases( + pipeline_database=pipeline_database, + result_database=result_database, + result_path_prefix=result_path_prefix, + pipeline_path_prefix=pipeline_path_prefix, + ) + + # ------------------------------------------------------------------ + # attach_databases + # ------------------------------------------------------------------ + + def attach_databases( + self, + pipeline_database: ArrowDatabaseProtocol, + result_database: ArrowDatabaseProtocol | None = None, + result_path_prefix: tuple[str, ...] | None = None, + pipeline_path_prefix: tuple[str, ...] = (), + ) -> None: + """Attach databases for persistent caching and pipeline records. + + Creates a ``CachedFunctionPod`` wrapping the original function pod + for result caching. The pipeline database is used separately for + pipeline-level provenance records (tag + packet hash). + + Args: + pipeline_database: Database for pipeline records. + result_database: Database for cached results. Defaults to + pipeline_database. + result_path_prefix: Path prefix for result records. + pipeline_path_prefix: Path prefix for pipeline records. + """ + computed_result_path_prefix: tuple[str, ...] = () + if result_database is None: + result_database = pipeline_database + computed_result_path_prefix = ( + result_path_prefix + if result_path_prefix is not None + else pipeline_path_prefix + ("_result",) + ) + elif result_path_prefix is not None: + computed_result_path_prefix = result_path_prefix + + # Always wrap the original function_pod (not a previous cached wrapper) + self._cached_function_pod = CachedFunctionPod( + self._function_pod, + result_database=result_database, + record_path_prefix=computed_result_path_prefix, + ) + + self._pipeline_database = pipeline_database + self._pipeline_path_prefix = pipeline_path_prefix + + # Clear all caches + self.clear_cache() + self._content_hash_cache.clear() + self._pipeline_hash_cache.clear() + + # Compute pipeline node hash + self._pipeline_node_hash = self.pipeline_hash().to_string() + self._output_schema_hash = self.data_context.semantic_hasher.hash_object( + self._packet_function.output_packet_schema + ).to_string() + + # ------------------------------------------------------------------ + # from_descriptor — reconstruct from a serialized pipeline descriptor + # ------------------------------------------------------------------ + + @classmethod + def from_descriptor( + cls, + descriptor: dict[str, Any], + function_pod: FunctionPodProtocol | None, + input_stream: StreamProtocol | None, + databases: dict[str, Any], + ) -> "FunctionNode": + """Construct a FunctionNode from a serialized descriptor. + + When *function_pod* and *input_stream* are both provided the node + operates in full mode -- constructed normally via ``__init__``. + When *function_pod* is ``None`` the node is created in read-only + mode with metadata from the descriptor; computation methods will + raise ``RuntimeError``. + + Args: + descriptor: The serialized node descriptor dict. + function_pod: An optional live function pod. ``None`` for + read-only mode. + input_stream: An optional live input stream. ``None`` for + read-only mode. + databases: Mapping of database role names (``"pipeline"``, + ``"result"``) to database instances. + + Returns: + A new ``FunctionNode`` instance. + """ + from orcapod.pipeline.serialization import LoadStatus + + pipeline_db = databases.get("pipeline") + result_db = databases.get("result", pipeline_db) + + if function_pod is not None and input_stream is not None: + # Full mode: construct normally + pipeline_path = tuple(descriptor.get("pipeline_path", ())) + # Derive pipeline_path_prefix by stripping the suffix that + # __init__ appends (packet_function.uri + node hash element). + # We pass the full pipeline_path_prefix from the descriptor. + # The descriptor stores the complete pipeline_path; we need + # to reconstruct the prefix that was originally passed to + # __init__. The suffix added is: pf.uri + (f"node:{hash}",). + # Instead of reverse-engineering, use the descriptor's path + # minus what __init__ will add. For full mode we let __init__ + # recompute pipeline_path from the prefix. + pf_uri_len = len(function_pod.packet_function.uri) + 1 # +1 for node:hash + prefix = ( + pipeline_path[:-pf_uri_len] if len(pipeline_path) > pf_uri_len else () + ) + + node = cls( + function_pod=function_pod, + input_stream=input_stream, + pipeline_database=pipeline_db, + result_database=result_db, + pipeline_path_prefix=prefix, + label=descriptor.get("label"), + ) + node._descriptor = descriptor + node._load_status = LoadStatus.FULL + return node + + # Read-only mode: bypass __init__, set minimum required state + node = cls.__new__(cls) + + # From LabelableMixin + node._label = descriptor.get("label") + + # From DataContextMixin + node._data_context = contexts.resolve_context( + descriptor.get("data_context_key") + ) + from orcapod.config import DEFAULT_CONFIG + + node._orcapod_config = DEFAULT_CONFIG + + # From ContentIdentifiableBase + node._content_hash_cache = {} + node._cached_int_hash = None + + # From PipelineElementBase + node._pipeline_hash_cache = {} + + # From TemporalMixin + node._modified_time = None + + # From FunctionNode + node._function_pod = None + node._packet_function = None + node._input_stream = None + node.tracker_manager = DEFAULT_TRACKER_MANAGER + node._cached_input_iterator = None + node._needs_iterator = True + node._cached_output_packets = {} + node._cached_output_table = None + node._cached_content_hash_column = None + + # DB persistence state + node._pipeline_database = pipeline_db + node._cached_function_pod = None + node._pipeline_path_prefix = () + node._pipeline_node_hash = None + node._output_schema_hash = None + + # Descriptor metadata for read-only access + node._descriptor = descriptor + node._stored_schema = descriptor.get("output_schema", {}) + node._stored_content_hash = descriptor.get("content_hash") + node._stored_pipeline_hash = descriptor.get("pipeline_hash") + node._stored_pipeline_path = tuple(descriptor.get("pipeline_path", ())) + node._stored_result_record_path = tuple( + descriptor.get("result_record_path", ()) + ) + + # Determine load status based on DB availability + node._load_status = LoadStatus.UNAVAILABLE + if pipeline_db is not None: + node._load_status = LoadStatus.READ_ONLY + + return node + + # ------------------------------------------------------------------ + # load_status + # ------------------------------------------------------------------ + + @property + def load_status(self) -> Any: + """Return the load status of this node. + + Returns: + The ``LoadStatus`` enum value indicating how this node was + loaded. Defaults to ``FULL`` for nodes created via + ``__init__``. + """ + from orcapod.pipeline.serialization import LoadStatus + + return getattr(self, "_load_status", LoadStatus.FULL) + + # ------------------------------------------------------------------ + # Core properties + # ------------------------------------------------------------------ + + @property + def producer(self) -> FunctionPodProtocol: + return self._function_pod + + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self._function_pod.data_context_key) + + @property + def data_context_key(self) -> str: + return self._function_pod.data_context_key + + @property + def executor(self) -> PacketFunctionExecutorProtocol | None: + """The executor set on the underlying packet function.""" + return self._packet_function.executor + + @executor.setter + def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor on the underlying packet function.""" + self._packet_function.executor = executor + + @property + def upstreams(self) -> tuple[StreamProtocol, ...]: + return (self._input_stream,) + + @upstreams.setter + def upstreams(self, value: tuple[StreamProtocol, ...]) -> None: + if len(value) != 1: + raise ValueError("FunctionPod can only have one upstream") + self._input_stream = value[0] + + # ------------------------------------------------------------------ + # Read-only overrides (for deserialized nodes without live function_pod) + # ------------------------------------------------------------------ + + def content_hash(self, hasher=None) -> ContentHash: + """Return the content hash, using stored value in read-only mode.""" + stored = getattr(self, "_stored_content_hash", None) + if self._function_pod is None and stored is not None: + from orcapod.types import ContentHash as CH + + return CH.from_string(stored) + return super().content_hash(hasher) + + def pipeline_hash(self, hasher=None) -> ContentHash: + """Return the pipeline hash, using stored value in read-only mode.""" + stored = getattr(self, "_stored_pipeline_hash", None) + if self._function_pod is None and stored is not None: + from orcapod.types import ContentHash as CH + + return CH.from_string(stored) + return super().pipeline_hash(hasher) + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """Return output schema, using stored value in read-only mode.""" + if self._function_pod is None: + stored = getattr(self, "_stored_schema", {}) + tag = Schema(stored.get("tag", {})) + packet = Schema(stored.get("packet", {})) + return tag, packet + tag_schema = self._input_stream.output_schema( + columns=columns, all_info=all_info + )[0] + return tag_schema, self._packet_function.output_packet_schema + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + if self._function_pod is None: + stored = getattr(self, "_stored_schema", {}) + tag_keys = tuple(stored.get("tag", {}).keys()) + packet_keys = tuple(stored.get("packet", {}).keys()) + return tag_keys, packet_keys + tag_schema, packet_schema = self.output_schema( + columns=columns, all_info=all_info + ) + return tuple(tag_schema.keys()), tuple(packet_schema.keys()) + + # ------------------------------------------------------------------ + # Pipeline path + # ------------------------------------------------------------------ + + @property + def pipeline_path(self) -> tuple[str, ...]: + """Return the pipeline path for DB record scoping. + + Raises: + RuntimeError: If no database is attached and this is not a + read-only deserialized node. + """ + stored = getattr(self, "_stored_pipeline_path", None) + if self._packet_function is None and stored is not None: + return stored + if self._pipeline_database is None: + raise RuntimeError( + "Cannot compute pipeline_path without an attached database. " + "Call attach_databases() first." + ) + return ( + self._pipeline_path_prefix + + self._packet_function.uri + + (f"node:{self._pipeline_node_hash}",) + ) + + # ------------------------------------------------------------------ + # Caching + # ------------------------------------------------------------------ + + def _ensure_iterator(self) -> None: + """Lazily acquire the upstream iterator on first use.""" + if self._needs_iterator: + self._cached_input_iterator = self._input_stream.iter_packets() + self._needs_iterator = False + self._update_modified_time() + + def clear_cache(self) -> None: + self._cached_input_iterator = None + self._needs_iterator = True + self._cached_output_packets.clear() + self._cached_output_table = None + self._cached_content_hash_column = None + self._update_modified_time() + + # ------------------------------------------------------------------ + # Packet processing + # ------------------------------------------------------------------ + + def execute_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + ) -> tuple[TagProtocol, PacketProtocol | None]: + """Execute a single packet: compute, persist, and cache. + + Internal method for orchestrators. The caller must guarantee that + the tag and packet conform to the expected input schema (matching + ``self._input_stream``). No validation is performed. + + Args: + tag: The tag associated with the packet. + packet: The input packet to process. + + Returns: + A ``(tag, output_packet)`` tuple. CapturedLogs are discarded. + """ + tag_out, result, _captured = self._process_packet_internal(tag, packet) + return tag_out, result + + def execute( + self, + input_stream: StreamProtocol, + *, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", + ) -> list[tuple[TagProtocol, PacketProtocol]]: + """Execute all packets from a stream: compute, persist, and cache. + + Args: + input_stream: The input stream to process. + observer: Optional execution observer for hooks. + error_policy: ``"continue"`` (default) skips failed packets; + ``"fail_fast"`` re-raises on the first failure. + + Returns: + Materialized list of (tag, output_packet) pairs, excluding + None outputs and failed packets. + """ + from orcapod.pipeline.observer import _NOOP_LOGGER + + if observer is not None: + observer.on_node_start(self) + + # Gather entry IDs and check cache + upstream_entries = [ + (tag, packet, self.compute_pipeline_entry_id(tag, packet)) + for tag, packet in input_stream.iter_packets() + ] + entry_ids = [eid for _, _, eid in upstream_entries] + cached = self.get_cached_results(entry_ids=entry_ids) + + pp = self.pipeline_path if self._pipeline_database is not None else () + + output: list[tuple[TagProtocol, PacketProtocol]] = [] + for tag, packet, entry_id in upstream_entries: + if observer is not None: + observer.on_packet_start(self, tag, packet) + pkt_logger = observer.create_packet_logger( + self, tag, packet, pipeline_path=pp + ) + else: + pkt_logger = _NOOP_LOGGER + + if entry_id in cached: + tag_out, result = cached[entry_id] + if observer is not None: + observer.on_packet_end(self, tag, packet, result, cached=True) + output.append((tag_out, result)) + else: + tag_out, result, captured = self._process_packet_internal(tag, packet) + pkt_logger.record(captured) + if not captured.success: + if observer is not None: + observer.on_packet_crash( + self, + tag, + packet, + RuntimeError( + captured.traceback or "packet function failed" + ), + ) + if error_policy == "fail_fast": + if observer is not None: + observer.on_node_end(self) + raise RuntimeError( + captured.traceback or "packet function failed" + ) + else: + if observer is not None: + observer.on_packet_end( + self, tag, packet, result, cached=False + ) + if result is not None: + output.append((tag_out, result)) + + if observer is not None: + observer.on_node_end(self) + return output + + def _process_packet_internal( + self, + tag: TagProtocol, + packet: PacketProtocol, + cache_index: int | None = None, + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + """Core compute + persist + cache. + + Used by ``execute_packet``, ``execute``, and ``iter_packets``. + No input validation is performed — the caller guarantees correctness. + + Returns: + A ``(tag, output_packet, captured_logs)`` 3-tuple. + + Args: + tag: The input tag. + packet: The input packet. + cache_index: Optional explicit index for the internal cache. + When ``None``, auto-assigns at ``len(_cached_output_packets)``. + """ + if self._cached_function_pod is not None: + tag_out, output_packet, captured = ( + self._cached_function_pod.process_packet_with_capture(tag, packet) + ) + + if output_packet is not None and captured.success: + result_computed = bool( + output_packet.get_meta_value( + self._cached_function_pod.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + else: + tag_out, output_packet, captured = ( + self._function_pod.process_packet_with_capture(tag, packet) + ) + + # Cache internally and invalidate derived caches + idx = ( + cache_index if cache_index is not None else len(self._cached_output_packets) + ) + self._cached_output_packets[idx] = (tag_out, output_packet) + self._cached_input_iterator = None + self._needs_iterator = False + self._cached_output_table = None + self._cached_content_hash_column = None + + return tag_out, output_packet, captured + + def get_cached_results( + self, entry_ids: list[str] + ) -> dict[str, tuple[TagProtocol, PacketProtocol]]: + """Retrieve cached results for specific pipeline entry IDs. + + Looks up the pipeline DB and result DB, joins them, and filters + to the requested entry IDs. Returns a mapping from entry ID to + (tag, output_packet). + + Args: + entry_ids: Pipeline entry IDs to look up. + + Returns: + Mapping from entry_id to (tag, output_packet) for found entries. + Empty dict if no DB is attached or no matches found. + """ + if self._cached_function_pod is None or not entry_ids: + return {} + + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + entry_id_set = set(entry_ids) + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is None or results is None: + return {} + + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() + ) + + if joined.num_rows == 0: + return {} + + # Filter to requested entry IDs + all_entry_ids = joined.column(PIPELINE_ENTRY_ID_COL).to_pylist() + mask = [eid in entry_id_set for eid in all_entry_ids] + filtered = joined.filter(pa.array(mask)) + + if filtered.num_rows == 0: + return {} + + tag_keys = self._input_stream.keys()[0] + drop_cols = [ + c + for c in filtered.column_names + if c.startswith(constants.META_PREFIX) or c == PIPELINE_ENTRY_ID_COL + ] + data_table = filtered.drop([c for c in drop_cols if c in filtered.column_names]) + + stream = ArrowTableStream(data_table, tag_columns=tag_keys) + filtered_entry_ids = [eid for eid, m in zip(all_entry_ids, mask) if m] + + result_dict: dict[str, tuple[TagProtocol, PacketProtocol]] = {} + for entry_id, (tag, packet) in zip(filtered_entry_ids, stream.iter_packets()): + result_dict[entry_id] = (tag, packet) + + # Populate internal cache with retrieved results (clear first to + # avoid duplicates on repeated orchestrator runs) + self._cached_output_packets.clear() + self._cached_output_table = None + self._cached_content_hash_column = None + for entry_id, (tag, packet) in result_dict.items(): + next_idx = len(self._cached_output_packets) + self._cached_output_packets[next_idx] = (tag, packet) + self._cached_input_iterator = None + self._needs_iterator = False + + return result_dict + + async def _async_process_packet_internal( + self, + tag: TagProtocol, + packet: PacketProtocol, + cache_index: int | None = None, + ) -> "tuple[TagProtocol, PacketProtocol | None, CapturedLogs]": + """Async counterpart of ``_process_packet_internal``. + + Computes via async path, writes pipeline provenance, and caches + internally — no schema validation. + + Returns: + A ``(tag, output_packet, captured_logs)`` 3-tuple. + + Args: + tag: The input tag. + packet: The input packet. + cache_index: Optional explicit index for the internal cache. + When ``None``, auto-assigns at ``len(_cached_output_packets)``. + """ + if self._cached_function_pod is not None: + tag_out, output_packet, captured = ( + await self._cached_function_pod.async_process_packet_with_capture( + tag, packet + ) + ) + + if output_packet is not None and captured.success: + result_computed = bool( + output_packet.get_meta_value( + self._cached_function_pod.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + else: + tag_out, output_packet, captured = ( + await self._function_pod.async_process_packet_with_capture(tag, packet) + ) + + # Cache internally and invalidate derived caches + idx = ( + cache_index if cache_index is not None else len(self._cached_output_packets) + ) + self._cached_output_packets[idx] = (tag_out, output_packet) + self._cached_input_iterator = None + self._needs_iterator = False + self._cached_output_table = None + self._cached_content_hash_column = None + + return tag_out, output_packet, captured + + def compute_pipeline_entry_id( + self, tag: TagProtocol, input_packet: PacketProtocol + ) -> str: + """Compute a unique pipeline entry ID from tag + system tags + input packet hash. + + This ID uniquely identifies a (tag, system_tags, input_packet) combination + and is used as the record ID in the pipeline database. + + Args: + tag: The tag (including system tags). + input_packet: The input packet. + + Returns: + A hash string uniquely identifying this combination. + """ + tag_with_hash = tag.as_table(columns={"system_tags": True}).append_column( + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), + ) + return self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() + + def add_pipeline_record( + self, + tag: TagProtocol, + input_packet: PacketProtocol, + packet_record_id: str, + computed: bool, + skip_cache_lookup: bool = False, + ) -> None: + """Add a pipeline record to the database for a processed packet. + + The pipeline record stores: + - Tag columns (including system tags) + - All source columns of the input packet (provenance, not data) + - Output packet record ID (for joining with result records) + - Input packet data context key + - Whether the result was freshly computed or cached + """ + entry_id = self.compute_pipeline_entry_id(tag, input_packet) + + # Check for existing entry + existing_record = None + if not skip_cache_lookup: + existing_record = self._pipeline_database.get_record_by_id( + self.pipeline_path, + entry_id, + ) + + if existing_record is not None: + logger.debug( + f"Record with entry_id {entry_id} already exists. Skipping addition." + ) + return + + # Extract source columns only (no data columns) from the input packet + input_table_with_source = input_packet.as_table(columns={"source": True}) + source_col_names = [ + c + for c in input_table_with_source.column_names + if c.startswith(constants.SOURCE_PREFIX) + ] + input_source_table = input_table_with_source.select(source_col_names) + + # Build the meta columns table + meta_table = pa.table( + { + constants.PACKET_RECORD_ID: pa.array( + [packet_record_id], type=pa.large_string() + ), + f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}": pa.array( + [input_packet.data_context_key], type=pa.large_string() + ), + f"{constants.META_PREFIX}computed": pa.array( + [computed], type=pa.bool_() + ), + } + ) + + # Combine: tag (with system tags) + input source columns + meta columns + combined_record = arrow_utils.hstack_tables( + tag.as_table(columns={"system_tags": True}), + input_source_table, + meta_table, + ) + + self._pipeline_database.add_record( + self.pipeline_path, + entry_id, + combined_record, + skip_duplicates=False, + ) + + # ------------------------------------------------------------------ + # Records and sources + # ------------------------------------------------------------------ + + def get_all_records( + self, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table | None": + """Return all computed results joined with their pipeline tag records. + + Args: + columns: Column configuration controlling which groups are included. + all_info: Shorthand to include all info columns. + + Returns: + A PyArrow table of joined results, or ``None`` if no database is + attached or no records exist. + """ + if self._cached_function_pod is None: + return None + + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + taginfo = self._pipeline_database.get_all_records(self.pipeline_path) + + if results is None or taginfo is None: + return None + + joined = ( + pl.DataFrame(taginfo) + .join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner") + .to_arrow() + ) + + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + + drop_columns = [] + if not column_config.meta and not column_config.all_info: + drop_columns.extend( + c for c in joined.column_names if c.startswith(constants.META_PREFIX) + ) + if not column_config.source and not column_config.all_info: + drop_columns.extend( + c for c in joined.column_names if c.startswith(constants.SOURCE_PREFIX) + ) + if not column_config.system_tags and not column_config.all_info: + drop_columns.extend( + c + for c in joined.column_names + if c.startswith(constants.SYSTEM_TAG_PREFIX) + ) + if drop_columns: + joined = joined.drop([c for c in drop_columns if c in joined.column_names]) + + return joined if joined.num_rows > 0 else None + + def as_source(self): + """Return a DerivedSource backed by the DB records of this node. + + Raises: + RuntimeError: If no database is attached. + """ + if self._pipeline_database is None: + raise RuntimeError("Cannot create a DerivedSource without a database") + + from orcapod.core.sources.derived_source import DerivedSource + + path_str = "/".join(self.pipeline_path) + content_frag = self.content_hash().to_string()[:16] + source_id = f"{path_str}:{content_frag}" + return DerivedSource( + origin=self, + source_id=source_id, + data_context=self.data_context_key, + config=self.orcapod_config, + ) + + # ------------------------------------------------------------------ + # Iteration + # ------------------------------------------------------------------ + + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + if self.is_stale: + self.clear_cache() + self._ensure_iterator() + + if self._cached_function_pod is not None: + # Two-phase iteration with DB backing + if self._cached_input_iterator is not None: + input_iter = self._cached_input_iterator + # --- Phase 1: yield already-computed results from the databases --- + # Retrieve pipeline records with their entry_ids (record IDs) + # and join with result records to reconstruct (tag, output_packet). + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + existing_entry_ids: set[str] = set() + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() + ) + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + # Collect pipeline entry_ids for Phase 2 skip check + existing_entry_ids = set( + cast( + list[str], + joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), + ) + ) + # Drop internal columns before yielding as stream + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for i, (tag, packet) in enumerate( + existing_stream.iter_packets() + ): + self._cached_output_packets[i] = (tag, packet) + yield tag, packet + + # --- Phase 2: process only missing input packets --- + # Skip inputs whose pipeline entry_id (tag+system_tags+packet_hash) + # already exists in the pipeline database. + for tag, packet in input_iter: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in existing_entry_ids: + continue + tag, output_packet, _captured = self._process_packet_internal( + tag, packet + ) + if output_packet is not None: + yield tag, output_packet + + self._cached_input_iterator = None + else: + # Yield from snapshot of complete cache + for i in range(len(self._cached_output_packets)): + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet + else: + # Simple iteration without DB + if self._cached_input_iterator is not None: + if _executor_supports_concurrent(self._packet_function): + yield from self._iter_packets_concurrent( + self._cached_input_iterator + ) + else: + yield from self._iter_packets_sequential( + self._cached_input_iterator + ) + else: + for i in range(len(self._cached_output_packets)): + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet + + def _iter_packets_sequential( + self, input_iter: Iterator[tuple[TagProtocol, PacketProtocol]] + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + for i, (tag, packet) in enumerate(input_iter): + if i in self._cached_output_packets: + tag, packet = self._cached_output_packets[i] + if packet is not None: + yield tag, packet + else: + tag, output_packet, _captured = self._process_packet_internal( + tag, packet + ) + if output_packet is not None: + yield tag, output_packet + self._cached_input_iterator = None + + def _iter_packets_concurrent( + self, + input_iter: Iterator[tuple[TagProtocol, PacketProtocol]], + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Collect remaining inputs, execute concurrently, and yield results in order.""" + + all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] + to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] + for i, (tag, packet) in enumerate(input_iter): + all_inputs.append((i, tag, packet)) + if i not in self._cached_output_packets: + to_compute.append((i, tag, packet)) + self._cached_input_iterator = None + + if to_compute: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + # Already in event loop — fall back to sequential sync + for i, tag, pkt in to_compute: + self._process_packet_internal(tag, pkt, cache_index=i) + else: + + async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: + return list( + await asyncio.gather( + *[ + self._async_process_packet_internal( + tag, pkt, cache_index=i + ) + for i, tag, pkt in to_compute + ] + ) + ) + + asyncio.run(_gather()) + + # Yield all results in order from internal cache + for idx in sorted(self._cached_output_packets.keys()): + tag, packet = self._cached_output_packets[idx] + if packet is not None: + yield tag, packet + + def run(self) -> None: + """Eagerly process all input packets, filling the pipeline and result databases.""" + for _ in self.iter_packets(): + pass + + # ------------------------------------------------------------------ + # as_table + # ------------------------------------------------------------------ + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + if self._cached_output_table is None: + all_tags = [] + all_packets = [] + tag_schema, packet_schema = None, None + for tag, packet in self.iter_packets(): + if tag_schema is None: + tag_schema = tag.arrow_schema(all_info=True) + if packet_schema is None: + packet_schema = packet.arrow_schema(all_info=True) + all_tags.append(tag.as_dict(all_info=True)) + all_packets.append(packet.as_dict(all_info=True)) + + converter = self.data_context.type_converter + + struct_packets = converter.python_dicts_to_struct_dicts(all_packets) + all_tags_as_tables: pa.Table = pa.Table.from_pylist( + all_tags, schema=tag_schema + ) + if constants.CONTEXT_KEY in all_tags_as_tables.column_names: + all_tags_as_tables = all_tags_as_tables.drop([constants.CONTEXT_KEY]) + all_packets_as_tables: pa.Table = pa.Table.from_pylist( + struct_packets, schema=packet_schema + ) + + self._cached_output_table = arrow_utils.hstack_tables( + all_tags_as_tables, all_packets_as_tables + ) + assert self._cached_output_table is not None, ( + "_cached_output_table should not be None here." + ) + + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + + drop_columns = [] + if not column_config.system_tags: + drop_columns.extend( + [ + c + for c in self._cached_output_table.column_names + if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + ) + if not column_config.source: + drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1]) + if not column_config.context: + drop_columns.append(constants.CONTEXT_KEY) + if not column_config.meta: + drop_columns.extend( + c + for c in self._cached_output_table.column_names + if c.startswith(constants.META_PREFIX) + ) + elif not isinstance(column_config.meta, bool): + # Collection[str]: keep only meta columns matching the specified prefixes + drop_columns.extend( + c + for c in self._cached_output_table.column_names + if c.startswith(constants.META_PREFIX) + and not any(c.startswith(p) for p in column_config.meta) + ) + output_table = self._cached_output_table.drop( + [c for c in drop_columns if c in self._cached_output_table.column_names] + ) + + if column_config.content_hash: + if self._cached_content_hash_column is None: + content_hashes = [] + for tag, packet in self.iter_packets(): + content_hashes.append(packet.content_hash().to_string()) + self._cached_content_hash_column = pa.array( + content_hashes, type=pa.large_string() + ) + assert self._cached_content_hash_column is not None, ( + "_cached_content_hash_column should not be None here." + ) + hash_column_name = ( + "_content_hash" + if column_config.content_hash is True + else column_config.content_hash + ) + output_table = output_table.append_column( + hash_column_name, self._cached_content_hash_column + ) + + if column_config.sort_by_tags: + output_table = ( + pl.DataFrame(output_table) + .sort(by=self.keys()[0], descending=False) + .to_arrow() + ) + return output_table + + # ------------------------------------------------------------------ + # Async channel execution + # ------------------------------------------------------------------ + + async def async_execute( + self, + input_channel: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> None: + """Streaming async execution for FunctionNode. + + When a database is attached, uses two-phase execution: replay cached + results first, then compute missing packets concurrently. Otherwise, + routes each packet through ``async_process_packet`` directly. + + Args: + input_channel: Single readable channel of (tag, packet) pairs. + output: Writable channel for output (tag, packet) pairs. + observer: Optional execution observer for hooks. + """ + # TODO(PLT-930): Restore concurrency limiting (semaphore) via node-level config. + # Currently all packets are processed sequentially in async_execute. + try: + if observer is not None: + observer.on_node_start(self) + + if self._cached_function_pod is not None: + # DB-backed async execution: + # Phase 1: build cache lookup from pipeline DB + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + cached_by_entry_id: dict[ + str, tuple[TagProtocol, PacketProtocol] + ] = {} + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() + ) + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + entry_ids_col = joined.column( + PIPELINE_ENTRY_ID_COL + ).to_pylist() + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for eid, (tag_out, pkt_out) in zip( + entry_ids_col, existing_stream.iter_packets() + ): + cached_by_entry_id[eid] = (tag_out, pkt_out) + + # Phase 2: drive output from input channel — cached or compute + async for tag, packet in input_channel: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in cached_by_entry_id: + tag_out, result_packet = cached_by_entry_id[entry_id] + if observer is not None: + observer.on_packet_start(self, tag, packet) + observer.on_packet_end( + self, tag, packet, result_packet, cached=True + ) + await output.send((tag_out, result_packet)) + else: + await self._async_execute_one_packet( + tag, packet, output, observer=observer + ) + else: + # Simple async execution without DB + async for tag, packet in input_channel: + await self._async_execute_one_packet( + tag, packet, output, observer=observer + ) + + if observer is not None: + observer.on_node_end(self) + finally: + await output.close() + + async def _async_execute_one_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> None: + """Process one non-cached packet in the async execute path.""" + from orcapod.pipeline.observer import _NOOP_LOGGER + + pp = self.pipeline_path if self._pipeline_database is not None else () + + if observer is not None: + observer.on_packet_start(self, tag, packet) + pkt_logger = observer.create_packet_logger( + self, tag, packet, pipeline_path=pp + ) + else: + pkt_logger = _NOOP_LOGGER + + tag_out, result_packet, captured = await self._async_process_packet_internal( + tag, packet + ) + pkt_logger.record(captured) + if not captured.success: + if observer is not None: + observer.on_packet_crash( + self, + tag, + packet, + RuntimeError(captured.traceback or "packet function failed"), + ) + else: + if observer is not None: + observer.on_packet_end( + self, tag, packet, result_packet, cached=False + ) + if result_packet is not None: + await output.send((tag_out, result_packet)) + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(packet_function={self._packet_function!r}, " + f"input_stream={self._input_stream!r})" + ) diff --git a/src/orcapod/core/nodes/operator_node.py b/src/orcapod/core/nodes/operator_node.py new file mode 100644 index 00000000..e16da95a --- /dev/null +++ b/src/orcapod/core/nodes/operator_node.py @@ -0,0 +1,723 @@ +"""OperatorNode — stream node for operator invocations with optional DB persistence.""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Iterator, Sequence +from typing import TYPE_CHECKING, Any + +from orcapod import contexts +from orcapod.channels import Channel, ReadableChannel, WritableChannel +from orcapod.config import Config +from orcapod.core.operators.static_output_pod import StaticOutputOperatorPod +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.core.streams.base import StreamBase +from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER +from orcapod.protocols.core_protocols import ( + PacketProtocol, + StreamProtocol, + TagProtocol, + TrackerManagerProtocol, +) +from orcapod.protocols.core_protocols.operator_pod import OperatorPodProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.types import CacheMode, ColumnConfig, ContentHash, Schema +from orcapod.utils.lazy_module import LazyModule + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol +else: + pa = LazyModule("pyarrow") + + +class OperatorNode(StreamBase): + """Stream node representing an operator invocation with optional DB persistence. + + When constructed without database parameters, provides the core stream + interface (identity, schema, iteration) without any persistence. When + databases are provided (either at construction or via ``attach_databases``), + adds pipeline record storage with per-row deduplication, ``get_all_records()`` + for retrieving stored results, ``as_source()`` for creating a + ``DerivedSource`` from DB records, and three-tier cache mode + (OFF / LOG / REPLAY). + + Pipeline path structure:: + + pipeline_path_prefix / operator.uri / node:{content_hash} + + Where ``content_hash`` is the data-inclusive hash that encodes both + pipeline structure and upstream source identities, ensuring each + unique source combination gets its own cache table. + + Cache modes: + - **OFF** (default): compute, don't write to DB. + - **LOG**: compute AND write to DB (append-only historical record). + - **REPLAY**: skip computation, flow cached results downstream. + """ + + node_type = "operator" + HASH_COLUMN_NAME = "_record_hash" + + def __init__( + self, + operator: OperatorPodProtocol, + input_streams: tuple[StreamProtocol, ...] | list[StreamProtocol], + tracker_manager: TrackerManagerProtocol | None = None, + label: str | None = None, + config: Config | None = None, + # Optional DB params for persistent mode: + pipeline_database: ArrowDatabaseProtocol | None = None, + cache_mode: CacheMode = CacheMode.OFF, + pipeline_path_prefix: tuple[str, ...] = (), + ): + if tracker_manager is None: + tracker_manager = DEFAULT_TRACKER_MANAGER + self.tracker_manager = tracker_manager + + self._operator = operator + self._input_streams = tuple(input_streams) + + super().__init__(label=label, config=config) + + # Validate inputs eagerly + self._operator.validate_inputs(*self._input_streams) + + # Stream-level caching state + self._cached_output_stream: StreamProtocol | None = None + self._cached_output_table: pa.Table | None = None + self._set_modified_time(None) + + # DB persistence state (initially None; set via __init__ params or attach_databases) + self._pipeline_database: ArrowDatabaseProtocol | None = None + self._pipeline_path_prefix: tuple[str, ...] = () + self._cache_mode = CacheMode.OFF + self._pipeline_node_hash: str | None = None + + if pipeline_database is not None: + self.attach_databases( + pipeline_database=pipeline_database, + cache_mode=cache_mode, + pipeline_path_prefix=pipeline_path_prefix, + ) + + # ------------------------------------------------------------------ + # attach_databases + # ------------------------------------------------------------------ + + def attach_databases( + self, + pipeline_database: ArrowDatabaseProtocol, + cache_mode: CacheMode = CacheMode.OFF, + pipeline_path_prefix: tuple[str, ...] = (), + ) -> None: + """Attach a database for persistent caching and pipeline records. + + Args: + pipeline_database: Database for pipeline records. + cache_mode: Caching behaviour (OFF, LOG, or REPLAY). + pipeline_path_prefix: Path prefix for pipeline records. + """ + self._pipeline_database = pipeline_database + self._pipeline_path_prefix = pipeline_path_prefix + self._cache_mode = cache_mode + + # Clear caches + self.clear_cache() + self._content_hash_cache.clear() + self._pipeline_hash_cache.clear() + + # Use content_hash (data-inclusive) for pipeline node hash + self._pipeline_node_hash = self.content_hash().to_string() + + # ------------------------------------------------------------------ + # from_descriptor — reconstruct from a serialized pipeline descriptor + # ------------------------------------------------------------------ + + @classmethod + def from_descriptor( + cls, + descriptor: dict[str, Any], + operator: OperatorPodProtocol | None, + input_streams: tuple[StreamProtocol, ...] | list[StreamProtocol], + databases: dict[str, Any], + ) -> "OperatorNode": + """Construct an OperatorNode from a serialized descriptor. + + When *operator* and *input_streams* are provided the node operates + in full mode — constructed normally via ``__init__``. When + *operator* is ``None`` the node is created in read-only mode with + metadata from the descriptor; computation methods will raise + ``RuntimeError``. + + Args: + descriptor: The serialized node descriptor dict. + operator: An optional live operator instance. ``None`` for + read-only mode. + input_streams: Input streams for the operator. Empty tuple + for read-only mode. + databases: Mapping of database role names (``"pipeline"``) + to database instances. + + Returns: + A new ``OperatorNode`` instance. + """ + from orcapod.pipeline.serialization import LoadStatus + + pipeline_db = databases.get("pipeline") + cache_mode_str = descriptor.get("cache_mode", "OFF") + cache_mode = ( + CacheMode[cache_mode_str] + if isinstance(cache_mode_str, str) + else CacheMode.OFF + ) + + if operator is not None and input_streams: + # Full mode: construct normally + pipeline_path = tuple(descriptor.get("pipeline_path", ())) + # Derive pipeline_path_prefix by stripping the suffix that + # __init__ appends: operator.uri (2 elements) + node:{hash} (1 element). + uri_len = len(operator.uri) + 1 # +1 for node:{hash} + prefix = pipeline_path[:-uri_len] if len(pipeline_path) > uri_len else () + + node = cls( + operator=operator, + input_streams=input_streams, + pipeline_database=pipeline_db, + cache_mode=cache_mode, + pipeline_path_prefix=prefix, + label=descriptor.get("label"), + ) + node._descriptor = descriptor + node._load_status = LoadStatus.FULL + return node + + # Read-only mode: bypass __init__, set minimum required state + node = cls.__new__(cls) + + # From LabelableMixin + node._label = descriptor.get("label") + + # From DataContextMixin + from orcapod.config import DEFAULT_CONFIG + + node._data_context = contexts.resolve_context( + descriptor.get("data_context_key") + ) + node._orcapod_config = DEFAULT_CONFIG + + # From ContentIdentifiableBase + node._content_hash_cache = {} + node._cached_int_hash = None + + # From PipelineElementBase + node._pipeline_hash_cache = {} + + # From TemporalMixin + node._modified_time = None + + # From OperatorNode + node._operator = None + node._input_streams = () + node.tracker_manager = DEFAULT_TRACKER_MANAGER + node._cached_output_stream = None + node._cached_output_table = None + + # DB persistence state + node._pipeline_database = pipeline_db + node._pipeline_path_prefix = () + node._cache_mode = cache_mode + node._pipeline_node_hash = None + + # Descriptor metadata for read-only access + node._descriptor = descriptor + node._stored_schema = descriptor.get("output_schema", {}) + node._stored_content_hash = descriptor.get("content_hash") + node._stored_pipeline_hash = descriptor.get("pipeline_hash") + node._stored_pipeline_path = tuple(descriptor.get("pipeline_path", ())) + + # Determine load status based on DB availability + node._load_status = LoadStatus.UNAVAILABLE + if pipeline_db is not None: + node._load_status = LoadStatus.READ_ONLY + + return node + + # ------------------------------------------------------------------ + # load_status + # ------------------------------------------------------------------ + + @property + def load_status(self) -> Any: + """Return the load status of this node. + + Returns: + The ``LoadStatus`` enum value indicating how this node was + loaded. Defaults to ``FULL`` for nodes created via + ``__init__``. + """ + from orcapod.pipeline.serialization import LoadStatus + + return getattr(self, "_load_status", LoadStatus.FULL) + + # ------------------------------------------------------------------ + # Identity + # ------------------------------------------------------------------ + + def identity_structure(self) -> Any: + return (self._operator, self._operator.argument_symmetry(self._input_streams)) + + def pipeline_identity_structure(self) -> Any: + return (self._operator, self._operator.argument_symmetry(self._input_streams)) + + # ------------------------------------------------------------------ + # Read-only overrides (for deserialized nodes without live operator) + # ------------------------------------------------------------------ + + def content_hash(self, hasher=None) -> "ContentHash": + """Return the content hash, using stored value in read-only mode.""" + stored = getattr(self, "_stored_content_hash", None) + if self._operator is None and stored is not None: + return ContentHash.from_string(stored) + return super().content_hash(hasher) + + def pipeline_hash(self, hasher=None) -> "ContentHash": + """Return the pipeline hash, using stored value in read-only mode.""" + stored = getattr(self, "_stored_pipeline_hash", None) + if self._operator is None and stored is not None: + return ContentHash.from_string(stored) + return super().pipeline_hash(hasher) + + # ------------------------------------------------------------------ + # Stream interface + # ------------------------------------------------------------------ + + @property + def producer(self) -> OperatorPodProtocol: + return self._operator + + @property + def data_context(self) -> contexts.DataContext: + return contexts.resolve_context(self._operator.data_context_key) + + @property + def data_context_key(self) -> str: + return self._operator.data_context_key + + @property + def upstreams(self) -> tuple[StreamProtocol, ...]: + return self._input_streams + + @upstreams.setter + def upstreams(self, value: tuple[StreamProtocol, ...]) -> None: + self._input_streams = value + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + if self._operator is None: + stored = getattr(self, "_stored_schema", {}) + tag_keys = tuple(stored.get("tag", {}).keys()) + packet_keys = tuple(stored.get("packet", {}).keys()) + return tag_keys, packet_keys + tag_schema, packet_schema = self.output_schema( + columns=columns, all_info=all_info + ) + return tuple(tag_schema.keys()), tuple(packet_schema.keys()) + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """Return output schema, using stored value in read-only mode.""" + if self._operator is None: + stored = getattr(self, "_stored_schema", {}) + tag = Schema(stored.get("tag", {})) + packet = Schema(stored.get("packet", {})) + return tag, packet + return self._operator.output_schema( + *self._input_streams, + columns=columns, + all_info=all_info, + ) + + # ------------------------------------------------------------------ + # Pipeline path + # ------------------------------------------------------------------ + + @property + def pipeline_path(self) -> tuple[str, ...]: + """Return the pipeline path for DB record scoping. + + Raises: + RuntimeError: If no database is attached and this is not a + read-only deserialized node. + """ + stored = getattr(self, "_stored_pipeline_path", None) + if self._operator is None and stored is not None: + return stored + if self._pipeline_database is None: + raise RuntimeError( + "pipeline_path requires a database. Call attach_databases() first." + ) + return ( + self._pipeline_path_prefix + + self._operator.uri + + (f"node:{self._pipeline_node_hash}",) + ) + + # ------------------------------------------------------------------ + # Computation and caching + # ------------------------------------------------------------------ + + def clear_cache(self) -> None: + """Discard all in-memory cached state.""" + self._cached_output_stream = None + self._cached_output_table = None + self._update_modified_time() + + def _store_output_stream(self, stream: StreamProtocol) -> None: + """Materialize stream and store in the pipeline database with per-row dedup.""" + output_table = stream.as_table( + columns={"source": True, "system_tags": True}, + ) + + # Per-row record hashes for dedup: hash(tag + packet + system_tag) + arrow_hasher = self.data_context.arrow_hasher + record_hashes = [] + for batch in output_table.to_batches(): + for i in range(len(batch)): + record_hashes.append( + arrow_hasher.hash_table(batch.slice(i, 1)).to_hex() + ) + + output_table = output_table.add_column( + 0, + self.HASH_COLUMN_NAME, + pa.array(record_hashes, type=pa.large_string()), + ) + + # Store (identical rows across runs naturally deduplicate) + self._pipeline_database.add_records( + self.pipeline_path, + output_table, + record_id_column=self.HASH_COLUMN_NAME, + skip_duplicates=True, + ) + + self._cached_output_table = output_table.drop(self.HASH_COLUMN_NAME) + + def get_cached_output(self) -> "StreamProtocol | None": + """Return cached output stream in REPLAY mode, else None. + + Returns: + The cached stream if REPLAY mode and DB records exist, + otherwise None. + """ + if self._pipeline_database is None: + return None + if self._cache_mode != CacheMode.REPLAY: + return None + self._replay_from_cache() + return self._cached_output_stream + + def execute( + self, + *input_streams: StreamProtocol, + observer: "ExecutionObserver | None" = None, + ) -> list[tuple[TagProtocol, PacketProtocol]]: + """Execute input streams: compute, persist, and cache. + + Args: + *input_streams: Input streams to execute. + observer: Optional execution observer for hooks. + + Returns: + Materialized list of (tag, packet) pairs. + """ + if observer is not None: + observer.on_node_start(self) + + # Check REPLAY cache first + cached_output = self.get_cached_output() + if cached_output is not None: + output = list(cached_output.iter_packets()) + if observer is not None: + observer.on_node_end(self) + return output + + # Compute + result_stream = self._operator.process(*input_streams) + + # Materialize + output = list(result_stream.iter_packets()) + + # Cache + if output: + self._cached_output_stream = StaticOutputOperatorPod._materialize_to_stream( + output + ) + else: + self._cached_output_stream = result_stream + + self._update_modified_time() + + # Persist to DB only in LOG mode + if ( + self._pipeline_database is not None + and self._cache_mode == CacheMode.LOG + and self._cached_output_stream is not None + ): + self._store_output_stream(self._cached_output_stream) + + if observer is not None: + observer.on_node_end(self) + return output + + def _compute_and_store(self) -> None: + """Compute operator output, optionally store in DB.""" + self._cached_output_stream = self._operator.process( + *self._input_streams, + ) + + if self._cache_mode == CacheMode.OFF: + self._update_modified_time() + return + + self._store_output_stream(self._cached_output_stream) + self._update_modified_time() + + def _replay_from_cache(self) -> None: + """Load cached results from DB, skip computation. + + If no cached records exist yet, produces an empty stream with + the correct schema (zero rows, correct columns). + """ + records = self._pipeline_database.get_all_records(self.pipeline_path) + if records is None: + # Build an empty table with the correct schema + tag_schema, packet_schema = self.output_schema() + type_converter = self.data_context.type_converter + empty_fields = {} + for name, py_type in {**tag_schema, **packet_schema}.items(): + arrow_type = type_converter.python_type_to_arrow_type(py_type) + empty_fields[name] = pa.array([], type=arrow_type) + records = pa.table(empty_fields) + + tag_keys = self.keys()[0] + self._cached_output_stream = ArrowTableStream(records, tag_columns=tag_keys) + self._update_modified_time() + + def run(self) -> None: + """Execute the operator according to the current cache mode. + + Without a database: + Always compute via the operator's ``process()`` method. + + With a database: + - **OFF**: always compute, no DB writes. + - **LOG**: always compute, write results to DB. + - **REPLAY**: skip computation, load from DB. + """ + if self.is_stale: + self.clear_cache() + + if self._cached_output_stream is not None: + return + + if self._pipeline_database is not None: + if self._cache_mode == CacheMode.REPLAY: + self._replay_from_cache() + else: + self._compute_and_store() + else: + self._cached_output_stream = self._operator.process( + *self._input_streams, + ) + self._update_modified_time() + + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + self.run() + assert self._cached_output_stream is not None + return self._cached_output_stream.iter_packets() + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + self.run() + assert self._cached_output_stream is not None + return self._cached_output_stream.as_table(columns=columns, all_info=all_info) + + # ------------------------------------------------------------------ + # DB retrieval + # ------------------------------------------------------------------ + + def get_all_records( + self, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table | None": + """Retrieve all stored records from the pipeline database. + + Returns the stored output table with column filtering applied + per ``ColumnConfig`` conventions. Returns ``None`` when no + database is attached or no records exist. + """ + if self._pipeline_database is None: + return None + + results = self._pipeline_database.get_all_records(self.pipeline_path) + if results is None: + return None + + column_config = ColumnConfig.handle_config(columns, all_info=all_info) + + drop_columns = [] + if not column_config.meta and not column_config.all_info: + drop_columns.extend( + c for c in results.column_names if c.startswith(constants.META_PREFIX) + ) + if not column_config.source and not column_config.all_info: + drop_columns.extend( + c for c in results.column_names if c.startswith(constants.SOURCE_PREFIX) + ) + if not column_config.system_tags and not column_config.all_info: + drop_columns.extend( + c + for c in results.column_names + if c.startswith(constants.SYSTEM_TAG_PREFIX) + ) + if drop_columns: + results = results.drop( + [c for c in drop_columns if c in results.column_names] + ) + + return results if results.num_rows > 0 else None + + # ------------------------------------------------------------------ + # DerivedSource + # ------------------------------------------------------------------ + + def as_source(self): + """Return a DerivedSource backed by the DB records of this node. + + Raises: + RuntimeError: If no database is attached. + """ + if self._pipeline_database is None: + raise RuntimeError("Cannot create a DerivedSource without a database") + + from orcapod.core.sources.derived_source import DerivedSource + + path_str = "/".join(self.pipeline_path) + content_frag = self.content_hash().to_string()[:16] + source_id = f"{path_str}:{content_frag}" + return DerivedSource( + origin=self, + source_id=source_id, + data_context=self.data_context_key, + config=self.orcapod_config, + ) + + # ------------------------------------------------------------------ + # Async channel execution + # ------------------------------------------------------------------ + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: "ExecutionObserver | None" = None, + ) -> None: + """Async execution with cache mode handling when DB is attached. + + Without a database, delegates to the wrapped operator's + ``async_execute``. + + With a database: + - REPLAY: emit from DB, close output. + - OFF: delegate to operator, forward results. + - LOG: delegate to operator, forward + collect results, then store in DB. + + Args: + inputs: Sequence of readable channels from upstream nodes. + output: Writable channel for output (tag, packet) pairs. + observer: Optional execution observer for hooks. + """ + if self._pipeline_database is None: + # Simple delegation without DB + if observer is not None: + observer.on_node_start(self) + hashes = [s.pipeline_hash() for s in self._input_streams] + await self._operator.async_execute( + inputs, output, input_pipeline_hashes=hashes + ) + if observer is not None: + observer.on_node_end(self) + return + + try: + if observer is not None: + observer.on_node_start(self) + + if self._cache_mode == CacheMode.REPLAY: + self._replay_from_cache() + assert self._cached_output_stream is not None + for tag, packet in self._cached_output_stream.iter_packets(): + await output.send((tag, packet)) + return # finally block closes output + + # OFF or LOG: delegate to operator, forward results downstream + intermediate: Channel[tuple[TagProtocol, PacketProtocol]] = Channel() + should_collect = self._cache_mode == CacheMode.LOG + collected: list[tuple[TagProtocol, PacketProtocol]] = [] + + async def forward() -> None: + async for item in intermediate.reader: + if should_collect: + collected.append(item) + await output.send(item) + + hashes = [s.pipeline_hash() for s in self._input_streams] + async with asyncio.TaskGroup() as tg: + tg.create_task( + self._operator.async_execute( + inputs, + intermediate.writer, + input_pipeline_hashes=hashes, + ) + ) + tg.create_task(forward()) + + # TaskGroup has completed — store if LOG mode (sync DB write, post-hoc) + if should_collect and collected: + stream = StaticOutputOperatorPod._materialize_to_stream(collected) + self._cached_output_stream = stream + self._store_output_stream(stream) + + self._update_modified_time() + + if observer is not None: + observer.on_node_end(self) + finally: + await output.close() + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(operator={self._operator!r}, " + f"upstreams={self._input_streams!r})" + ) diff --git a/src/orcapod/core/nodes/source_node.py b/src/orcapod/core/nodes/source_node.py new file mode 100644 index 00000000..504f29f7 --- /dev/null +++ b/src/orcapod/core/nodes/source_node.py @@ -0,0 +1,291 @@ +"""SourceNode — wraps a root source stream in the computation graph.""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any + +from orcapod import contexts +from orcapod.channels import WritableChannel +from orcapod.config import Config, DEFAULT_CONFIG +from orcapod.core.streams.base import StreamBase +from orcapod.protocols import core_protocols as cp +from orcapod.types import ColumnConfig, ContentHash, Schema + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol + + +class SourceNode(StreamBase): + """Represents a root source stream in the computation graph.""" + + node_type = "source" + + def __init__( + self, + stream: cp.StreamProtocol, + label: str | None = None, + config: Config | None = None, + ): + super().__init__(label=label, config=config) + self.stream = stream + self._cached_results: list[tuple[cp.TagProtocol, cp.PacketProtocol]] | None = ( + None + ) + + # ------------------------------------------------------------------ + # from_descriptor — reconstruct from a serialized pipeline descriptor + # ------------------------------------------------------------------ + + @classmethod + def from_descriptor( + cls, + descriptor: dict[str, Any], + stream: cp.StreamProtocol | None, + databases: dict[str, Any], + ) -> SourceNode: + """Construct a SourceNode from a serialized descriptor. + + When *stream* is provided the node operates in full mode — all + delegation goes through the live stream. When *stream* is ``None`` + the node is created in read-only mode with metadata from the + descriptor; data-access methods (``iter_packets``, ``as_table``) + will raise ``RuntimeError``. + + Args: + descriptor: The serialized node descriptor dict. + stream: An optional live stream to wrap. ``None`` for + read-only mode. + databases: Mapping of database role names to database + instances (currently unused for source nodes but kept + for interface consistency with other node types). + + Returns: + A new ``SourceNode`` instance. + """ + from orcapod.pipeline.serialization import LoadStatus + + if stream is not None: + node = cls(stream=stream, label=descriptor.get("label")) + node._descriptor = descriptor + node._load_status = LoadStatus.FULL + return node + + # Read-only mode: bypass __init__, set minimum required state + node = cls.__new__(cls) + + # From LabelableMixin + node._label = descriptor.get("label") + + # From DataContextMixin + node._data_context = contexts.resolve_context( + descriptor.get("data_context_key") + ) + node._orcapod_config = DEFAULT_CONFIG + + # From ContentIdentifiableBase + node._content_hash_cache = {} + node._cached_int_hash = None + + # From PipelineElementBase + node._pipeline_hash_cache = {} + + # From TemporalMixin + node._modified_time = None + + # SourceNode's own state + node.stream = None + node._cached_results = None + node._descriptor = descriptor + node._load_status = LoadStatus.UNAVAILABLE + node._stored_schema = descriptor.get("output_schema", {}) + node._stored_content_hash = descriptor.get("content_hash") + node._stored_pipeline_hash = descriptor.get("pipeline_hash") + + return node + + # ------------------------------------------------------------------ + # load_status + # ------------------------------------------------------------------ + + @property + def load_status(self) -> Any: + """Return the load status of this node. + + Returns: + The ``LoadStatus`` enum value indicating how this node was + loaded. Defaults to ``FULL`` for nodes created via + ``__init__``. + """ + from orcapod.pipeline.serialization import LoadStatus + + return getattr(self, "_load_status", LoadStatus.FULL) + + # ------------------------------------------------------------------ + # Delegation — with read-only guards + # ------------------------------------------------------------------ + + @property + def data_context(self) -> contexts.DataContext: + if self.stream is None: + return self._data_context + return contexts.resolve_context(self.stream.data_context_key) + + @property + def data_context_key(self) -> str: + if self.stream is None: + return self._data_context.context_key + return self.stream.data_context_key + + def computed_label(self) -> str | None: + if self.stream is None: + return None + return self.stream.label + + def identity_structure(self) -> Any: + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + # TODO: revisit this logic for case where stream is not a root source + return self.stream.identity_structure() + + def pipeline_identity_structure(self) -> Any: + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + return self.stream.pipeline_identity_structure() + + def content_hash(self, hasher=None) -> ContentHash: + """Return the content hash, using stored value in read-only mode.""" + stored = getattr(self, "_stored_content_hash", None) + if self.stream is None and stored is not None: + return ContentHash.from_string(stored) + return super().content_hash(hasher) + + def pipeline_hash(self, hasher=None) -> ContentHash: + """Return the pipeline hash, using stored value in read-only mode.""" + stored = getattr(self, "_stored_pipeline_hash", None) + if self.stream is None and stored is not None: + return ContentHash.from_string(stored) + return super().pipeline_hash(hasher) + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + if self.stream is None: + stored = getattr(self, "_stored_schema", {}) + tag_keys = tuple(stored.get("tag", {}).keys()) + packet_keys = tuple(stored.get("packet", {}).keys()) + return tag_keys, packet_keys + return self.stream.keys(columns=columns, all_info=all_info) + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + if self.stream is None: + stored = getattr(self, "_stored_schema", {}) + tag = Schema(stored.get("tag", {})) + packet = Schema(stored.get("packet", {})) + return tag, packet + return self.stream.output_schema(columns=columns, all_info=all_info) + + @property + def producer(self) -> None: + return None + + @property + def upstreams(self) -> tuple[cp.StreamProtocol, ...]: + return () + + @upstreams.setter + def upstreams(self, value: tuple[cp.StreamProtocol, ...]) -> None: + if len(value) != 0: + raise ValueError("SourceNode upstreams must be empty") + + def __repr__(self) -> str: + return f"SourceNode(stream={self.stream!r}, label={self.label!r})" + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + return self.stream.as_table(columns=columns, all_info=all_info) + + def iter_packets(self) -> Iterator[tuple[cp.TagProtocol, cp.PacketProtocol]]: + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + if self._cached_results is not None: + return iter(self._cached_results) + return self.stream.iter_packets() + + def execute( + self, + *, + observer: "ExecutionObserver | None" = None, + ) -> list[tuple[cp.TagProtocol, cp.PacketProtocol]]: + """Execute this source: materialize packets and return. + + Args: + observer: Optional execution observer for hooks. + + Returns: + List of (tag, packet) tuples. + """ + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + if observer is not None: + observer.on_node_start(self) + result = list(self.stream.iter_packets()) + self._cached_results = result + if observer is not None: + observer.on_node_end(self) + return result + + def run(self) -> None: + """No-op for source nodes — data is already available.""" + + async def async_execute( + self, + output: WritableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]], + *, + observer: "ExecutionObserver | None" = None, + ) -> None: + """Push all (tag, packet) pairs from the wrapped stream to the output channel. + + Args: + output: Channel to write results to. + observer: Optional execution observer for hooks. + """ + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + try: + if observer is not None: + observer.on_node_start(self) + for tag, packet in self.stream.iter_packets(): + await output.send((tag, packet)) + if observer is not None: + observer.on_node_end(self) + finally: + await output.close() diff --git a/src/orcapod/core/operators/__init__.py b/src/orcapod/core/operators/__init__.py index b1f05443..15d6a641 100644 --- a/src/orcapod/core/operators/__init__.py +++ b/src/orcapod/core/operators/__init__.py @@ -1,17 +1,19 @@ -from .join import Join -from .semijoin import SemiJoin -from .mappers import MapTags, MapPackets from .batch import Batch from .column_selection import ( - SelectTagColumns, - SelectPacketColumns, - DropTagColumns, DropPacketColumns, + DropTagColumns, + SelectPacketColumns, + SelectTagColumns, ) from .filters import PolarsFilter +from .join import Join +from .mappers import MapPackets, MapTags +from .merge_join import MergeJoin +from .semijoin import SemiJoin __all__ = [ "Join", + "MergeJoin", "SemiJoin", "MapTags", "MapPackets", diff --git a/src/orcapod/core/operators/base.py b/src/orcapod/core/operators/base.py index b87748c2..09a59462 100644 --- a/src/orcapod/core/operators/base.py +++ b/src/orcapod/core/operators/base.py @@ -1,290 +1,200 @@ -from orcapod.core.kernels import TrackedKernelBase -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema +from __future__ import annotations + +import asyncio from abc import abstractmethod +from collections.abc import Collection, Sequence from typing import Any -from collections.abc import Collection - - -class Operator(TrackedKernelBase): - """ - Base class for all operators. - Operators are a special type of kernel that can be used to perform operations on streams. - - They are defined as a callable that takes a (possibly empty) collection of streams as the input - and returns a new stream as output (note that output stream is always singular). - """ - - -class UnaryOperator(Operator): - """ - Base class for all operators. - """ - - def check_unary_input( - self, - streams: Collection[cp.Stream], - ) -> None: - """ - Check that the inputs to the unary operator are valid. - """ - if len(streams) != 1: - raise ValueError("UnaryOperator requires exactly one input stream.") - - def validate_inputs(self, *streams: cp.Stream) -> None: - self.check_unary_input(streams) - stream = streams[0] - return self.op_validate_inputs(stream) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Forward method for unary operators. - It expects exactly one stream as input. - """ - stream = streams[0] - return self.op_forward(stream) - # TODO: complete substream implementation - # Substream implementation pending - # stream = streams[0] - # # visit each substream - # output_substreams = [] - # for substream_id in stream.substream_identities: - # substream = stream.get_substream(substream_id) - # output_substreams.append(self.op_forward(substream)) +from orcapod.channels import ReadableChannel, WritableChannel +from orcapod.core.operators.static_output_pod import StaticOutputOperatorPod +from orcapod.protocols.core_protocols import ( + ArgumentGroup, + PacketProtocol, + StreamProtocol, + TagProtocol, +) +from orcapod.types import ColumnConfig, ContentHash, Schema - # # at the moment only single output substream is supported - # if len(output_substreams) != 1: - # raise NotImplementedError( - # "Support for multiple output substreams is not implemented yet." - # ) - # return output_substreams[0] - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - stream = streams[0] - return self.op_output_types(stream, include_system_tags=include_system_tags) - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - """ - Return a structure that represents the identity of this operator. - This is used to ensure that the operator can be uniquely identified in the computational graph. - """ - if streams is not None: - stream = list(streams)[0] - return self.op_identity_structure(stream) - return self.op_identity_structure() +class UnaryOperator(StaticOutputOperatorPod): + """Base class for all unary operators.""" @abstractmethod - def op_validate_inputs(self, stream: cp.Stream) -> None: - """ - This method should be implemented by subclasses to validate the inputs to the operator. - It takes two streams as input and raises an error if the inputs are not valid. - """ - ... + def validate_unary_input(self, stream: StreamProtocol) -> None: + """Validate the single input stream. - @abstractmethod - def op_forward(self, stream: cp.Stream) -> cp.Stream: - """ - This method should be implemented by subclasses to define the specific behavior of the binary operator. - It takes two streams as input and returns a new stream as output. + Raises: + ValueError: If the input stream is not valid for this operator. """ ... @abstractmethod - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - This method should be implemented by subclasses to return the typespecs of the input and output streams. - It takes two streams as input and returns a tuple of typespecs. - """ + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: + """Process a single input stream and return a new output stream.""" ... @abstractmethod - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: - """ - This method should be implemented by subclasses to return a structure that represents the identity of the operator. - It takes two streams as input and returns a tuple containing the operator name and a set of streams. - """ + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """Return the (tag, packet) output schemas for the given input stream.""" ... + def validate_inputs(self, *streams: StreamProtocol) -> None: + if len(streams) != 1: + raise ValueError("UnaryOperator requires exactly one input stream.") + stream = streams[0] + return self.validate_unary_input(stream) -class BinaryOperator(Operator): - """ - Base class for all operators. - """ + def static_process(self, *streams: StreamProtocol) -> StreamProtocol: + """Forward to ``unary_static_process`` with the single input stream.""" + stream = streams[0] + return self.unary_static_process(stream) - def check_binary_inputs( + def output_schema( self, - streams: Collection[cp.Stream], - ) -> None: - """ - Check that the inputs to the binary operator are valid. - This method is called before the forward method to ensure that the inputs are valid. - """ - if len(streams) != 2: - raise ValueError("BinaryOperator requires exactly two input streams.") + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + stream = streams[0] + return self.unary_output_schema(stream, columns=columns, all_info=all_info) - def validate_inputs(self, *streams: cp.Stream) -> None: - self.check_binary_inputs(streams) - left_stream, right_stream = streams - return self.op_validate_inputs(left_stream, right_stream) + def argument_symmetry(self, streams: Collection[StreamProtocol]) -> ArgumentGroup: + # return single stream as a tuple + return (tuple(streams)[0],) - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Forward method for binary operators. - It expects exactly two streams as input. - """ - left_stream, right_stream = streams - return self.op_forward(left_stream, right_stream) + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, + ) -> None: + """Barrier-mode: collect single input, run unary_static_process, emit.""" + try: + rows = await inputs[0].collect() + stream = self._materialize_to_stream(rows) + result = self.static_process(stream) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - left_stream, right_stream = streams - return self.op_output_types( - left_stream, right_stream, include_system_tags=include_system_tags - ) - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - """ - Return a structure that represents the identity of this operator. - This is used to ensure that the operator can be uniquely identified in the computational graph. - """ - if streams is not None: - left_stream, right_stream = streams - self.op_identity_structure(left_stream, right_stream) - return self.op_identity_structure() +class BinaryOperator(StaticOutputOperatorPod): + """Base class for all binary operators.""" @abstractmethod - def op_validate_inputs( - self, left_stream: cp.Stream, right_stream: cp.Stream + def validate_binary_inputs( + self, left_stream: StreamProtocol, right_stream: StreamProtocol ) -> None: - """ - This method should be implemented by subclasses to validate the inputs to the operator. - It takes two streams as input and raises an error if the inputs are not valid. + """Validate the two input streams. + + Raises: + ValueError: If the inputs are not valid for this operator. """ ... @abstractmethod - def op_forward(self, left_stream: cp.Stream, right_stream: cp.Stream) -> cp.Stream: - """ - This method should be implemented by subclasses to define the specific behavior of the binary operator. - It takes two streams as input and returns a new stream as output. - """ + def binary_static_process( + self, left_stream: StreamProtocol, right_stream: StreamProtocol + ) -> StreamProtocol: + """Process two input streams and return a new output stream.""" ... @abstractmethod - def op_output_types( + def binary_output_schema( self, - left_stream: cp.Stream, - right_stream: cp.Stream, - include_system_tags: bool = False, - ) -> tuple[PythonSchema, PythonSchema]: - """ - This method should be implemented by subclasses to return the typespecs of the input and output streams. - It takes two streams as input and returns a tuple of typespecs. - """ - ... + left_stream: StreamProtocol, + right_stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: ... @abstractmethod - def op_identity_structure( - self, - left_stream: cp.Stream | None = None, - right_stream: cp.Stream | None = None, - ) -> Any: - """ - This method should be implemented by subclasses to return a structure that represents the identity of the operator. - It takes two streams as input and returns a tuple containing the operator name and a set of streams. - """ + def is_commutative(self) -> bool: + """Return True if the operator is commutative (order of inputs does not matter).""" ... + def static_process(self, *streams: StreamProtocol) -> StreamProtocol: + """Forward to ``binary_static_process`` with two input streams.""" + left_stream, right_stream = streams + return self.binary_static_process(left_stream, right_stream) + + def output_schema( + self, + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + left_stream, right_stream = streams + return self.binary_output_schema( + left_stream, right_stream, columns=columns, all_info=all_info + ) -class NonZeroInputOperator(Operator): - """ - Operators that work with at least one input stream. - This is useful for operators that can take a variable number of (but at least one ) input streams, - such as joins, unions, etc. - """ + def validate_inputs(self, *streams: StreamProtocol) -> None: + if len(streams) != 2: + raise ValueError("BinaryOperator requires exactly two input streams.") + left_stream, right_stream = streams + self.validate_binary_inputs(left_stream, right_stream) - def verify_non_zero_input( + def argument_symmetry(self, streams: Collection[StreamProtocol]) -> ArgumentGroup: + if self.is_commutative(): + # return as symmetric group + return frozenset(streams) + else: + # return as ordered group + return tuple(streams) + + async def async_execute( self, - streams: Collection[cp.Stream], + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, ) -> None: - """ - Check that the inputs to the variable inputs operator are valid. - This method is called before the forward method to ensure that the inputs are valid. - """ - if len(streams) == 0: - raise ValueError( - f"Operator {self.__class__.__name__} requires at least one input stream." + """Barrier-mode: collect both inputs concurrently, run binary_static_process, emit.""" + try: + left_rows, right_rows = await asyncio.gather( + inputs[0].collect(), inputs[1].collect() ) + left_stream = self._materialize_to_stream(left_rows) + right_stream = self._materialize_to_stream(right_rows) + result = self.static_process(left_stream, right_stream) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() - def validate_inputs(self, *streams: cp.Stream) -> None: - self.verify_non_zero_input(streams) - return self.op_validate_inputs(*streams) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Forward method for variable inputs operators. - It expects at least one stream as input. - """ - return self.op_forward(*streams) - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - return self.op_output_types(*streams, include_system_tags=include_system_tags) - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - """ - Return a structure that represents the identity of this operator. - This is used to ensure that the operator can be uniquely identified in the computational graph. - """ - return self.op_identity_structure(streams) +class NonZeroInputOperator(StaticOutputOperatorPod): + """Base class for operators that require at least one input stream. - @abstractmethod - def op_validate_inputs(self, *streams: cp.Stream) -> None: - """ - This method should be implemented by subclasses to validate the inputs to the operator. - It takes two streams as input and raises an error if the inputs are not valid. - """ - ... + Useful for operators that accept a variable number of input streams, + such as joins and unions. + """ @abstractmethod - def op_forward(self, *streams: cp.Stream) -> cp.Stream: - """ - This method should be implemented by subclasses to define the specific behavior of the non-zero input operator. - It takes variable number of streams as input and returns a new stream as output. - """ - ... + def validate_nonzero_inputs( + self, + *streams: StreamProtocol, + ) -> None: + """Validate the input streams. - @abstractmethod - def op_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - This method should be implemented by subclasses to return the typespecs of the input and output streams. - It takes at least one stream as input and returns a tuple of typespecs. + Raises: + ValueError: If the inputs are not valid for this operator. """ ... - @abstractmethod - def op_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - """ - This method should be implemented by subclasses to return a structure that represents the identity of the operator. - It takes zero or more streams as input and returns a tuple containing the operator name and a set of streams. - If zero, it should return identity of the operator itself. - If one or more, it should return a identity structure approrpiate for the operator invoked on the given streams. - """ - ... + def validate_inputs(self, *streams: StreamProtocol) -> None: + if len(streams) == 0: + raise ValueError( + f"Operator {self.__class__.__name__} requires at least one input stream." + ) + self.validate_nonzero_inputs(*streams) diff --git a/src/orcapod/core/operators/batch.py b/src/orcapod/core/operators/batch.py index be48b3c8..8018d2bb 100644 --- a/src/orcapod/core/operators/batch.py +++ b/src/orcapod/core/operators/batch.py @@ -1,18 +1,21 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import UnaryOperator -from collections.abc import Collection -from orcapod.protocols import core_protocols as cp -from typing import Any, TYPE_CHECKING +from orcapod.core.streams import ArrowTableStream +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol +from orcapod.types import ColumnConfig from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams import TableStream if TYPE_CHECKING: - import pyarrow as pa import polars as pl + import pyarrow as pa else: pa = LazyModule("pyarrow") pl = LazyModule("polars") -from orcapod.types import PythonSchema +from orcapod.types import Schema class Batch(UnaryOperator): @@ -29,34 +32,18 @@ def __init__(self, batch_size: int = 0, drop_partial_batch: bool = False, **kwar self.batch_size = batch_size self.drop_partial_batch = drop_partial_batch - def check_unary_input( - self, - streams: Collection[cp.Stream], - ) -> None: - """ - Check that the inputs to the unary operator are valid. - """ - if len(streams) != 1: - raise ValueError("UnaryOperator requires exactly one input stream.") - - def validate_inputs(self, *streams: cp.Stream) -> None: - self.check_unary_input(streams) - stream = streams[0] - return self.op_validate_inputs(stream) - - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ - This method should be implemented by subclasses to validate the inputs to the operator. - It takes two streams as input and raises an error if the inputs are not valid. + Batch works on any input stream, so no validation is needed. """ return None - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: """ This method should be implemented by subclasses to define the specific behavior of the binary operator. It takes two streams as input and returns a new stream as output. """ - table = stream.as_table(include_source=True, include_system_tags=True) + table = stream.as_table(columns={"source": True, "system_tags": True}) tag_columns, packet_columns = stream.keys() @@ -81,26 +68,87 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: batched_data.append(next_batch) batched_table = pa.Table.from_pylist(batched_data) - return TableStream(batched_table, tag_columns=tag_columns) + return ArrowTableStream( + batched_table, + tag_columns=tag_columns, + ) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: """ - This method should be implemented by subclasses to return the typespecs of the input and output streams. - It takes two streams as input and returns a tuple of typespecs. + This method should be implemented by subclasses to return the schemas of the input and output streams. + It takes two streams as input and returns a tuple of schemas. """ - tag_types, packet_types = stream.types(include_system_tags=include_system_tags) + tag_types, packet_types = stream.output_schema( + columns=columns, all_info=all_info + ) batched_tag_types = {k: list[v] for k, v in tag_types.items()} batched_packet_types = {k: list[v] for k, v in packet_types.items()} # TODO: check if this is really necessary - return PythonSchema(batched_tag_types), PythonSchema(batched_packet_types) - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: - return ( - (self.__class__.__name__, self.batch_size, self.drop_partial_batch) - + (stream,) - if stream is not None - else () - ) + return Schema(batched_tag_types), Schema(batched_packet_types) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming batch: emit full batches as they accumulate. + + When ``batch_size > 0``, each group of ``batch_size`` rows is + materialized and emitted immediately, allowing downstream consumers + to start processing before all input is consumed. When + ``batch_size == 0`` (batch everything), falls back to barrier mode. + """ + try: + if self.batch_size == 0: + # Must collect all rows — barrier fallback + rows = await inputs[0].collect() + if rows: + stream = self._materialize_to_stream(rows) + result = self.unary_static_process(stream) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + return + + batch: list[tuple[TagProtocol, PacketProtocol]] = [] + async for tag, packet in inputs[0]: + batch.append((tag, packet)) + if len(batch) >= self.batch_size: + stream = self._materialize_to_stream(batch) + result = self.unary_static_process(stream) + for out_tag, out_packet in result.iter_packets(): + await output.send((out_tag, out_packet)) + batch = [] + + # Flush partial batch + if batch and not self.drop_partial_batch: + stream = self._materialize_to_stream(batch) + result = self.unary_static_process(stream) + for out_tag, out_packet in result.iter_packets(): + await output.send((out_tag, out_packet)) + finally: + await output.close() + + def to_config(self) -> dict[str, Any]: + """Serialize this Batch operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``batch_size`` and ``drop_partial_batch``. + """ + config = super().to_config() + config["config"] = { + "batch_size": self.batch_size, + "drop_partial_batch": self.drop_partial_batch, + } + return config + + def identity_structure(self) -> Any: + return (self.__class__.__name__, self.batch_size, self.drop_partial_batch) diff --git a/src/orcapod/core/operators/column_selection.py b/src/orcapod/core/operators/column_selection.py index 4140db8e..37a0663d 100644 --- a/src/orcapod/core/operators/column_selection.py +++ b/src/orcapod/core/operators/column_selection.py @@ -1,14 +1,15 @@ -from orcapod.protocols import core_protocols as cp -from orcapod.core.streams import TableStream -from orcapod.types import PythonSchema -from typing import Any, TYPE_CHECKING -from orcapod.utils.lazy_module import LazyModule -from collections.abc import Collection, Mapping -from orcapod.errors import InputValidationError -from orcapod.core.system_constants import constants -from orcapod.core.operators.base import UnaryOperator import logging +from collections.abc import Collection, Mapping, Sequence +from typing import TYPE_CHECKING, Any +from orcapod.channels import ReadableChannel, WritableChannel +from orcapod.core.operators.base import UnaryOperator +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, Schema +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa @@ -30,7 +31,21 @@ def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs self.strict = strict super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def to_config(self) -> dict[str, Any]: + """Serialize this SelectTagColumns operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``columns`` and ``strict``. + """ + config = super().to_config() + config["config"] = { + "columns": list(self.columns), + "strict": self.strict, + } + return config + + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() tags_to_drop = [c for c in tag_columns if c not in self.columns] new_tag_columns = [c for c in tag_columns if c not in tags_to_drop] @@ -40,19 +55,17 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: return stream table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False} ) modified_table = table.drop_columns(list(tags_to_drop)) - return TableStream( + return ArrowTableStream( modified_table, tag_columns=new_tag_columns, - source=self, - upstreams=(stream,), ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. @@ -66,11 +79,15 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: f"Missing tag columns: {missing_columns}. Make sure all specified columns to select are present or use strict=False to ignore missing columns" ) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_schema, packet_schema = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) tag_columns, _ = stream.keys() tags_to_drop = [tc for tc in tag_columns if tc not in self.columns] @@ -78,14 +95,42 @@ def op_output_types( # this ensures all system tag columns are preserved new_tag_schema = {k: v for k, v in tag_schema.items() if k not in tags_to_drop} - return new_tag_schema, packet_schema - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + return Schema(new_tag_schema), packet_schema + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: select tag columns per row without materializing.""" + try: + tags_to_drop: list[str] | None = None + async for tag, packet in inputs[0]: + if tags_to_drop is None: + tag_keys = tag.keys() + if self.strict: + missing = set(self.columns) - set(tag_keys) + if missing: + raise InputValidationError( + f"Missing tag columns: {missing}. Make sure all " + f"specified columns to select are present or use " + f"strict=False to ignore missing columns" + ) + tags_to_drop = [c for c in tag_keys if c not in self.columns] + if not tags_to_drop: + await output.send((tag, packet)) + else: + await output.send((tag.drop(*tags_to_drop), packet)) + finally: + await output.close() + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.columns, self.strict, - ) + ((stream,) if stream is not None else ()) + ) class SelectPacketColumns(UnaryOperator): @@ -100,7 +145,21 @@ def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs self.strict = strict super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def to_config(self) -> dict[str, Any]: + """Serialize this SelectPacketColumns operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``columns`` and ``strict``. + """ + config = super().to_config() + config["config"] = { + "columns": list(self.columns), + "strict": self.strict, + } + return config + + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() packet_columns_to_drop = [c for c in packet_columns if c not in self.columns] new_packet_columns = [ @@ -112,7 +171,7 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: return stream table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False}, ) # make sure to drop associated source fields associated_source_fields = [ @@ -122,14 +181,12 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: modified_table = table.drop_columns(packet_columns_to_drop) - return TableStream( + return ArrowTableStream( modified_table, tag_columns=tag_columns, - source=self, - upstreams=(stream,), ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. @@ -143,11 +200,15 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: f"Missing packet columns: {missing_columns}. Make sure all specified columns to select are present or use strict=False to ignore missing columns" ) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_schema, packet_schema = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) _, packet_columns = stream.keys() packets_to_drop = [pc for pc in packet_columns if pc not in self.columns] @@ -157,14 +218,42 @@ def op_output_types( k: v for k, v in packet_schema.items() if k not in packets_to_drop } - return tag_schema, new_packet_schema - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + return tag_schema, Schema(new_packet_schema) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: select packet columns per row without materializing.""" + try: + pkts_to_drop: list[str] | None = None + async for tag, packet in inputs[0]: + if pkts_to_drop is None: + pkt_keys = packet.keys() + if self.strict: + missing = set(self.columns) - set(pkt_keys) + if missing: + raise InputValidationError( + f"Missing packet columns: {missing}. Make sure all " + f"specified columns to select are present or use " + f"strict=False to ignore missing columns" + ) + pkts_to_drop = [c for c in pkt_keys if c not in self.columns] + if not pkts_to_drop: + await output.send((tag, packet)) + else: + await output.send((tag, packet.drop(*pkts_to_drop))) + finally: + await output.close() + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.columns, self.strict, - ) + ((stream,) if stream is not None else ()) + ) class DropTagColumns(UnaryOperator): @@ -179,7 +268,21 @@ def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs self.strict = strict super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def to_config(self) -> dict[str, Any]: + """Serialize this DropTagColumns operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``columns`` and ``strict``. + """ + config = super().to_config() + config["config"] = { + "columns": list(self.columns), + "strict": self.strict, + } + return config + + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() columns_to_drop = self.columns if not self.strict: @@ -192,19 +295,17 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: return stream table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False} ) modified_table = table.drop_columns(list(columns_to_drop)) - return TableStream( + return ArrowTableStream( modified_table, tag_columns=new_tag_columns, - source=self, - upstreams=(stream,), ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. @@ -218,25 +319,61 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: f"Missing tag columns: {missing_columns}. Make sure all specified columns to drop are present or use strict=False to ignore missing columns" ) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_schema, packet_schema = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) tag_columns, _ = stream.keys() new_tag_columns = [c for c in tag_columns if c not in self.columns] new_tag_schema = {k: v for k, v in tag_schema.items() if k in new_tag_columns} - return new_tag_schema, packet_schema - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + return Schema(new_tag_schema), packet_schema + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: drop tag columns per row without materializing.""" + try: + effective_drops: list[str] | None = None + async for tag, packet in inputs[0]: + if effective_drops is None: + tag_keys = tag.keys() + if self.strict: + missing = set(self.columns) - set(tag_keys) + if missing: + raise InputValidationError( + f"Missing tag columns: {missing}. Make sure all " + f"specified columns to drop are present or use " + f"strict=False to ignore missing columns" + ) + effective_drops = ( + list(self.columns) + if self.strict + else [c for c in self.columns if c in tag_keys] + ) + if not effective_drops: + await output.send((tag, packet)) + else: + await output.send((tag.drop(*effective_drops), packet)) + finally: + await output.close() + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.columns, self.strict, - ) + ((stream,) if stream is not None else ()) + ) class DropPacketColumns(UnaryOperator): @@ -251,7 +388,21 @@ def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs self.strict = strict super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def to_config(self) -> dict[str, Any]: + """Serialize this DropPacketColumns operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``columns`` and ``strict``. + """ + config = super().to_config() + config["config"] = { + "columns": list(self.columns), + "strict": self.strict, + } + return config + + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() columns_to_drop = list(self.columns) if not self.strict: @@ -268,19 +419,17 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: columns_to_drop.extend(associated_source_columns) table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False} ) modified_table = table.drop_columns(columns_to_drop) - return TableStream( + return ArrowTableStream( modified_table, tag_columns=tag_columns, - source=self, - upstreams=(stream,), ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. @@ -293,24 +442,61 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: f"Missing packet columns: {missing_columns}. Make sure all specified columns to drop are present or use strict=False to ignore missing columns" ) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_schema, packet_schema = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) + new_packet_schema = { k: v for k, v in packet_schema.items() if k not in self.columns } - return tag_schema, new_packet_schema - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + return tag_schema, Schema(new_packet_schema) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: drop packet columns per row without materializing.""" + try: + effective_drops: list[str] | None = None + async for tag, packet in inputs[0]: + if effective_drops is None: + pkt_keys = packet.keys() + if self.strict: + missing = set(self.columns) - set(pkt_keys) + if missing: + raise InputValidationError( + f"Missing packet columns: {missing}. Make sure all " + f"specified columns to drop are present or use " + f"strict=False to ignore missing columns" + ) + effective_drops = ( + list(self.columns) + if self.strict + else [c for c in self.columns if c in pkt_keys] + ) + if not effective_drops: + await output.send((tag, packet)) + else: + await output.send((tag, packet.drop(*effective_drops))) + finally: + await output.close() + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.columns, self.strict, - ) + ((stream,) if stream is not None else ()) + ) class MapTags(UnaryOperator): @@ -327,7 +513,7 @@ def __init__( self.drop_unmapped = drop_unmapped super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() missing_tags = set(tag_columns) - set(self.name_map.keys()) @@ -335,7 +521,7 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: # nothing to rename in the tags, return stream as is return stream - table = stream.as_table(include_source=True, include_system_tags=True) + table = stream.as_table(columns={"source": True, "system_tags": True}) name_map = { tc: self.name_map.get(tc, tc) for tc in tag_columns @@ -350,11 +536,12 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: # drop any tags that are not in the name map renamed_table = renamed_table.drop_columns(list(missing_tags)) - return TableStream( - renamed_table, tag_columns=new_tag_columns, source=self, upstreams=(stream,) + return ArrowTableStream( + renamed_table, + tag_columns=new_tag_columns, ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. @@ -379,21 +566,25 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: message += f"overlapping packet columns: {overlapping_packet_columns}." raise InputValidationError(message) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_typespec, packet_typespec = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) - # Create new packet typespec with renamed keys - new_tag_typespec = {self.name_map.get(k, k): v for k, v in tag_typespec.items()} + # Create new packet schema with renamed keys + new_tag_schema = {self.name_map.get(k, k): v for k, v in tag_schema.items()} - return new_tag_typespec, packet_typespec + return Schema(new_tag_schema), packet_schema - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.name_map, self.drop_unmapped, - ) + ((stream,) if stream is not None else ()) + ) diff --git a/src/orcapod/core/operators/filters.py b/src/orcapod/core/operators/filters.py index 2edf4f7c..15529f44 100644 --- a/src/orcapod/core/operators/filters.py +++ b/src/orcapod/core/operators/filters.py @@ -1,21 +1,20 @@ -from orcapod.protocols import core_protocols as cp -from orcapod.core.streams import TableStream -from orcapod.types import PythonSchema -from typing import Any, TYPE_CHECKING, TypeAlias -from orcapod.utils.lazy_module import LazyModule -from collections.abc import Collection, Mapping -from orcapod.errors import InputValidationError -from orcapod.core.system_constants import constants -from orcapod.core.operators.base import UnaryOperator import logging -from collections.abc import Iterable +from collections.abc import Collection, Iterable, Mapping +from typing import TYPE_CHECKING, Any, TypeAlias +from orcapod.core.operators.base import UnaryOperator +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, Schema +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: - import pyarrow as pa + import numpy as np import polars as pl import polars._typing as pl_type - import numpy as np + import pyarrow as pa else: pa = LazyModule("pyarrow") pl = LazyModule("polars") @@ -43,7 +42,7 @@ def __init__( self.constraints = constraints if constraints is not None else {} super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: if len(self.predicates) == 0 and len(self.constraints) == 0: logger.info( "No predicates or constraints specified. Returning stream unaltered." @@ -52,39 +51,103 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: # TODO: improve efficiency here... table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False} ) df = pl.DataFrame(table) filtered_table = df.filter(*self.predicates, **self.constraints).to_arrow() - return TableStream( + return ArrowTableStream( filtered_table, - tag_columns=stream.tag_keys(), - source=self, - upstreams=(stream,), + tag_columns=stream.keys()[0], ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. """ - # Any valid stream would work return - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + include_system_tags: bool = False, + ) -> tuple[Schema, Schema]: # data types are not modified - return stream.types(include_system_tags=include_system_tags) + return stream.output_schema(columns=columns, all_info=all_info) + + def to_config(self) -> dict[str, Any]: + """Serialize this PolarsFilter operator to a config dict. + + Polars ``Expr`` predicates are serialized to JSON strings when possible. + If any predicate cannot be serialized, ``reconstructable`` is set to + ``False`` and ``predicates`` is set to ``None``. - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys. + """ + config = super().to_config() + serialized_predicates = [] + reconstructable = True + for pred in self.predicates: + if hasattr(pred, "meta") and hasattr(pred.meta, "serialize"): + serialized = pred.meta.serialize(format="json") + # serialize() returns bytes in some Polars versions, str in others + if isinstance(serialized, bytes): + serialized = serialized.decode() + serialized_predicates.append(serialized) + else: + reconstructable = False + break + config["config"] = { + "constraints": dict(self.constraints) if self.constraints else None, + "predicates": serialized_predicates if reconstructable else None, + "reconstructable": reconstructable, + } + return config + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PolarsFilter": + """Reconstruct a PolarsFilter from a config dict. + + Args: + config: Dict as returned by ``to_config()``. + + Returns: + A new ``PolarsFilter`` instance. + + Raises: + NotImplementedError: If ``reconstructable`` is ``False``. + """ + inner = config.get("config", {}) + if not inner.get("reconstructable", True): + raise NotImplementedError( + "PolarsFilter with non-serializable predicates cannot be reconstructed" + ) + predicates = [] + if inner.get("predicates"): + import polars as pl + + predicates = [] + for p in inner["predicates"]: + # deserialize() accepts bytes in some Polars versions, str in others + try: + predicates.append(pl.Expr.deserialize(p.encode(), format="json")) + except TypeError: + predicates.append(pl.Expr.deserialize(p, format="json")) + constraints = inner.get("constraints") + return cls(predicates=predicates, constraints=constraints) + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.predicates, self.constraints, - ) + ((stream,) if stream is not None else ()) + ) class SelectPacketColumns(UnaryOperator): @@ -99,7 +162,7 @@ def __init__(self, columns: str | Collection[str], strict: bool = True, **kwargs self.strict = strict super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() packet_columns_to_drop = [c for c in packet_columns if c not in self.columns] new_packet_columns = [ @@ -111,7 +174,7 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: return stream table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False} ) # make sure to drop associated source fields associated_source_fields = [ @@ -121,20 +184,18 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: modified_table = table.drop_columns(packet_columns_to_drop) - return TableStream( + return ArrowTableStream( modified_table, tag_columns=tag_columns, - source=self, - upstreams=(stream,), ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. """ # TODO: remove redundant logic - tag_columns, packet_columns = stream.keys() + _, packet_columns = stream.keys() columns_to_select = self.columns missing_columns = set(columns_to_select) - set(packet_columns) if missing_columns and self.strict: @@ -142,11 +203,16 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: f"Missing packet columns: {missing_columns}. Make sure all specified columns to select are present or use strict=False to ignore missing columns" ) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_schema, packet_schema = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + include_system_tags: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) _, packet_columns = stream.keys() packets_to_drop = [pc for pc in packet_columns if pc not in self.columns] @@ -158,9 +224,9 @@ def op_output_types( return tag_schema, new_packet_schema - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.columns, self.strict, - ) + ((stream,) if stream is not None else ()) + ) diff --git a/src/orcapod/core/operators/join.py b/src/orcapod/core/operators/join.py index 04c65ee5..ed10ef78 100644 --- a/src/orcapod/core/operators/join.py +++ b/src/orcapod/core/operators/join.py @@ -1,17 +1,25 @@ -from orcapod.protocols import core_protocols as cp -from orcapod.core.streams import TableStream -from orcapod.types import PythonSchema -from orcapod.utils import types_utils -from typing import Any, TYPE_CHECKING -from orcapod.utils.lazy_module import LazyModule -from collections.abc import Collection -from orcapod.errors import InputValidationError +import asyncio +from collections.abc import Collection, Sequence +from typing import TYPE_CHECKING, Any + +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import NonZeroInputOperator -from orcapod.core import arrow_data_utils +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.protocols.core_protocols import ( + ArgumentGroup, + PacketProtocol, + StreamProtocol, + TagProtocol, +) +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, ContentHash, Schema +from orcapod.utils import arrow_data_utils, schema_utils +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: - import pyarrow as pa import polars as pl + import pyarrow as pa else: pa = LazyModule("pyarrow") pl = LazyModule("polars") @@ -26,50 +34,83 @@ def kernel_id(self) -> tuple[str, ...]: """ return (f"{self.__class__.__name__}",) - def op_validate_inputs(self, *streams: cp.Stream) -> None: + def validate_nonzero_inputs(self, *streams: StreamProtocol) -> None: + """Validate that input streams are compatible for joining.""" + # TODO: add more helpful validation try: - self.op_output_types(*streams) + self.output_schema(*streams) except Exception as e: - # raise InputValidationError(f"Input streams are not compatible: {e}") from e - raise e + raise InputValidationError(f"Input streams are not compatible: {e}") from e - def order_input_streams(self, *streams: cp.Stream) -> list[cp.Stream]: - # order the streams based on their hashes to offer deterministic operation - return sorted(streams, key=lambda s: s.content_hash().to_hex()) + def order_input_streams(self, *streams: StreamProtocol) -> list[StreamProtocol]: + # Canonically order by pipeline_hash for deterministic operation. + # pipeline_hash is structure-only, so streams with the same schema+topology + # get the same ordering regardless of data content. + return sorted(streams, key=lambda s: s.pipeline_hash().to_string()) - def op_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - if len(streams) == 1: - # If only one stream is provided, return its typespecs - return streams[0].types(include_system_tags=include_system_tags) + def argument_symmetry(self, streams: Collection) -> ArgumentGroup: + return frozenset(streams) - # output type computation does NOT require consistent ordering of streams + def output_schema( + self, + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + columns_config = ColumnConfig.handle_config(columns, all_info=all_info) + + if len(streams) == 1: + # if single stream, simply return the output schema of the single input stream + return streams[0].output_schema(columns=columns, all_info=all_info) - # TODO: consider performing the check always with system tags on + # Always get input schemas WITHOUT system tags for the base computation. + # System tags are computed separately because the join renames them. stream = streams[0] - tag_typespec, packet_typespec = stream.types( - include_system_tags=include_system_tags - ) + tag_schema, packet_schema = stream.output_schema() for other_stream in streams[1:]: - other_tag_typespec, other_packet_typespec = other_stream.types( - include_system_tags=include_system_tags - ) - tag_typespec = types_utils.union_typespecs(tag_typespec, other_tag_typespec) - intersection_packet_typespec = types_utils.intersection_typespecs( - packet_typespec, other_packet_typespec + other_tag_schema, other_packet_schema = other_stream.output_schema() + tag_schema = schema_utils.union_schemas(tag_schema, other_tag_schema) + intersection_packet_schema = schema_utils.intersection_schemas( + packet_schema, other_packet_schema ) - packet_typespec = types_utils.union_typespecs( - packet_typespec, other_packet_typespec + packet_schema = schema_utils.union_schemas( + packet_schema, other_packet_schema ) - if intersection_packet_typespec: + if intersection_packet_schema: raise InputValidationError( - f"Packets should not have overlapping keys, but {packet_typespec.keys()} found in {stream} and {other_stream}." + f"Packets should not have overlapping keys, but {packet_schema.keys()} found in {stream} and {other_stream}." ) - return tag_typespec, packet_typespec + # Add system tag columns if requested + if columns_config.system_tags: + system_tag_schema = self._predict_system_tag_schema(*streams) + tag_schema = schema_utils.union_schemas(tag_schema, system_tag_schema) + + return tag_schema, packet_schema - def op_forward(self, *streams: cp.Stream) -> cp.Stream: + def _predict_system_tag_schema(self, *streams: StreamProtocol) -> Schema: + """Predict the system tag columns that the join would produce. + + Each input stream's existing system tag columns get renamed by + appending ::{pipeline_hash}:{canonical_position}. This method + computes those output column names without performing the join. + """ + n_char = self.orcapod_config.system_tag_hash_n_char + ordered_streams = self.order_input_streams(*streams) + + system_tag_fields: dict[str, type] = {} + for idx, stream in enumerate(ordered_streams): + stream_tag_schema, _ = stream.output_schema(columns={"system_tags": True}) + for col_name in stream_tag_schema: + if col_name.startswith(constants.SYSTEM_TAG_PREFIX): + new_name = ( + f"{col_name}{constants.BLOCK_SEPARATOR}" + f"{stream.pipeline_hash().to_hex(n_char)}:{idx}" + ) + system_tag_fields[new_name] = str + return Schema(system_tag_fields) + + def static_process(self, *streams: StreamProtocol) -> StreamProtocol: """ Joins two streams together based on their tags. The resulting stream will contain all the tags from both streams. @@ -77,35 +118,62 @@ def op_forward(self, *streams: cp.Stream) -> cp.Stream: if len(streams) == 1: return streams[0] + # Canonically order streams by pipeline_hash for deterministic + # system tag column names regardless of input order (Join is commutative) + ordered_streams = self.order_input_streams(*streams) + COMMON_JOIN_KEY = "_common" - stream = streams[0] + n_char = self.orcapod_config.system_tag_hash_n_char + + stream = ordered_streams[0] tag_keys, _ = [set(k) for k in stream.keys()] - table = stream.as_table(include_source=True, include_system_tags=True) + table = stream.as_table( + columns={"source": True, "system_tags": True, "meta": True} + ) # trick to get cartesian product table = table.add_column(0, COMMON_JOIN_KEY, pa.array([0] * len(table))) table = arrow_data_utils.append_to_system_tags( table, - stream.content_hash().to_hex(self.orcapod_config.system_tag_hash_n_char), + f"{stream.pipeline_hash().to_hex(n_char)}:0", ) - for next_stream in streams[1:]: + for idx, next_stream in enumerate(ordered_streams[1:], start=1): next_tag_keys, _ = next_stream.keys() next_table = next_stream.as_table( - include_source=True, include_system_tags=True + columns={"source": True, "system_tags": True, "meta": True} ) next_table = arrow_data_utils.append_to_system_tags( next_table, - next_stream.content_hash().to_hex( - char_count=self.orcapod_config.system_tag_hash_n_char - ), + f"{next_stream.pipeline_hash().to_hex(n_char)}:{idx}", ) # trick to ensure that there will always be at least one shared key # this ensure that no overlap in keys lead to full caretesian product next_table = next_table.add_column( 0, COMMON_JOIN_KEY, pa.array([0] * len(next_table)) ) + + # Rename any non-key columns in next_table that would collide with + # the accumulated table, using stream-index-based suffixes instead of + # Polars' default ``_right`` suffix which causes cascading collisions + # on 3+ stream joins. The only legitimately shared column names are + # the tag join keys; everything else (meta columns, their derived + # source-info columns, etc.) must be unique. + join_key_set = tag_keys.intersection(next_tag_keys) | {COMMON_JOIN_KEY} + existing_names = set(table.column_names) + rename_map = {} + for col in next_table.column_names: + if col not in join_key_set and col in existing_names: + new_name = f"{col}_{idx}" + counter = idx + while new_name in existing_names or new_name in rename_map.values(): + counter += 1 + new_name = f"{col}_{counter}" + rename_map[col] = new_name + if rename_map: + next_table = pl.DataFrame(next_table).rename(rename_map).to_arrow() + common_tag_keys = tag_keys.intersection(next_tag_keys) common_tag_keys.add(COMMON_JOIN_KEY) @@ -120,22 +188,293 @@ def op_forward(self, *streams: cp.Stream) -> cp.Stream: # reorder columns to bring tag columns to the front # TODO: come up with a better algorithm table = table.drop(COMMON_JOIN_KEY) + + # Sort system tag values for same-pipeline-hash streams to ensure commutativity + table = arrow_data_utils.sort_system_tag_values(table) + reordered_columns = [col for col in table.column_names if col in tag_keys] reordered_columns += [col for col in table.column_names if col not in tag_keys] - return TableStream( + return ArrowTableStream( table.select(reordered_columns), tag_columns=tuple(tag_keys), - source=self, - upstreams=streams, ) - def op_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - return ( - (self.__class__.__name__,) + (set(streams),) if streams is not None else () - ) + # ------------------------------------------------------------------ + # Async execution + # ------------------------------------------------------------------ + + def _compute_system_tag_suffixes( + self, + input_pipeline_hashes: Sequence[ContentHash], + ) -> list[str]: + """Compute per-input system-tag suffixes from pipeline hashes. + + Each suffix is ``{truncated_hash}:{canonical_position}`` where + canonical position is determined by sorting the hashes (matching + the deterministic ordering used by ``static_process``). + + Args: + input_pipeline_hashes: Pipeline hash per input, positionally + matching the input channels. + + Returns: + List of suffix strings, one per input position. + """ + n_char = self.orcapod_config.system_tag_hash_n_char + hex_strings = [h.to_hex() for h in input_pipeline_hashes] + + # Canonical order: sorted by full hex (same as order_input_streams) + sorted_hexes = sorted(hex_strings) + + suffixes: list[str] = [] + for orig_idx, hex_str in enumerate(hex_strings): + canon_idx = sorted_hexes.index(hex_str) + truncated = input_pipeline_hashes[orig_idx].to_hex(n_char) + suffixes.append(f"{truncated}:{canon_idx}") + return suffixes + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, + ) -> None: + """Async join with streaming symmetric hash join for two inputs. + + Single input: streams through directly without any buffering. + + Two inputs: symmetric hash join — each arriving row is + immediately probed against the opposite side's buffer, emitting + matches as soon as found. System-tag columns are correctly + renamed using the ``input_pipeline_hashes``. + + Three or more inputs: collects all inputs concurrently, then + delegates to ``static_process`` for the Polars N-way join. + + Args: + inputs: Readable channels, one per upstream. + output: Writable channel for downstream. + input_pipeline_hashes: Pipeline hash for each input, + positionally matching ``inputs``. Required for + correct system-tag renaming with 2+ inputs. + """ + try: + if len(inputs) == 1: + async for tag, packet in inputs[0]: + await output.send((tag, packet)) + return + + # TODO: carefully revisit the logic behind system tag handling + if len(inputs) == 2: + suffixes = ( + self._compute_system_tag_suffixes(input_pipeline_hashes) + if input_pipeline_hashes is not None + else ["0", "1"] + ) + await self._symmetric_hash_join(inputs[0], inputs[1], output, suffixes) + return + + # N > 2: concurrent collection + static_process + all_rows = await asyncio.gather(*(ch.collect() for ch in inputs)) + + # Guard against empty inputs — join with an empty side is empty + if any(len(rows) == 0 for rows in all_rows): + return + + streams = [self._materialize_to_stream(rows) for rows in all_rows] + result = self.static_process(*streams) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() + + async def _symmetric_hash_join( + self, + left_ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + right_ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + suffixes: list[str], + ) -> None: + """Symmetric hash join for two inputs. + + Both sides are read concurrently via a merged bounded queue. + Each arriving row is added to its side's index and immediately + probed against the opposite side. Matched rows are emitted to + ``output`` as soon as found, so downstream consumers can begin + work before either input is fully consumed. + + Args: + left_ch: Left input channel. + right_ch: Right input channel. + output: Output channel for matched rows. + suffixes: Per-input system-tag suffixes (positional), + computed from pipeline hashes and canonical ordering. + """ + # Bounded queue preserves backpressure — producers block when full. + _SENTINEL = object() + queue: asyncio.Queue = asyncio.Queue(maxsize=64) + + async def _drain( + ch: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + side: int, + ) -> None: + async for item in ch: + await queue.put((side, item)) + await queue.put((side, _SENTINEL)) + + block_sep = constants.BLOCK_SEPARATOR + + async with asyncio.TaskGroup() as tg: + tg.create_task(_drain(left_ch, 0)) + tg.create_task(_drain(right_ch, 1)) + + # buffers[i] holds all rows seen so far from input i + buffers: list[list[tuple[TagProtocol, PacketProtocol]]] = [[], []] + # indexes[i] maps shared-key tuple → list of indices into buffers[i] + indexes: list[dict[tuple, list[int]]] = [{}, {}] + + shared_keys: tuple[str, ...] | None = None + needs_reindex = False + closed_count = 0 + + while closed_count < 2: + side, item = await queue.get() + + if item is _SENTINEL: + closed_count += 1 + continue + + tag, pkt = item + other = 1 - side + + # Determine shared tag keys once we have rows from both sides + if shared_keys is None: + if not buffers[other]: + # Other side empty — just buffer this row for later + buffers[side].append((tag, pkt)) + continue + + # We have data from both sides; compute shared keys + this_keys = set(tag.keys()) + other_keys = set(buffers[other][0][0].keys()) + shared_keys = tuple(sorted(this_keys & other_keys)) + needs_reindex = True + + # One-time re-index of all rows buffered before shared_keys + if needs_reindex: + needs_reindex = False + for buf_side in (0, 1): + for j, (bt, _bp) in enumerate(buffers[buf_side]): + btd = bt.as_dict() + k = ( + tuple(btd[sk] for sk in shared_keys) + if shared_keys + else (0,) + ) + indexes[buf_side].setdefault(k, []).append(j) + + # Emit matches for all already-buffered rows across sides + for li, (lt, lp) in enumerate(buffers[0]): + ltd = lt.as_dict() + lk = ( + tuple(ltd[sk] for sk in shared_keys) + if shared_keys + else (0,) + ) + for ri in indexes[1].get(lk, []): + rt, rp = buffers[1][ri] + await output.send( + self._merge_row_pair( + lt, lp, rt, rp, suffixes, block_sep + ) + ) + + # Index the new row + td = tag.as_dict() + key = tuple(td[sk] for sk in shared_keys) if shared_keys else (0,) + row_idx = len(buffers[side]) + buffers[side].append((tag, pkt)) + indexes[side].setdefault(key, []).append(row_idx) + + # Probe the opposite buffer for matches + matching_indices = indexes[other].get(key, []) + for mi in matching_indices: + other_tag, other_pkt = buffers[other][mi] + if side == 0: + merged = self._merge_row_pair( + tag, + pkt, + other_tag, + other_pkt, + suffixes, + block_sep, + ) + else: + merged = self._merge_row_pair( + other_tag, + other_pkt, + tag, + pkt, + suffixes, + block_sep, + ) + await output.send(merged) + + @staticmethod + def _merge_row_pair( + left_tag: TagProtocol, + left_pkt: PacketProtocol, + right_tag: TagProtocol, + right_pkt: PacketProtocol, + suffixes: list[str], + block_sep: str, + ) -> tuple[TagProtocol, PacketProtocol]: + """Merge a matched pair of rows into one joined (Tag, Packet). + + System-tag keys are renamed by appending + ``{block_sep}{suffix}`` to match the canonical name-extending + scheme used by ``static_process``. System-tag values sharing + the same provenance path are sorted for commutativity. + """ + from orcapod.core.datagrams import Packet, Tag + + sys_prefix = constants.SYSTEM_TAG_PREFIX + + # Merge tag dicts (shared keys come from left) + merged_tag_d: dict = {} + merged_tag_d.update(left_tag.as_dict()) + for k, v in right_tag.as_dict().items(): + if k not in merged_tag_d: + merged_tag_d[k] = v + + # Rename and merge system tags with canonical suffixes + merged_sys: dict = {} + for k, v in left_tag.system_tags().items(): + new_key = f"{k}{block_sep}{suffixes[0]}" if k.startswith(sys_prefix) else k + merged_sys[new_key] = v + for k, v in right_tag.system_tags().items(): + new_key = f"{k}{block_sep}{suffixes[1]}" if k.startswith(sys_prefix) else k + merged_sys[new_key] = v + + merged_tag = Tag(merged_tag_d, system_tags=merged_sys) + + # Merge packet dicts (non-overlapping by Join's validation) + merged_pkt_d: dict = {} + merged_pkt_d.update(left_pkt.as_dict()) + merged_pkt_d.update(right_pkt.as_dict()) + + merged_si: dict = {} + merged_si.update(left_pkt.source_info()) + merged_si.update(right_pkt.source_info()) + + merged_pkt = Packet(merged_pkt_d, source_info=merged_si) + + return merged_tag, merged_pkt + + def identity_structure(self) -> Any: + return self.__class__.__name__ def __repr__(self) -> str: return "Join()" diff --git a/src/orcapod/core/operators/mappers.py b/src/orcapod/core/operators/mappers.py index 5500e1bd..a7a32727 100644 --- a/src/orcapod/core/operators/mappers.py +++ b/src/orcapod/core/operators/mappers.py @@ -1,12 +1,14 @@ -from orcapod.protocols import core_protocols as cp -from orcapod.core.streams import TableStream -from orcapod.types import PythonSchema -from typing import Any, TYPE_CHECKING -from orcapod.utils.lazy_module import LazyModule -from collections.abc import Mapping -from orcapod.errors import InputValidationError -from orcapod.core.system_constants import constants +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any + +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import UnaryOperator +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, Schema +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa @@ -28,7 +30,21 @@ def __init__( self.drop_unmapped = drop_unmapped super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def to_config(self) -> dict[str, Any]: + """Serialize this MapPackets operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``name_map`` and ``drop_unmapped``. + """ + config = super().to_config() + config["config"] = { + "name_map": dict(self.name_map), + "drop_unmapped": self.drop_unmapped, + } + return config + + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() unmapped_columns = set(packet_columns) - set(self.name_map.keys()) @@ -37,7 +53,7 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: return stream table = stream.as_table( - include_source=True, include_system_tags=True, sort_by_tags=False + columns={"source": True, "system_tags": True, "sort_by_tags": False} ) name_map = { @@ -64,15 +80,9 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: if self.drop_unmapped and unmapped_columns: renamed_table = renamed_table.drop_columns(list(unmapped_columns)) - return TableStream( - renamed_table, tag_columns=tag_columns, source=self, upstreams=(stream,) - ) + return ArrowTableStream(renamed_table, tag_columns=tag_columns) - def op_validate_inputs(self, stream: cp.Stream) -> None: - """ - This method should be implemented by subclasses to validate the inputs to the operator. - It takes two streams as input and raises an error if the inputs are not valid. - """ + def validate_unary_input(self, stream: StreamProtocol) -> None: # verify that renamed value does NOT collide with other columns tag_columns, packet_columns = stream.keys() relevant_source = [] @@ -95,28 +105,60 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: message += f"overlapping tag columns: {overlapping_tag_columns}." raise InputValidationError(message) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_typespec, packet_typespec = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) - # Create new packet typespec with renamed keys - new_packet_typespec = { + # Create new packet schema with renamed keys + new_packet_schema = { self.name_map.get(k, k): v - for k, v in packet_typespec.items() + for k, v in packet_schema.items() if k in self.name_map or not self.drop_unmapped } - return tag_typespec, new_packet_typespec - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + return tag_schema, Schema(new_packet_schema) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: rename packet columns per row without materializing.""" + try: + rename_map: dict[str, str] | None = None + unmapped: list[str] | None = None + async for tag, packet in inputs[0]: + if rename_map is None: + pkt_keys = packet.keys() + rename_map = { + k: self.name_map[k] for k in pkt_keys if k in self.name_map + } + if self.drop_unmapped: + unmapped = [k for k in pkt_keys if k not in self.name_map] + if not rename_map: + await output.send((tag, packet)) + else: + new_pkt = packet.rename(rename_map) + if unmapped: + new_pkt = new_pkt.drop(*unmapped) + await output.send((tag, new_pkt)) + finally: + await output.close() + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.name_map, self.drop_unmapped, - ) + ((stream,) if stream is not None else ()) + ) class MapTags(UnaryOperator): @@ -133,7 +175,21 @@ def __init__( self.drop_unmapped = drop_unmapped super().__init__(**kwargs) - def op_forward(self, stream: cp.Stream) -> cp.Stream: + def to_config(self) -> dict[str, Any]: + """Serialize this MapTags operator to a config dict. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys, + where ``config`` contains ``name_map`` and ``drop_unmapped``. + """ + config = super().to_config() + config["config"] = { + "name_map": dict(self.name_map), + "drop_unmapped": self.drop_unmapped, + } + return config + + def unary_static_process(self, stream: StreamProtocol) -> StreamProtocol: tag_columns, packet_columns = stream.keys() missing_tags = set(tag_columns) - set(self.name_map.keys()) @@ -141,7 +197,9 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: # nothing to rename in the tags, return stream as is return stream - table = stream.as_table(include_source=True, include_system_tags=True) + table = stream.as_table( + columns={"source": True, "system_tags": True, "sort_by_tags": False} + ) name_map = { tc: self.name_map.get(tc, tc) @@ -158,11 +216,12 @@ def op_forward(self, stream: cp.Stream) -> cp.Stream: # drop any tags that are not in the name map renamed_table = renamed_table.drop_columns(list(missing_tags)) - return TableStream( - renamed_table, tag_columns=new_tag_columns, source=self, upstreams=(stream,) + return ArrowTableStream( + renamed_table, + tag_columns=new_tag_columns, ) - def op_validate_inputs(self, stream: cp.Stream) -> None: + def validate_unary_input(self, stream: StreamProtocol) -> None: """ This method should be implemented by subclasses to validate the inputs to the operator. It takes two streams as input and raises an error if the inputs are not valid. @@ -187,30 +246,56 @@ def op_validate_inputs(self, stream: cp.Stream) -> None: message += f"overlapping packet columns: {overlapping_packet_columns}." raise InputValidationError(message) - def op_output_types( - self, stream: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_typespec, packet_typespec = stream.types( - include_system_tags=include_system_tags + def unary_output_schema( + self, + stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + tag_schema, packet_schema = stream.output_schema( + columns=columns, all_info=all_info ) - # Create new packet typespec with renamed keys - new_tag_typespec = {self.name_map.get(k, k): v for k, v in tag_typespec.items()} - - # Create new packet typespec with renamed keys - new_tag_typespec = { + new_tag_schema = { self.name_map.get(k, k): v - for k, v in tag_typespec.items() + for k, v in tag_schema.items() if k in self.name_map or not self.drop_unmapped } - return new_tag_typespec, packet_typespec - - return new_tag_typespec, packet_typespec - - def op_identity_structure(self, stream: cp.Stream | None = None) -> Any: + return Schema(new_tag_schema), packet_schema + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Streaming: rename tag columns per row without materializing.""" + try: + rename_map: dict[str, str] | None = None + unmapped: list[str] | None = None + async for tag, packet in inputs[0]: + if rename_map is None: + tag_keys = tag.keys() + rename_map = { + k: self.name_map[k] for k in tag_keys if k in self.name_map + } + if self.drop_unmapped: + unmapped = [k for k in tag_keys if k not in self.name_map] + if not rename_map: + await output.send((tag, packet)) + else: + new_tag = tag.rename(rename_map) + if unmapped: + new_tag = new_tag.drop(*unmapped) + await output.send((new_tag, packet)) + finally: + await output.close() + + def identity_structure(self) -> Any: return ( self.__class__.__name__, self.name_map, self.drop_unmapped, - ) + ((stream,) if stream is not None else ()) + ) diff --git a/src/orcapod/core/operators/merge_join.py b/src/orcapod/core/operators/merge_join.py new file mode 100644 index 00000000..8edcaec0 --- /dev/null +++ b/src/orcapod/core/operators/merge_join.py @@ -0,0 +1,296 @@ +from typing import TYPE_CHECKING, Any + +from orcapod.core.operators.base import BinaryOperator +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, Schema +from orcapod.utils import arrow_data_utils, schema_utils +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import polars as pl + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + pl = LazyModule("polars") + + +class MergeJoin(BinaryOperator): + """ + Binary operator that joins two streams, merging colliding packet columns + into sorted lists. + + For packet columns that exist in both streams: + - Values are combined into a list and sorted independently per column. + - Corresponding source columns are reordered to match the sort order of + their packet column. + + For non-colliding columns, values are kept as scalars (same as regular Join). + + Tag columns use inner join on shared tags, with union of tag schemas. + + MergeJoin is commutative: MergeJoin(A, B) produces the same result as + MergeJoin(B, A), achieved by sorting merged values and system tag values. + """ + + @property + def kernel_id(self) -> tuple[str, ...]: + return (f"{self.__class__.__name__}",) + + def is_commutative(self) -> bool: + return True + + def validate_binary_inputs( + self, left_stream: StreamProtocol, right_stream: StreamProtocol + ) -> None: + _, left_packet_schema = left_stream.output_schema() + _, right_packet_schema = right_stream.output_schema() + + # Colliding packet columns must have identical types since they are + # merged into list[T] — both sides must contribute the same T. + colliding_keys = set(left_packet_schema.keys()) & set( + right_packet_schema.keys() + ) + for key in colliding_keys: + left_type = left_packet_schema[key] + right_type = right_packet_schema[key] + if left_type != right_type: + raise InputValidationError( + f"Colliding packet column '{key}' has incompatible types: " + f"{left_type} (left) vs {right_type} (right). " + f"MergeJoin requires colliding columns to have identical types." + ) + + try: + self.binary_output_schema(left_stream, right_stream) + except InputValidationError: + raise + except Exception as e: + raise InputValidationError( + f"Input streams are not compatible for merge join: {e}" + ) from e + + def binary_output_schema( + self, + left_stream: StreamProtocol, + right_stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + columns_config = ColumnConfig.handle_config(columns, all_info=all_info) + + # Always get input schemas WITHOUT system tags for the base computation. + # System tags are computed separately because the join renames them. + left_tag_schema, left_packet_schema = left_stream.output_schema() + right_tag_schema, right_packet_schema = right_stream.output_schema() + + # Tag schema: union of both tag schemas + tag_schema = schema_utils.union_schemas(left_tag_schema, right_tag_schema) + + # Packet schema: colliding columns become list[T], non-colliding stay scalar + colliding_schema = schema_utils.intersection_schemas( + left_packet_schema, right_packet_schema + ) + + merged_packet_schema = {} + all_packet_keys = set(left_packet_schema.keys()) | set( + right_packet_schema.keys() + ) + for key in all_packet_keys: + if key in colliding_schema: + merged_packet_schema[key] = list[colliding_schema[key]] + elif key in left_packet_schema: + merged_packet_schema[key] = left_packet_schema[key] + else: + merged_packet_schema[key] = right_packet_schema[key] + + # Add system tag columns if requested + if columns_config.system_tags: + system_tag_schema = self._predict_system_tag_schema( + left_stream, right_stream + ) + tag_schema = schema_utils.union_schemas(tag_schema, system_tag_schema) + + return tag_schema, Schema(merged_packet_schema) + + def _canonical_order( + self, left_stream: StreamProtocol, right_stream: StreamProtocol + ) -> list[tuple[StreamProtocol, int]]: + """ + Determine canonical ordering of the two input streams by stable-sorting + on pipeline_hash. Returns list of (stream, original_index) tuples in + canonical order. + """ + streams_with_idx = [(left_stream, 0), (right_stream, 1)] + # Python's sorted is stable, so equal pipeline_hashes preserve input order + return sorted(streams_with_idx, key=lambda s: s[0].pipeline_hash().to_hex()) + + def _predict_system_tag_schema( + self, left_stream: StreamProtocol, right_stream: StreamProtocol + ) -> Schema: + """Predict the system tag columns that the join would produce. + + Each input stream's existing system tag columns get renamed by + appending ::{pipeline_hash}:{canonical_position}. This method + computes those output column names without performing the join. + """ + n_char = self.orcapod_config.system_tag_hash_n_char + canonical = self._canonical_order(left_stream, right_stream) + + system_tag_fields: dict[str, type] = {} + for stream, orig_idx in canonical: + canon_pos = canonical.index((stream, orig_idx)) + stream_tag_schema, _ = stream.output_schema(columns={"system_tags": True}) + for col_name in stream_tag_schema: + if col_name.startswith(constants.SYSTEM_TAG_PREFIX): + new_name = ( + f"{col_name}{constants.BLOCK_SEPARATOR}" + f"{stream.pipeline_hash().to_hex(n_char)}:{canon_pos}" + ) + system_tag_fields[new_name] = str + return Schema(system_tag_fields) + + def binary_static_process( + self, left_stream: StreamProtocol, right_stream: StreamProtocol + ) -> StreamProtocol: + n_char = self.orcapod_config.system_tag_hash_n_char + + # Determine canonical ordering for system tag positions + canonical = self._canonical_order(left_stream, right_stream) + + # Get tables with source + system_tags, append system tag blocks + tables = {} + for stream, orig_idx in canonical: + canon_pos = canonical.index((stream, orig_idx)) + table = stream.as_table(columns={"source": True, "system_tags": True}) + table = arrow_data_utils.append_to_system_tags( + table, f"{stream.pipeline_hash().to_hex(n_char)}:{canon_pos}" + ) + tables[orig_idx] = table + + left_table = tables[0] + right_table = tables[1] + + # Determine shared tag keys for inner join + left_tag_keys, left_packet_keys = left_stream.keys() + right_tag_keys, right_packet_keys = right_stream.keys() + shared_tag_keys = set(left_tag_keys) & set(right_tag_keys) + + # Find colliding packet columns + colliding_keys = set(left_packet_keys) & set(right_packet_keys) + + # Perform inner join via Polars on shared tag keys + # Use a common key trick to ensure cartesian product if no shared tags + COMMON_JOIN_KEY = "_common" + left_table = left_table.add_column( + 0, COMMON_JOIN_KEY, pa.array([0] * len(left_table)) + ) + right_table = right_table.add_column( + 0, COMMON_JOIN_KEY, pa.array([0] * len(right_table)) + ) + + join_keys = list(shared_tag_keys | {COMMON_JOIN_KEY}) + + # Track which columns Polars will auto-suffix with _right + # (right-table columns that collide with left, excluding join keys) + left_col_set = set(left_table.column_names) - {COMMON_JOIN_KEY} + right_col_set = set(right_table.column_names) - {COMMON_JOIN_KEY} + join_key_set = set(join_keys) - {COMMON_JOIN_KEY} + polars_suffixed_bases = (right_col_set & left_col_set) - join_key_set + + joined = ( + pl.DataFrame(left_table) + .join(pl.DataFrame(right_table), on=join_keys, how="inner") + .to_arrow() + ) + joined = joined.drop(COMMON_JOIN_KEY) + + # Process colliding packet columns: merge into sorted lists + for col in colliding_keys: + left_col_name = col + right_col_name = f"{col}_right" + + left_source_col = f"{constants.SOURCE_PREFIX}{col}" + right_source_col = f"{left_source_col}_right" + + if right_col_name not in joined.column_names: + continue + + left_vals = joined.column(left_col_name).to_pylist() + right_vals = joined.column(right_col_name).to_pylist() + + # Also handle corresponding source columns + has_source = ( + left_source_col in joined.column_names + and right_source_col in joined.column_names + ) + if has_source: + left_sources = joined.column(left_source_col).to_pylist() + right_sources = joined.column(right_source_col).to_pylist() + + merged_vals = [] + merged_sources = [] if has_source else None + for i in range(len(left_vals)): + lv, rv = left_vals[i], right_vals[i] + if has_source: + ls, rs = left_sources[i], right_sources[i] + # Sort by packet value, carry source along + pairs = sorted(zip([lv, rv], [ls, rs]), key=lambda p: p[0]) + merged_vals.append([p[0] for p in pairs]) + merged_sources.append([p[1] for p in pairs]) + else: + merged_vals.append(sorted([lv, rv])) + + # Replace the left column with merged list, drop right column + col_idx = joined.column_names.index(left_col_name) + joined = joined.drop(left_col_name) + joined = joined.drop(right_col_name) + + merged_array = pa.array(merged_vals) + joined = joined.add_column(col_idx, left_col_name, merged_array) + + if has_source: + source_idx = joined.column_names.index(left_source_col) + joined = joined.drop(left_source_col) + joined = joined.drop(right_source_col) + source_array = pa.array(merged_sources) + joined = joined.add_column(source_idx, left_source_col, source_array) + + # Handle remaining Polars-generated _right suffixed columns + # (only from columns we know Polars auto-suffixed, not original names) + for base_name in polars_suffixed_bases: + suffixed_name = f"{base_name}_right" + if suffixed_name not in joined.column_names: + continue # Already handled during colliding column processing + if base_name not in joined.column_names: + # Left version was removed, rename right to original + idx = joined.column_names.index(suffixed_name) + col_data = joined.column(suffixed_name) + joined = joined.drop(suffixed_name) + joined = joined.add_column(idx, base_name, col_data) + else: + # Both versions exist, drop the right one + joined = joined.drop(suffixed_name) + + # Sort system tag values for same-pipeline-hash streams to ensure commutativity + joined = arrow_data_utils.sort_system_tag_values(joined) + + # Reorder: tag columns first, then packet columns + all_tag_keys = set(left_tag_keys) | set(right_tag_keys) + tag_cols = [c for c in joined.column_names if c in all_tag_keys] + other_cols = [c for c in joined.column_names if c not in all_tag_keys] + joined = joined.select(tag_cols + other_cols) + + return ArrowTableStream( + joined, + tag_columns=tuple(all_tag_keys), + ) + + def identity_structure(self) -> Any: + return self.__class__.__name__ + + def __repr__(self) -> str: + return "MergeJoin()" diff --git a/src/orcapod/core/operators/semijoin.py b/src/orcapod/core/operators/semijoin.py index 6cdff4cc..0507a900 100644 --- a/src/orcapod/core/operators/semijoin.py +++ b/src/orcapod/core/operators/semijoin.py @@ -1,11 +1,14 @@ -from orcapod.protocols import core_protocols as cp -from orcapod.core.streams import TableStream -from orcapod.utils import types_utils -from orcapod.types import PythonSchema -from typing import Any, TYPE_CHECKING -from orcapod.utils.lazy_module import LazyModule -from orcapod.errors import InputValidationError +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from orcapod.channels import ReadableChannel, WritableChannel from orcapod.core.operators.base import BinaryOperator +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.protocols.core_protocols import PacketProtocol, StreamProtocol, TagProtocol +from orcapod.types import ColumnConfig, Schema +from orcapod.utils import schema_utils +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa @@ -27,57 +30,34 @@ class SemiJoin(BinaryOperator): The output stream preserves the schema of the left stream exactly. """ - @property - def kernel_id(self) -> tuple[str, ...]: - """ - Returns a unique identifier for the kernel. - This is used to identify the kernel in the computational graph. - """ - return (f"{self.__class__.__name__}",) - - def op_identity_structure( - self, - left_stream: cp.Stream | None = None, - right_stream: cp.Stream | None = None, - ) -> Any: - """ - Return a structure that represents the identity of this operator. - Unlike Join, SemiJoin depends on the order of streams (left vs right). - """ - id_struct = (self.__class__.__name__,) - if left_stream is not None and right_stream is not None: - # Order matters for semi-join: (left_stream, right_stream) - id_struct += (left_stream, right_stream) - return id_struct - - def op_forward(self, left_stream: cp.Stream, right_stream: cp.Stream) -> cp.Stream: + def binary_static_process( + self, left_stream: StreamProtocol, right_stream: StreamProtocol + ) -> StreamProtocol: """ Performs a semi-join between left and right streams. Returns entries from left stream that have matching entries in right stream. """ - left_tag_typespec, left_packet_typespec = left_stream.types() - right_tag_typespec, right_packet_typespec = right_stream.types() + left_tag_schema, left_packet_schema = left_stream.output_schema() + right_tag_schema, right_packet_schema = right_stream.output_schema() # Find overlapping columns across all columns (tags + packets) - left_all_typespec = types_utils.union_typespecs( - left_tag_typespec, left_packet_typespec + left_all_schema = schema_utils.union_schemas( + left_tag_schema, left_packet_schema ) - right_all_typespec = types_utils.union_typespecs( - right_tag_typespec, right_packet_typespec + right_all_schema = schema_utils.union_schemas( + right_tag_schema, right_packet_schema ) common_keys = tuple( - types_utils.intersection_typespecs( - left_all_typespec, right_all_typespec - ).keys() + schema_utils.intersection_schemas(left_all_schema, right_all_schema).keys() ) # If no overlapping columns, return the left stream unmodified if not common_keys: return left_stream - # include source info for left stream - left_table = left_stream.as_table(include_source=True) + # include source info and system tags for left stream + left_table = left_stream.as_table(columns={"source": True, "system_tags": True}) # Get the right table for matching right_table = right_stream.as_table() @@ -89,52 +69,144 @@ def op_forward(self, left_stream: cp.Stream, right_stream: cp.Stream) -> cp.Stre join_type="left semi", ) - return TableStream( + return ArrowTableStream( semi_joined_table, - tag_columns=tuple(left_tag_typespec.keys()), - source=self, - upstreams=(left_stream, right_stream), + tag_columns=tuple(left_tag_schema.keys()), ) - def op_output_types( + def binary_output_schema( self, - left_stream: cp.Stream, - right_stream: cp.Stream, - include_system_tags: bool = False, - ) -> tuple[PythonSchema, PythonSchema]: + left_stream: StreamProtocol, + right_stream: StreamProtocol, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: """ Returns the output types for the semi-join operation. The output preserves the exact schema of the left stream. """ # Semi-join preserves the left stream's schema exactly - return left_stream.types(include_system_tags=include_system_tags) + return left_stream.output_schema(columns=columns, all_info=all_info) - def op_validate_inputs( - self, left_stream: cp.Stream, right_stream: cp.Stream + def validate_binary_inputs( + self, left_stream: StreamProtocol, right_stream: StreamProtocol ) -> None: """ Validates that the input streams are compatible for semi-join. Checks that overlapping columns have compatible types. + + Stores the common keys so that ``async_execute`` can use them + to determine the correct empty-right behavior without data. """ try: - left_tag_typespec, left_packet_typespec = left_stream.types() - right_tag_typespec, right_packet_typespec = right_stream.types() + left_tag_schema, left_packet_schema = left_stream.output_schema() + right_tag_schema, right_packet_schema = right_stream.output_schema() # Check that overlapping columns have compatible types across all columns - left_all_typespec = types_utils.union_typespecs( - left_tag_typespec, left_packet_typespec + left_all_schema = schema_utils.union_schemas( + left_tag_schema, left_packet_schema ) - right_all_typespec = types_utils.union_typespecs( - right_tag_typespec, right_packet_typespec + right_all_schema = schema_utils.union_schemas( + right_tag_schema, right_packet_schema ) - # intersection_typespecs will raise an error if types are incompatible - types_utils.intersection_typespecs(left_all_typespec, right_all_typespec) + # intersection_schemas will raise an error if types are incompatible + common = schema_utils.intersection_schemas( + left_all_schema, right_all_schema + ) + self._validated_common_keys: tuple[str, ...] = tuple(common.keys()) except Exception as e: raise InputValidationError( f"Input streams are not compatible for semi-join: {e}" ) from e - def __repr__(self) -> str: - return "SemiJoin()" + def is_commutative(self) -> bool: + return False + + def _common_keys_from_schema(self) -> tuple[str, ...]: + """Return the common keys computed during input validation. + + Falls back to an empty tuple if validation hasn't been called + (shouldn't happen in normal pipeline execution). + """ + return getattr(self, "_validated_common_keys", ()) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + **kwargs: Any, + ) -> None: + """Build-probe: collect right input, then stream left through a hash lookup. + + Phase 1 — Build: collect all rows from the right (filter) channel and + index them by the common-key values. + Phase 2 — Probe: stream left rows one at a time; for each row whose + common-key values appear in the right-side index, emit immediately. + + Falls back to barrier mode when the right input is empty (schema + cannot be inferred from data) or when there are no common keys. + """ + try: + left_ch, right_ch = inputs[0], inputs[1] + + # Phase 1: Build right-side lookup + right_rows = await right_ch.collect() + + if not right_rows: + # Empty right: determine common keys from the validated + # input schemas (set during __init__) to match sync semantics. + # Common keys exist → empty result; no common keys → pass left through. + common = self._common_keys_from_schema() + if common: + # Drain left channel (discard) — result is empty + await left_ch.collect() + return + # No common keys — pass all left rows through unchanged + async for tag, packet in left_ch: + await output.send((tag, packet)) + return + + # Determine right-side keys from first row + right_tag_keys = set(right_rows[0][0].keys()) + right_pkt_keys = set(right_rows[0][1].keys()) + right_all_keys = right_tag_keys | right_pkt_keys + + # Phase 2: Probe — stream left rows + common_keys: tuple[str, ...] | None = None + right_lookup: set[tuple] | None = None + + async for tag, packet in left_ch: + if common_keys is None: + # First left row — determine common keys and build index + left_tag_keys = set(tag.keys()) + left_pkt_keys = set(packet.keys()) + left_all_keys = left_tag_keys | left_pkt_keys + common_keys = tuple(sorted(left_all_keys & right_all_keys)) + + if not common_keys: + # No common keys — pass all left rows through + await output.send((tag, packet)) + async for t, p in left_ch: + await output.send((t, p)) + return + + # Build right-side lookup + right_lookup = set() + for rt, rp in right_rows: + rd = rt.as_dict() + rd.update(rp.as_dict()) + right_lookup.add(tuple(rd[k] for k in common_keys)) + + # Probe + ld = tag.as_dict() + ld.update(packet.as_dict()) + if tuple(ld[k] for k in common_keys) in right_lookup: # type: ignore[arg-type] + await output.send((tag, packet)) + finally: + await output.close() + + def identity_structure(self) -> Any: + return self.__class__.__name__ diff --git a/src/orcapod/core/operators/static_output_pod.py b/src/orcapod/core/operators/static_output_pod.py new file mode 100644 index 00000000..c3f32ce7 --- /dev/null +++ b/src/orcapod/core/operators/static_output_pod.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +import asyncio +import logging +from abc import abstractmethod +from collections.abc import Collection, Iterator, Sequence +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, cast + +from orcapod.channels import ReadableChannel, WritableChannel +from orcapod.config import Config +from orcapod.contexts import DataContext +from orcapod.core.base import TraceableBase +from orcapod.core.streams.base import StreamBase +from orcapod.core.tracker import DEFAULT_TRACKER_MANAGER +from orcapod.protocols.core_protocols import ( + ArgumentGroup, + PacketProtocol, + PodProtocol, + StreamProtocol, + TagProtocol, + TrackerManagerProtocol, +) +from orcapod.types import ColumnConfig, ContentHash, Schema +from orcapod.utils.lazy_module import LazyModule + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + + +class StaticOutputOperatorPod(TraceableBase): + """Abstract base class for pods whose core logic yields a static output stream. + + The static output stream is wrapped in ``DynamicPodStream`` which re-executes + the pod as necessary to keep the output up-to-date. Pod invocations are + tracked by the tracker manager. + """ + + def __init__( + self, tracker_manager: TrackerManagerProtocol | None = None, **kwargs + ) -> None: + self.tracker_manager = tracker_manager or DEFAULT_TRACKER_MANAGER + super().__init__(**kwargs) + + def pipeline_identity_structure(self) -> Any: + """Return the pipeline identity, which defaults to content identity for operators.""" + return self.identity_structure() + + @property + def uri(self) -> tuple[str, ...]: + """Return a unique resource identifier for the pod.""" + return ( + f"{self.__class__.__name__}", + self.content_hash().to_hex(), + ) + + @abstractmethod + def validate_inputs(self, *streams: StreamProtocol) -> None: + """Validate input streams, raising exceptions if invalid. + + Args: + *streams: Input streams to validate. + + Raises: + PodInputValidationError: If inputs are invalid. + """ + ... + + @abstractmethod + def argument_symmetry(self, streams: Collection[StreamProtocol]) -> ArgumentGroup: + """Describe symmetry/ordering constraints on input arguments. + + Returns a structure encoding which arguments can be reordered: + - ``frozenset``: Arguments commute (order doesn't matter). + - ``tuple``: Arguments have fixed positions. + - Nesting expresses partial symmetry. + + Examples: + Full symmetry (Join):: + + return frozenset([a, b, c]) + + No symmetry (Concatenate):: + + return (a, b, c) + + Partial symmetry:: + + return (frozenset([a, b]), c) + """ + ... + + @abstractmethod + def output_schema( + self, + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """Determine output (tag, packet) schemas without triggering computation. + + Args: + *streams: Input streams to analyze. + columns: Column configuration for included column groups. + all_info: If True, include all info columns. + + Returns: + A ``(tag_schema, packet_schema)`` tuple. + + Raises: + ValidationError: If input types are incompatible. + """ + ... + + @abstractmethod + def static_process(self, *streams: StreamProtocol) -> StreamProtocol: + """Execute the pod on the input streams and return a static output stream. + + Args: + *streams: Input streams to process. + + Returns: + The resulting output stream. + """ + ... + + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> DynamicPodStream: + """Invoke the pod on input streams and return a ``DynamicPodStream``. + + Args: + *streams: Input streams to process. + label: Optional label for tracking. + + Returns: + A ``DynamicPodStream`` wrapping the computation. + """ + logger.debug(f"Invoking kernel {self} on streams: {streams}") + + # perform input stream validation + self.validate_inputs(*streams) + self.tracker_manager.record_operator_pod_invocation( + self, upstreams=streams, label=label + ) + output_stream = DynamicPodStream( + pod=self, + upstreams=streams, + label=label, + ) + return output_stream + + def to_config(self) -> dict[str, Any]: + """Serialize this operator to a JSON-compatible config dict. + + Subclasses with constructor parameters should override this to include + their specific config in the ``"config"`` key. + + Returns: + A dict with ``class_name``, ``module_path``, and ``config`` keys. + """ + return { + "class_name": self.__class__.__name__, + "module_path": self.__class__.__module__, + "config": {}, + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "StaticOutputOperatorPod": + """Reconstruct an operator from a config dict. + + Args: + config: Dict as returned by ``to_config()``. + + Returns: + A new instance of this operator class. + """ + return cls(**config.get("config", {})) + + def __call__(self, *streams: StreamProtocol, **kwargs) -> DynamicPodStream: + """Convenience alias for ``process``.""" + logger.debug(f"Invoking pod {self} on streams through __call__: {streams}") + # perform input stream validation + return self.process(*streams, **kwargs) + + # ------------------------------------------------------------------ + # Async channel execution (default barrier mode) + # ------------------------------------------------------------------ + + @staticmethod + def _materialize_to_stream( + rows: Sequence[tuple[TagProtocol, PacketProtocol]], + ) -> StreamProtocol: + """Materialize a list of (Tag, Packet) pairs into an ArrowTableStream. + + Used by the barrier-mode ``async_execute`` to convert collected + channel items back into a stream suitable for ``static_process``. + """ + from orcapod.core.streams.arrow_table_stream import ArrowTableStream + from orcapod.utils import arrow_utils + + if not rows: + raise ValueError("Cannot materialize an empty list of rows into a stream") + + tag_tables = [] + packet_tables = [] + + for tag, packet in rows: + tag_tables.append(tag.as_table(columns={"system_tags": True})) + packet_tables.append(packet.as_table(columns={"source": True})) + + combined_tags = pa.concat_tables(tag_tables) + combined_packets = pa.concat_tables(packet_tables) + + user_tag_keys = tuple(rows[0][0].keys()) + + # Preserve actual source_info provenance from the first row + # (all rows share the same packet columns and source tokens). + source_info = rows[0][1].source_info() + + full_table = arrow_utils.hstack_tables(combined_tags, combined_packets) + + return ArrowTableStream( + full_table, + tag_columns=user_tag_keys, + source_info=source_info, + ) + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + input_pipeline_hashes: Sequence[ContentHash] | None = None, + ) -> None: + """Default barrier-mode async execution. + + Collects all inputs, runs ``static_process``, emits results. + Subclasses override for streaming or incremental strategies. + + Args: + inputs: Readable channels, one per upstream node. + output: Writable channel for downstream consumption. + input_pipeline_hashes: Pipeline hash for each input stream, + positionally matching ``inputs``. Multi-input operators + (e.g. Join) use these to compute canonical system-tag + column names. Ignored by single-input operators. + """ + try: + all_rows = await asyncio.gather(*(ch.collect() for ch in inputs)) + streams = [self._materialize_to_stream(rows) for rows in all_rows] + result = self.static_process(*streams) + for tag, packet in result.iter_packets(): + await output.send((tag, packet)) + finally: + await output.close() + + +class DynamicPodStream(StreamBase): + """Recomputable stream wrapping a ``StaticOutputOperatorPod`` invocation.""" + + def __init__( + self, + pod: StaticOutputOperatorPod, + upstreams: tuple[StreamProtocol, ...] = (), + label: str | None = None, + data_context: DataContext | None = None, + config: Config | None = None, + ) -> None: + self._pod = pod + self._upstreams = upstreams + + super().__init__(label=label, data_context=data_context, config=config) + self._set_modified_time(None) + self._cached_time: datetime | None = None + self._cached_stream: StreamProtocol | None = None + + def identity_structure(self) -> Any: + structure = (self._pod,) + if self._upstreams: + structure += (self._pod.argument_symmetry(self._upstreams),) + return structure + + def pipeline_identity_structure(self) -> Any: + structure = (self._pod,) + if self._upstreams: + structure += (self._pod.argument_symmetry(self._upstreams),) + return structure + + @property + def producer(self) -> PodProtocol: + return self._pod + + @property + def upstreams(self) -> tuple[StreamProtocol, ...]: + return self._upstreams + + def clear_cache(self) -> None: + """Clear the cached stream, forcing recomputation on next access.""" + self._cached_stream = None + self._cached_time = None + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + """Return the (tag_keys, packet_keys) column names for this stream.""" + tag_schema, packet_schema = self._pod.output_schema( + *self.upstreams, + columns=columns, + all_info=all_info, + ) + return tuple(tag_schema.keys()), tuple(packet_schema.keys()) + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """Return the (tag_schema, packet_schema) for this stream.""" + return self._pod.output_schema( + *self.upstreams, + columns=columns, + all_info=all_info, + ) + + @property + def last_modified(self) -> datetime | None: + """Returns the last modified time of the stream.""" + self._update_cache_status() + return self._cached_time + + def _update_cache_status(self) -> None: + if self._cached_time is None: + return + + upstream_times = [stream.last_modified for stream in self.upstreams] + upstream_times.append(self._pod.last_modified) + + if any(t is None for t in upstream_times): + self._cached_results = None + self._cached_time = None + return + + # Get the maximum upstream time + max_upstream_time = max(cast(list[datetime], upstream_times)) + + # Invalidate cache if upstream is newer and update the cache time + if max_upstream_time > self._cached_time: + self._cached_results = None + self._cached_time = max_upstream_time + + def run(self, *args: Any, **kwargs: Any) -> None: + self._update_cache_status() + + # recompute if cache is invalid + if self._cached_time is None or self._cached_stream is None: + self._cached_stream = self._pod.static_process( + *self.upstreams, + ) + self._cached_time = datetime.now(timezone.utc) + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + self.run() + assert self._cached_stream is not None, ( + "StreamProtocol has not been updated or is empty." + ) + return self._cached_stream.as_table(columns=columns, all_info=all_info) + + def iter_packets( + self, + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + self.run() + assert self._cached_stream is not None, ( + "StreamProtocol has not been updated or is empty." + ) + return self._cached_stream.iter_packets() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(kernel={self.producer}, upstreams={self.upstreams})" diff --git a/src/orcapod/core/packet_function.py b/src/orcapod/core/packet_function.py new file mode 100644 index 00000000..db874de2 --- /dev/null +++ b/src/orcapod/core/packet_function.py @@ -0,0 +1,946 @@ +from __future__ import annotations + +import inspect +import logging +import re +import sys +from abc import abstractmethod +from collections.abc import Callable, Iterable, Sequence +from datetime import datetime, timezone +import typing +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar + +from uuid_utils import uuid7 + +from orcapod.config import Config +from orcapod.contexts import DataContext +from orcapod.core.base import TraceableBase +from orcapod.core.datagrams import Packet +from orcapod.hashing.hash_utils import ( + get_function_components, + get_function_signature, +) +from orcapod.core.result_cache import ResultCache +from orcapod.protocols.core_protocols import PacketFunctionProtocol, PacketProtocol +from orcapod.protocols.core_protocols.executor import ( + PacketFunctionExecutorProtocol, + PythonFunctionExecutorProtocol, +) +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.types import DataValue, Schema, SchemaLike +from orcapod.utils import schema_utils +from orcapod.utils.git_utils import get_git_info_for_python_object +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + import pyarrow.compute as pc + + from orcapod.pipeline.logging_capture import CapturedLogs +else: + pa = LazyModule("pyarrow") + pc = LazyModule("pyarrow.compute") + +logger = logging.getLogger(__name__) + +error_handling_options = Literal["raise", "ignore", "warn"] + +# --------------------------------------------------------------------------- +# Shared executor for running async functions synchronously from within +# an event loop (see PythonPacketFunction._call_async_function_sync). +# --------------------------------------------------------------------------- + +_sync_executor = None + + +def _get_sync_executor(): + """Return a shared single-thread executor for sync fallback of async fns.""" + global _sync_executor + if _sync_executor is None: + from concurrent.futures import ThreadPoolExecutor + + _sync_executor = ThreadPoolExecutor(1) + return _sync_executor + + +def parse_function_outputs( + output_keys: Sequence[str], values: Any +) -> dict[str, DataValue]: + """ + Map raw function return values to a keyed output dict. + + Rules: + - ``output_keys = []``: return value is ignored; empty dict returned. + - ``output_keys = ["result"]``: any value (including iterables) is stored as-is + under the single key. + - ``output_keys = ["a", "b", ...]``: ``values`` must be iterable and its length + must match the number of keys. + + Args: + output_keys: Ordered list of output key names. + values: Raw return value from the function. + + Returns: + Dict mapping each output key to its corresponding value. + + Raises: + ValueError: If ``values`` is not iterable when multiple keys are given, or if + the number of values does not match the number of keys. + """ + if len(output_keys) == 0: + output_values: list[Any] = [] + elif len(output_keys) == 1: + output_values = [values] + elif isinstance(values, Iterable): + output_values = list(values) + else: + raise ValueError( + "Values returned by function must be sequence-like if multiple output keys are specified" + ) + + if len(output_values) != len(output_keys): + raise ValueError( + f"Number of output keys {len(output_keys)}:{output_keys} does not match " + f"number of values returned by function {len(output_values)}" + ) + + return dict(zip(output_keys, output_values)) + + +E = TypeVar("E", bound=PacketFunctionExecutorProtocol) + + +class PacketFunctionBase(TraceableBase, Generic[E]): + """Abstract base class for PacketFunctionProtocol. + + Type-parameterized with the executor protocol ``E``. Concrete + subclasses that bind ``E`` (e.g. ``class Foo(PacketFunctionBase[SomeProto])``) + get automatic ``isinstance`` validation in ``set_executor`` at class + definition time via ``__init_subclass__``. + """ + + _resolved_executor_protocol: ClassVar[type | None] = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + for base in getattr(cls, "__orig_bases__", ()): + origin = typing.get_origin(base) + if origin is PacketFunctionBase: + args = typing.get_args(base) + if args and not isinstance(args[0], TypeVar): + cls._resolved_executor_protocol = args[0] + return + + def __init__( + self, + version: str = "v0.0", + label: str | None = None, + data_context: str | DataContext | None = None, + config: Config | None = None, + executor: PacketFunctionExecutorProtocol | None = None, + ): + super().__init__(label=label, data_context=data_context, config=config) + self._active = True + self._version = version + self._executor: E | None = None + + # Parse version string to extract major and minor versions + # 0.5.2 -> 0 and 5.2, 1.3rc -> 1 and 3rc + match = re.match(r"\D*(\d+)\.(.*)", version) + if match: + self._major_version = int(match.group(1)) + self._minor_version = match.group(2) + else: + raise ValueError( + f"Version string {version} does not contain a valid version number" + ) + + self._output_packet_schema_hash = None + + # Validate and set via the property setter. This works because + # concrete subclasses define packet_function_type_id as a simple + # constant property that does not depend on instance state set + # *after* super().__init__(). + if executor is not None: + self.executor = executor + + def computed_label(self) -> str | None: + """Return the canonical function name as the label if no explicit label is given.""" + return self.canonical_function_name + + @property + def output_packet_schema_hash(self) -> str: + """Return the hash of the output packet schema as a string. + + The hash is computed lazily on first access and cached for subsequent calls. + + Returns: + The hash string of the output packet schema. + """ + if self._output_packet_schema_hash is None: + self._output_packet_schema_hash = ( + self.data_context.semantic_hasher.hash_object( + self.output_packet_schema + ).to_string() + ) + return self._output_packet_schema_hash + + @property + def uri(self) -> tuple[str, ...]: + return ( + self.canonical_function_name, + self.output_packet_schema_hash, + f"v{self.major_version}", + self.packet_function_type_id, + ) + + def identity_structure(self) -> Any: + return self.uri + + def pipeline_identity_structure(self) -> Any: + return self.uri + + @property + def major_version(self) -> int: + return self._major_version + + @property + def minor_version_string(self) -> str: + return self._minor_version + + @property + @abstractmethod + def packet_function_type_id(self) -> str: + """Unique function type identifier (e.g. ``"python.function.v1"``).""" + ... + + @property + @abstractmethod + def canonical_function_name(self) -> str: + """Human-readable function identifier.""" + ... + + @property + @abstractmethod + def input_packet_schema(self) -> Schema: + """Schema describing the input packets this function accepts.""" + ... + + @property + @abstractmethod + def output_packet_schema(self) -> Schema: + """Schema describing the output packets this function produces.""" + ... + + @abstractmethod + def get_function_variation_data(self) -> dict[str, Any]: + """Raw data defining function variation""" + ... + + @abstractmethod + def get_execution_data(self) -> dict[str, Any]: + """Raw data defining execution context""" + ... + + # ==================== Executor ==================== + + @property + def executor(self) -> E | None: + """Return the executor used to run this packet function, or ``None`` for direct execution.""" + return self._executor + + @executor.setter + def executor(self, executor: E | None) -> None: + """Set or clear the executor for this packet function. + + Delegates to ``set_executor`` for validation. + """ + self.set_executor(executor) + + def set_executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor, validating type compatibility. + + Performs two checks: + 1. The executor supports this function's ``packet_function_type_id``. + 2. If the subclass bound ``E`` via ``Generic[E]``, the executor is an + instance of the resolved protocol (checked once at assignment time, + not in the hot path). + + Raises: + TypeError: If *executor* fails either compatibility check. + """ + if executor is not None: + if not executor.supports(self.packet_function_type_id): + raise TypeError( + f"Executor {executor.executor_type_id!r} does not support " + f"packet function type {self.packet_function_type_id!r}. " + f"Supported types: {executor.supported_function_type_ids()}" + ) + proto = getattr(type(self), "_resolved_executor_protocol", None) + if proto is not None and not isinstance(executor, proto): + raise TypeError( + f"{type(self).__name__} requires an executor implementing " + f"{proto.__name__}, got {type(executor).__name__}" + ) + self._executor = executor + + # ==================== Execution ==================== + + def call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Process a single packet, routing through the executor if one is set. + + Subclasses should override ``direct_call`` instead of this method. + + Returns: + A ``(output_packet, captured_logs)`` tuple. + """ + if self._executor is not None: + return self._executor.execute(self, packet) + return self.direct_call(packet) + + async def async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Asynchronously process a single packet, routing through the executor if set. + + Subclasses should override ``direct_async_call`` instead of this method. + + Returns: + A ``(output_packet, captured_logs)`` tuple. + """ + if self._executor is not None: + return await self._executor.async_execute(self, packet) + return await self.direct_async_call(packet) + + @abstractmethod + def direct_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Execute the function's native computation on *packet*. + + This is the method executors invoke. It bypasses executor routing + and runs the computation directly. On user-function failure the + exception is caught internally and ``(None, captured_failure)`` + is returned — no re-raise. Subclasses must implement this. + """ + ... + + @abstractmethod + async def direct_async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Asynchronous counterpart of ``direct_call``.""" + ... + + +class PythonPacketFunction(PacketFunctionBase[PythonFunctionExecutorProtocol]): + @property + def packet_function_type_id(self) -> str: + """Unique function type identifier.""" + return "python.function.v0" + + def __init__( + self, + function: Callable[..., Any], + output_keys: str | Sequence[str] | None = None, + function_name: str | None = None, + version: str = "v0.0", + input_schema: SchemaLike | None = None, + output_schema: SchemaLike | Sequence[type] | None = None, + label: str | None = None, + **kwargs, + ) -> None: + self._function = function + self._is_async = inspect.iscoroutinefunction(function) + + # Reject functions with variadic parameters -- PythonPacketFunction maps + # packet keys to named parameters, so the full parameter set must be fixed. + _sig = inspect.signature(function) + _variadic = [ + name + for name, param in _sig.parameters.items() + if param.kind + in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ) + ] + if _variadic: + raise ValueError( + f"PythonPacketFunction does not support functions with variadic " + f"parameters (*args / **kwargs). " + f"Offending parameters: {_variadic!r}." + ) + + if output_keys is None: + output_keys = [] + if isinstance(output_keys, str): + output_keys = [output_keys] + self._output_keys = output_keys + if function_name is None: + if hasattr(self._function, "__name__"): + function_name = getattr(self._function, "__name__") + else: + raise ValueError( + "function_name must be provided if function has no __name__" + ) + + assert function_name is not None + self._function_name = function_name + + super().__init__(label=label, version=version, **kwargs) + + # extract input and output schema from the function signature + self._input_schema, self._output_schema = schema_utils.extract_function_schemas( + self._function, + self._output_keys, + input_typespec=input_schema, + output_typespec=output_schema, + ) + + # get git info for the function + # TODO: turn this into optional addition + env_info = get_git_info_for_python_object(self._function) + if env_info is None: + git_hash = "unknown" + else: + git_hash = env_info.get("git_commit_hash", "unknown") + if env_info.get("git_repo_status") == "dirty": + git_hash += "-dirty" + self._git_hash = git_hash + + semantic_hasher = self.data_context.semantic_hasher + self._function_signature_hash = semantic_hasher.hash_object( + get_function_signature(function) + ).to_string() + self._function_content_hash = semantic_hasher.hash_object( + get_function_components(self._function) + ).to_string() + self._output_schema_hash = semantic_hasher.hash_object( + self.output_packet_schema + ).to_string() + + @property + def canonical_function_name(self) -> str: + """Human-readable function identifier.""" + return self._function_name + + def get_function_variation_data(self) -> dict[str, Any]: + """Raw data defining function variation - system computes hash""" + return { + "function_name": self._function_name, + "function_signature_hash": self._function_signature_hash, + "function_content_hash": self._function_content_hash, + "git_hash": self._git_hash, + } + + def get_execution_data(self) -> dict[str, Any]: + """Raw data defining execution context - system computes hash""" + python_version_info = sys.version_info + python_version_str = f"{python_version_info.major}.{python_version_info.minor}.{python_version_info.micro}" + return {"python_version": python_version_str, "execution_context": "local"} + + @property + def input_packet_schema(self) -> Schema: + """Schema describing the input packets this function accepts.""" + return self._input_schema + + @property + def output_packet_schema(self) -> Schema: + """Schema describing the output packets this function produces.""" + return self._output_schema + + def is_active(self) -> bool: + """Return whether the function is active (will process packets).""" + return self._active + + def set_active(self, active: bool = True) -> None: + """Set the active state. If False, ``call`` returns None for every packet.""" + self._active = active + + @property + def is_async(self) -> bool: + """Return whether the wrapped function is an async coroutine function.""" + return self._is_async + + def _build_output_packet(self, values: Any) -> PacketProtocol: + """Build an output Packet from raw function return values. + + Args: + values: Raw return value from the wrapped function. + + Returns: + A Packet containing the parsed outputs with source info. + """ + output_data = parse_function_outputs(self._output_keys, values) + + def combine(*components: tuple[str, ...]) -> str: + inner_parsed = [":".join(component) for component in components] + return "::".join(inner_parsed) + + record_id = str(uuid7()) + source_info = {k: combine(self.uri, (record_id,), (k,)) for k in output_data} + + return Packet( + output_data, + source_info=source_info, + record_id=record_id, + python_schema=self.output_packet_schema, + data_context=self.data_context, + ) + + def _call_async_function_sync(self, packet: PacketProtocol) -> Any: + """Run the wrapped async function synchronously. + + Uses ``asyncio.run()`` when no event loop is running. When called + from within a running loop, offloads to a new thread to avoid + nested event loop errors. + + The coroutine is constructed inside the executor thread (not in the + caller thread) to avoid unawaited-coroutine warnings if submission + fails. + + Args: + packet: The input packet whose dict form is passed to the function. + + Returns: + The raw return value of the async function. + """ + import asyncio + + kwargs = packet.as_dict() + fn = self._function + try: + asyncio.get_running_loop() + except RuntimeError: + # No running loop — safe to use asyncio.run() + return asyncio.run(fn(**kwargs)) + else: + # Already in a loop — run in a separate thread with its own loop. + # The lambda ensures the coroutine is created inside the executor + # thread, avoiding unawaited-coroutine warnings on submission failure. + return ( + _get_sync_executor().submit(lambda: asyncio.run(fn(**kwargs))).result() + ) + + def call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Process a single packet, routing through the executor if one is set. + + When an executor implementing ``PythonFunctionExecutorProtocol`` is + set, the raw callable and kwargs are handed to ``execute_callable`` + which returns ``(raw_result, CapturedLogs)``. The output packet is + built from ``raw_result``. + """ + from orcapod.pipeline.logging_capture import CapturedLogs + + if self._executor is not None: + if not self._active: + return None, CapturedLogs(success=True) + raw, captured = self._executor.execute_callable( + self._function, packet.as_dict() + ) + if not captured.success: + return None, captured + return self._build_output_packet(raw), captured + return self.direct_call(packet) + + async def async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Async counterpart of ``call``.""" + from orcapod.pipeline.logging_capture import CapturedLogs + + if self._executor is not None: + if not self._active: + return None, CapturedLogs(success=True) + raw, captured = await self._executor.async_execute_callable( + self._function, packet.as_dict() + ) + if not captured.success: + return None, captured + return self._build_output_packet(raw), captured + return await self.direct_async_call(packet) + + def direct_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Execute the function on *packet* synchronously (no executor path). + + Uses :class:`~orcapod.pipeline.logging_capture.LocalCaptureContext` + for I/O capture. On user-function failure the exception is caught + internally and ``(None, captured_failure)`` is returned — no re-raise. + For async functions, the coroutine is driven to completion via + ``asyncio.run()`` (or a helper thread when already inside an event loop). + """ + import traceback as _tb + + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + if not self._active: + return None, CapturedLogs(success=True) + + ctx = LocalCaptureContext() + raw_result = None + with ctx: + try: + if self._is_async: + raw_result = self._call_async_function_sync(packet) + else: + raw_result = self._function(**packet.as_dict()) + except Exception: + return None, ctx.get_captured(success=False, tb=_tb.format_exc()) + return self._build_output_packet(raw_result), ctx.get_captured(success=True) + + async def direct_async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Execute the function on *packet* asynchronously (no executor path). + + Async functions are ``await``-ed directly. Sync functions are + offloaded to a thread pool via ``run_in_executor``. On failure, + ``(None, captured_failure)`` is returned — no re-raise. + """ + import asyncio + import traceback as _tb + + from orcapod.pipeline.logging_capture import CapturedLogs, LocalCaptureContext + + if not self._active: + return None, CapturedLogs(success=True) + + ctx = LocalCaptureContext() + raw_result = None + with ctx: + try: + if self._is_async: + raw_result = await self._function(**packet.as_dict()) + else: + import contextvars + import functools + + loop = asyncio.get_running_loop() + task_ctx = contextvars.copy_context() + raw_result = await loop.run_in_executor( + None, + functools.partial( + task_ctx.run, + self._function, + **packet.as_dict(), + ), + ) + except Exception: + return None, ctx.get_captured(success=False, tb=_tb.format_exc()) + return self._build_output_packet(raw_result), ctx.get_captured(success=True) + + def to_config(self) -> dict[str, Any]: + """Serialize this packet function to a JSON-compatible config dict. + + Returns: + A dict with ``packet_function_type_id`` and a nested ``config`` + containing enough information to reconstruct this instance via + :meth:`from_config`. + """ + return { + "packet_function_type_id": self.packet_function_type_id, + "config": { + "module_path": self._function.__module__, + "callable_name": self._function_name, + "version": self._version, + "input_packet_schema": { + k: str(v) for k, v in self.input_packet_schema.items() + }, + "output_packet_schema": { + k: str(v) for k, v in self.output_packet_schema.items() + }, + "output_keys": list(self._output_keys) if self._output_keys else None, + }, + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PythonPacketFunction": + """Reconstruct a PythonPacketFunction by importing the callable. + + Args: + config: A dict as produced by :meth:`to_config`. + + Returns: + A new ``PythonPacketFunction`` wrapping the imported callable. + + Raises: + ImportError: If the module specified in *config* cannot be imported. + AttributeError: If the callable name does not exist in the module. + """ + import importlib + + inner = config.get("config", config) + module = importlib.import_module(inner["module_path"]) + func = getattr(module, inner["callable_name"]) + return cls( + function=func, + output_keys=inner.get("output_keys"), + version=inner.get("version", "v0.0"), + ) + + +class PacketFunctionWrapper(PacketFunctionBase[E]): + """Wrapper around a PacketFunctionProtocol to modify or extend its behavior. + + Remains generic over ``E`` — the executor protocol is not bound here + so that wrappers inherit the executor type constraint of the wrapped + function. + """ + + def __init__(self, packet_function: PacketFunctionProtocol, **kwargs) -> None: + super().__init__(**kwargs) + self._packet_function = packet_function + + def computed_label(self) -> str | None: + return self._packet_function.label + + @property + def major_version(self) -> int: + return self._packet_function.major_version + + @property + def minor_version_string(self) -> str: + return self._packet_function.minor_version_string + + @property + def packet_function_type_id(self) -> str: + return self._packet_function.packet_function_type_id + + @property + def canonical_function_name(self) -> str: + return self._packet_function.canonical_function_name + + @property + def input_packet_schema(self) -> Schema: + return self._packet_function.input_packet_schema + + @property + def output_packet_schema(self) -> Schema: + return self._packet_function.output_packet_schema + + def get_function_variation_data(self) -> dict[str, Any]: + return self._packet_function.get_function_variation_data() + + def get_execution_data(self) -> dict[str, Any]: + return self._packet_function.get_execution_data() + + def to_config(self) -> dict[str, Any]: + """Delegate serialization to the wrapped packet function.""" + return self._packet_function.to_config() + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PacketFunctionWrapper": + """Reconstruct by delegating to the wrapped function type. + + Args: + config: A dict as produced by :meth:`to_config`. + + Returns: + A new instance reconstructed from *config*. + """ + return cls._packet_function_class_for_config(config).from_config(config) + + @staticmethod + def _packet_function_class_for_config(config: dict[str, Any]) -> type: + """Return the concrete class to use when reconstructing from *config*. + + Currently only ``PythonPacketFunction`` is supported. Subclasses may + override this to handle additional types. + + Args: + config: A config dict as produced by :meth:`to_config`. + + Returns: + The class to call ``from_config`` on. + + Raises: + ValueError: If the ``packet_function_type_id`` is not recognized. + """ + type_id = config.get("packet_function_type_id") + if type_id == "python.function.v0": + return PythonPacketFunction + raise ValueError(f"Unrecognized packet_function_type_id: {type_id!r}") + + # -- Executor delegation: setting/getting the executor on a wrapper + # transparently targets the wrapped (leaf) packet function. + + @property + def executor(self) -> PacketFunctionExecutorProtocol | None: + return self._packet_function.executor + + @executor.setter + def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + self._packet_function.executor = executor + + # -- Execution: call/async_call delegate to the wrapped function's + # call/async_call which handles executor routing. direct_call / + # direct_async_call bypass executor routing as their names imply. + + def call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + return self._packet_function.call(packet) + + async def async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + return await self._packet_function.async_call(packet) + + def direct_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + return self._packet_function.direct_call(packet) + + async def direct_async_call( + self, packet: PacketProtocol + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + return await self._packet_function.direct_async_call(packet) + + +class CachedPacketFunction(PacketFunctionWrapper): + """Wrapper around a PacketFunctionProtocol that caches results for identical input packets. + + Uses a shared ``ResultCache`` for lookup/store/conflict-resolution + logic (same mechanism as ``CachedFunctionPod``). + """ + + # Expose RESULT_COMPUTED_FLAG from the shared ResultCache + RESULT_COMPUTED_FLAG = ResultCache.RESULT_COMPUTED_FLAG + + def __init__( + self, + packet_function: PacketFunctionProtocol, + result_database: ArrowDatabaseProtocol, + record_path_prefix: tuple[str, ...] = (), + **kwargs, + ) -> None: + super().__init__(packet_function, **kwargs) + self._result_database = result_database + self._record_path_prefix = record_path_prefix + self._cache = ResultCache( + result_database=result_database, + record_path=record_path_prefix + self.uri, + auto_flush=True, + ) + + def set_auto_flush(self, on: bool = True) -> None: + """Set auto-flush behavior. If True, the database flushes after each record.""" + self._cache.set_auto_flush(on) + + @property + def record_path(self) -> tuple[str, ...]: + """Return the path to the record in the result store.""" + return self._cache.record_path + + def call( + self, + packet: PacketProtocol, + *, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + from orcapod.pipeline.logging_capture import CapturedLogs + + output_packet = None + if not skip_cache_lookup: + logger.info("Checking for cache...") + output_packet = self._cache.lookup(packet) + if output_packet is not None: + logger.info(f"Cache hit for {packet}!") + return output_packet, CapturedLogs(success=True) + output_packet, captured = self._packet_function.call(packet) + if output_packet is not None: + if not skip_cache_insert: + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + ) + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet, captured + + async def async_call( + self, + packet: PacketProtocol, + *, + skip_cache_lookup: bool = False, + skip_cache_insert: bool = False, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Async counterpart of ``call`` with cache check and recording.""" + from orcapod.pipeline.logging_capture import CapturedLogs + + output_packet = None + if not skip_cache_lookup: + logger.info("Checking for cache...") + output_packet = self._cache.lookup(packet) + if output_packet is not None: + logger.info(f"Cache hit for {packet}!") + return output_packet, CapturedLogs(success=True) + output_packet, captured = await self._packet_function.async_call(packet) + if output_packet is not None: + if not skip_cache_insert: + self._cache.store( + packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + ) + output_packet = output_packet.with_meta_columns( + **{self.RESULT_COMPUTED_FLAG: True} + ) + return output_packet, captured + + def get_cached_output_for_packet( + self, input_packet: PacketProtocol + ) -> PacketProtocol | None: + """Retrieve the cached output packet for *input_packet*. + + If multiple cached entries exist, the most recent (by timestamp) wins. + + Returns: + The cached output packet, or ``None`` if no entry was found. + """ + return self._cache.lookup(input_packet) + + def record_packet( + self, + input_packet: PacketProtocol, + output_packet: PacketProtocol, + skip_duplicates: bool = False, + ) -> PacketProtocol: + """Record the output packet against the input packet in the result store.""" + self._cache.store( + input_packet, + output_packet, + variation_data=self.get_function_variation_data(), + execution_data=self.get_execution_data(), + skip_duplicates=skip_duplicates, + ) + return output_packet + + def get_all_cached_outputs( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store for this function. + + Args: + include_system_columns: If True, include system columns + (e.g. record_id) in the result. + + Returns: + A PyArrow table of cached results, or ``None`` if empty. + """ + return self._cache.get_all_records( + include_system_columns=include_system_columns + ) diff --git a/src/orcapod/core/pods.py b/src/orcapod/core/pods.py deleted file mode 100644 index 9e2f9ad2..00000000 --- a/src/orcapod/core/pods.py +++ /dev/null @@ -1,947 +0,0 @@ -import hashlib -import logging -from abc import abstractmethod -from collections.abc import Callable, Collection, Iterable, Sequence -from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any, Literal, Protocol, cast - -from orcapod import contexts -from orcapod.core.datagrams import ( - ArrowPacket, - DictPacket, -) -from functools import wraps - -from orcapod.utils.git_utils import get_git_info_for_python_object -from orcapod.core.kernels import KernelStream, TrackedKernelBase -from orcapod.core.operators import Join -from orcapod.core.streams import CachedPodStream, LazyPodResultStream -from orcapod.core.system_constants import constants -from orcapod.hashing.hash_utils import get_function_components, get_function_signature -from orcapod.protocols import core_protocols as cp -from orcapod.protocols import hashing_protocols as hp -from orcapod.protocols.database_protocols import ArrowDatabase -from orcapod.types import DataValue, PythonSchema, PythonSchemaLike -from orcapod.utils import types_utils -from orcapod.utils.lazy_module import LazyModule - - -# TODO: extract default char count as config -def combine_hashes( - *hashes: str, - order: bool = False, - prefix_hasher_id: bool = False, - hex_char_count: int | None = 20, -) -> str: - """Combine hashes into a single hash string.""" - - # Sort for deterministic order regardless of input order - if order: - prepared_hashes = sorted(hashes) - else: - prepared_hashes = list(hashes) - combined = "".join(prepared_hashes) - combined_hash = hashlib.sha256(combined.encode()).hexdigest() - if hex_char_count is not None: - combined_hash = combined_hash[:hex_char_count] - if prefix_hasher_id: - return "sha256@" + combined_hash - return combined_hash - - -if TYPE_CHECKING: - import pyarrow as pa - import pyarrow.compute as pc -else: - pa = LazyModule("pyarrow") - pc = LazyModule("pyarrow.compute") - -logger = logging.getLogger(__name__) - -error_handling_options = Literal["raise", "ignore", "warn"] - - -class ActivatablePodBase(TrackedKernelBase): - """ - FunctionPod is a specialized kernel that encapsulates a function to be executed on data streams. - It allows for the execution of a function with a specific label and can be tracked by the system. - """ - - @abstractmethod - def input_packet_types(self) -> PythonSchema: - """ - Return the input typespec for the pod. This is used to validate the input streams. - """ - ... - - @abstractmethod - def output_packet_types(self) -> PythonSchema: - """ - Return the output typespec for the pod. This is used to validate the output streams. - """ - ... - - @property - def version(self) -> str: - return self._version - - @abstractmethod - def get_record_id(self, packet: cp.Packet, execution_engine_hash: str) -> str: - """ - Return the record ID for the input packet. This is used to identify the pod in the system. - """ - ... - - @property - @abstractmethod - def tiered_pod_id(self) -> dict[str, str]: - """ - Return the tiered pod ID for the pod. This is used to identify the pod in a tiered architecture. - """ - ... - - def __init__( - self, - error_handling: error_handling_options = "raise", - label: str | None = None, - version: str = "v0.0", - **kwargs, - ) -> None: - super().__init__(label=label, **kwargs) - self._active = True - self.error_handling = error_handling - self._version = version - import re - - match = re.match(r"\D.*(\d+)", version) - major_version = 0 - if match: - major_version = int(match.group(1)) - else: - raise ValueError( - f"Version string {version} does not contain a valid version number" - ) - self.skip_type_checking = False - self._major_version = major_version - - @property - def major_version(self) -> int: - return self._major_version - - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Return the input and output typespecs for the pod. - This is used to validate the input and output streams. - """ - tag_typespec, _ = streams[0].types(include_system_tags=include_system_tags) - return tag_typespec, self.output_packet_types() - - def is_active(self) -> bool: - """ - Check if the pod is active. If not, it will not process any packets. - """ - return self._active - - def set_active(self, active: bool) -> None: - """ - Set the active state of the pod. If set to False, the pod will not process any packets. - """ - self._active = active - - @staticmethod - def _join_streams(*streams: cp.Stream) -> cp.Stream: - if not streams: - raise ValueError("No streams provided for joining") - # Join the streams using a suitable join strategy - if len(streams) == 1: - return streams[0] - - joined_stream = streams[0] - for next_stream in streams[1:]: - joined_stream = Join()(joined_stream, next_stream) - return joined_stream - - def pre_kernel_processing(self, *streams: cp.Stream) -> tuple[cp.Stream, ...]: - """ - Prepare the incoming streams for execution in the pod. At least one stream must be present. - If more than one stream is present, the join of the provided streams will be returned. - """ - # if multiple streams are provided, join them - # otherwise, return as is - if len(streams) <= 1: - return streams - - output_stream = self._join_streams(*streams) - return (output_stream,) - - def validate_inputs(self, *streams: cp.Stream) -> None: - if len(streams) != 1: - raise ValueError( - f"{self.__class__.__name__} expects exactly one input stream, got {len(streams)}" - ) - if self.skip_type_checking: - return - input_stream = streams[0] - _, incoming_packet_types = input_stream.types() - if not types_utils.check_typespec_compatibility( - incoming_packet_types, self.input_packet_types() - ): - # TODO: use custom exception type for better error handling - raise ValueError( - f"Incoming packet data type {incoming_packet_types} from {input_stream} is not compatible with expected input typespec {self.input_packet_types()}" - ) - - def prepare_output_stream( - self, *streams: cp.Stream, label: str | None = None - ) -> KernelStream: - return KernelStream(source=self, upstreams=streams, label=label) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - assert len(streams) == 1, "PodBase.forward expects exactly one input stream" - return LazyPodResultStream(pod=self, prepared_stream=streams[0]) - - @abstractmethod - def call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[cp.Tag, cp.Packet | None]: ... - - @abstractmethod - async def async_call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[cp.Tag, cp.Packet | None]: ... - - def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> None: - if not self._skip_tracking and self._tracker_manager is not None: - self._tracker_manager.record_pod_invocation(self, streams, label=label) - - -class CallableWithPod(Protocol): - def __call__(self, *args, **kwargs) -> Any: ... - - @property - def pod(self) -> "FunctionPod": ... - - -def function_pod( - output_keys: str | Collection[str] | None = None, - function_name: str | None = None, - version: str = "v0.0", - label: str | None = None, - **kwargs, -) -> Callable[..., CallableWithPod]: - """ - Decorator that attaches FunctionPod as pod attribute. - - Args: - output_keys: Keys for the function output(s) - function_name: Name of the function pod; if None, defaults to the function name - **kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details. - - Returns: - CallableWithPod: Decorated function with `pod` attribute holding the FunctionPod instance - """ - - def decorator(func: Callable) -> CallableWithPod: - - if func.__name__ == "": - raise ValueError("Lambda functions cannot be used with function_pod") - - @wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - # Store the original function in the module for pickling purposes - # and make sure to change the name of the function - - # Create a simple typed function pod - pod = FunctionPod( - function=func, - output_keys=output_keys, - function_name=function_name or func.__name__, - version=version, - label=label, - **kwargs, - ) - setattr(wrapper, "pod", pod) - return cast(CallableWithPod, wrapper) - return decorator - - -class FunctionPod(ActivatablePodBase): - def __init__( - self, - function: cp.PodFunction, - output_keys: str | Collection[str] | None = None, - function_name=None, - version: str = "v0.0", - input_python_schema: PythonSchemaLike | None = None, - output_python_schema: PythonSchemaLike | Sequence[type] | None = None, - label: str | None = None, - function_info_extractor: hp.FunctionInfoExtractor | None = None, - **kwargs, - ) -> None: - self.function = function - - if output_keys is None: - output_keys = [] - if isinstance(output_keys, str): - output_keys = [output_keys] - self.output_keys = output_keys - if function_name is None: - if hasattr(self.function, "__name__"): - function_name = getattr(self.function, "__name__") - else: - raise ValueError( - "function_name must be provided if function has no __name__ attribute" - ) - self.function_name = function_name - # extract the first full index (potentially with leading 0) in the version string - if not isinstance(version, str): - raise TypeError(f"Version must be a string, got {type(version)}") - - super().__init__(label=label or self.function_name, version=version, **kwargs) - - # extract input and output types from the function signature - input_packet_types, output_packet_types = ( - types_utils.extract_function_typespecs( - self.function, - self.output_keys, - input_typespec=input_python_schema, - output_typespec=output_python_schema, - ) - ) - - # get git info for the function - env_info = get_git_info_for_python_object(self.function) - if env_info is None: - git_hash = "unknown" - else: - git_hash = env_info.get("git_commit_hash", "unknown") - if env_info.get("git_repo_status") == "dirty": - git_hash += "-dirty" - self._git_hash = git_hash - - self._input_packet_schema = dict(input_packet_types) - self._output_packet_schema = dict(output_packet_types) - # TODO: add output packet converter for speed up - - self._function_info_extractor = function_info_extractor - object_hasher = self.data_context.object_hasher - # TODO: fix and replace with object_hasher protocol specific methods - self._function_signature_hash = object_hasher.hash_object( - get_function_signature(self.function) - ).to_string() - self._function_content_hash = object_hasher.hash_object( - get_function_components(self.function) - ).to_string() - - self._output_packet_type_hash = object_hasher.hash_object( - self.output_packet_types() - ).to_string() - - self._total_pod_id_hash = object_hasher.hash_object( - self.tiered_pod_id - ).to_string() - - @property - def tiered_pod_id(self) -> dict[str, str]: - return { - "version": self.version, - "signature": self._function_signature_hash, - "content": self._function_content_hash, - "git_hash": self._git_hash, - } - - @property - def reference(self) -> tuple[str, ...]: - return ( - self.function_name, - self._output_packet_type_hash, - "v" + str(self.major_version), - ) - - def get_record_id( - self, - packet: cp.Packet, - execution_engine_hash: str, - ) -> str: - return combine_hashes( - str(packet.content_hash()), - self._total_pod_id_hash, - execution_engine_hash, - prefix_hasher_id=True, - ) - - def input_packet_types(self) -> PythonSchema: - """ - Return the input typespec for the function pod. - This is used to validate the input streams. - """ - return self._input_packet_schema.copy() - - def output_packet_types(self) -> PythonSchema: - """ - Return the output typespec for the function pod. - This is used to validate the output streams. - """ - return self._output_packet_schema.copy() - - def __repr__(self) -> str: - return f"FunctionPod:{self.function_name}" - - def __str__(self) -> str: - include_module = self.function.__module__ != "__main__" - func_sig = get_function_signature( - self.function, - name_override=self.function_name, - include_module=include_module, - ) - return f"FunctionPod:{func_sig}" - - def call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[cp.Tag, DictPacket | None]: - if not self.is_active(): - logger.info( - f"Pod is not active: skipping computation on input packet {packet}" - ) - return tag, None - - execution_engine_hash = execution_engine.name if execution_engine else "default" - - # any kernel/pod invocation happening inside the function will NOT be tracked - if not isinstance(packet, dict): - input_dict = packet.as_dict(include_source=False) - else: - input_dict = packet - - with self._tracker_manager.no_tracking(): - if execution_engine is not None: - # use the provided execution engine to run the function - values = execution_engine.submit_sync( - self.function, - fn_kwargs=input_dict, - **(execution_engine_opts or {}), - ) - else: - values = self.function(**input_dict) - - output_data = self.process_function_output(values) - - # TODO: extract out this function - def combine(*components: tuple[str, ...]) -> str: - inner_parsed = [":".join(component) for component in components] - return "::".join(inner_parsed) - - if record_id is None: - # if record_id is not provided, generate it from the packet - record_id = self.get_record_id(packet, execution_engine_hash) - source_info = { - k: combine(self.reference, (record_id,), (k,)) for k in output_data - } - - output_packet = DictPacket( - output_data, - source_info=source_info, - python_schema=self.output_packet_types(), - data_context=self.data_context, - ) - return tag, output_packet - - async def async_call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[cp.Tag, cp.Packet | None]: - """ - Asynchronous call to the function pod. This is a placeholder for future implementation. - Currently, it behaves like the synchronous call. - """ - if not self.is_active(): - logger.info( - f"Pod is not active: skipping computation on input packet {packet}" - ) - return tag, None - - execution_engine_hash = execution_engine.name if execution_engine else "default" - - # any kernel/pod invocation happening inside the function will NOT be tracked - # with self._tracker_manager.no_tracking(): - # FIXME: figure out how to properly make context manager work with async/await - # any kernel/pod invocation happening inside the function will NOT be tracked - if not isinstance(packet, dict): - input_dict = packet.as_dict(include_source=False) - else: - input_dict = packet - if execution_engine is not None: - # use the provided execution engine to run the function - values = await execution_engine.submit_async( - self.function, - fn_kwargs=input_dict, - **(execution_engine_opts or {}) - ) - else: - values = self.function(**input_dict) - - output_data = self.process_function_output(values) - - # TODO: extract out this function - def combine(*components: tuple[str, ...]) -> str: - inner_parsed = [":".join(component) for component in components] - return "::".join(inner_parsed) - - if record_id is None: - # if record_id is not provided, generate it from the packet - record_id = self.get_record_id(packet, execution_engine_hash) - source_info = { - k: combine(self.reference, (record_id,), (k,)) for k in output_data - } - - output_packet = DictPacket( - output_data, - source_info=source_info, - python_schema=self.output_packet_types(), - data_context=self.data_context, - ) - return tag, output_packet - - def process_function_output(self, values: Any) -> dict[str, DataValue]: - output_values = [] - if len(self.output_keys) == 0: - output_values = [] - elif len(self.output_keys) == 1: - output_values = [values] # type: ignore - elif isinstance(values, Iterable): - output_values = list(values) # type: ignore - elif len(self.output_keys) > 1: - raise ValueError( - "Values returned by function must be a pathlike or a sequence of pathlikes" - ) - - if len(output_values) != len(self.output_keys): - raise ValueError( - f"Number of output keys {len(self.output_keys)}:{self.output_keys} does not match number of values returned by function {len(output_values)}" - ) - - return {k: v for k, v in zip(self.output_keys, output_values)} - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - id_struct = (self.__class__.__name__,) + self.reference - # if streams are provided, perform pre-processing step, validate, and add the - # resulting single stream to the identity structure - if streams is not None and len(streams) != 0: - id_struct += tuple(streams) - - return id_struct - - -class WrappedPod(ActivatablePodBase): - """ - A wrapper for an existing pod, allowing for additional functionality or modifications without changing the original pod. - This class is meant to serve as a base class for other pods that need to wrap existing pods. - Note that only the call logic is pass through to the wrapped pod, but the forward logic is not. - """ - - def __init__( - self, - pod: cp.Pod, - label: str | None = None, - data_context: str | contexts.DataContext | None = None, - **kwargs, - ) -> None: - # if data_context is not explicitly given, use that of the contained pod - if data_context is None: - data_context = pod.data_context_key - super().__init__( - label=label, - data_context=data_context, - **kwargs, - ) - self.pod = pod - - @property - def reference(self) -> tuple[str, ...]: - """ - Return the pod ID, which is the function name of the wrapped pod. - This is used to identify the pod in the system. - """ - return self.pod.reference - - def get_record_id(self, packet: cp.Packet, execution_engine_hash: str) -> str: - return self.pod.get_record_id(packet, execution_engine_hash) - - @property - def tiered_pod_id(self) -> dict[str, str]: - """ - Return the tiered pod ID for the wrapped pod. This is used to identify the pod in a tiered architecture. - """ - return self.pod.tiered_pod_id - - def computed_label(self) -> str | None: - return self.pod.label - - def input_packet_types(self) -> PythonSchema: - """ - Return the input typespec for the stored pod. - This is used to validate the input streams. - """ - return self.pod.input_packet_types() - - def output_packet_types(self) -> PythonSchema: - """ - Return the output typespec for the stored pod. - This is used to validate the output streams. - """ - return self.pod.output_packet_types() - - def validate_inputs(self, *streams: cp.Stream) -> None: - self.pod.validate_inputs(*streams) - - def call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[cp.Tag, cp.Packet | None]: - return self.pod.call( - tag, - packet, - record_id=record_id, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - - async def async_call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[cp.Tag, cp.Packet | None]: - return await self.pod.async_call( - tag, - packet, - record_id=record_id, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - return self.pod.identity_structure(streams) - - def __repr__(self) -> str: - return f"WrappedPod({self.pod!r})" - - def __str__(self) -> str: - return f"WrappedPod:{self.pod!s}" - - -class CachedPod(WrappedPod): - """ - A pod that caches the results of the wrapped pod. - This is useful for pods that are expensive to compute and can benefit from caching. - """ - - # name of the column in the tag store that contains the packet hash - DATA_RETRIEVED_FLAG = f"{constants.META_PREFIX}data_retrieved" - - def __init__( - self, - pod: cp.Pod, - result_database: ArrowDatabase, - record_path_prefix: tuple[str, ...] = (), - match_tier: str | None = None, - retrieval_mode: Literal["latest", "most_specific"] = "latest", - **kwargs, - ): - super().__init__(pod, **kwargs) - self.record_path_prefix = record_path_prefix - self.result_database = result_database - self.match_tier = match_tier - self.retrieval_mode = retrieval_mode - self.mode: Literal["production", "development"] = "production" - - def set_mode(self, mode: str) -> None: - if mode not in ("production", "development"): - raise ValueError(f"Invalid mode: {mode}") - self.mode = mode - - @property - def version(self) -> str: - return self.pod.version - - @property - def record_path(self) -> tuple[str, ...]: - """ - Return the path to the record in the result store. - This is used to store the results of the pod. - """ - return self.record_path_prefix + self.reference - - def call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, - ) -> tuple[cp.Tag, cp.Packet | None]: - # TODO: consider logic for overwriting existing records - execution_engine_hash = execution_engine.name if execution_engine else "default" - if record_id is None: - record_id = self.get_record_id( - packet, execution_engine_hash=execution_engine_hash - ) - output_packet = None - if not skip_cache_lookup and self.mode == "production": - print("Checking for cache...") - output_packet = self.get_cached_output_for_packet(packet) - if output_packet is not None: - print(f"Cache hit for {packet}!") - if output_packet is None: - tag, output_packet = super().call( - tag, - packet, - record_id=record_id, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - if ( - output_packet is not None - and not skip_cache_insert - and self.mode == "production" - ): - self.record_packet(packet, output_packet, record_id=record_id) - - return tag, output_packet - - async def async_call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, - ) -> tuple[cp.Tag, cp.Packet | None]: - # TODO: consider logic for overwriting existing records - execution_engine_hash = execution_engine.name if execution_engine else "default" - - if record_id is None: - record_id = self.get_record_id( - packet, execution_engine_hash=execution_engine_hash - ) - output_packet = None - if not skip_cache_lookup: - output_packet = self.get_cached_output_for_packet(packet) - if output_packet is None: - tag, output_packet = await super().async_call( - tag, - packet, - record_id=record_id, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - if output_packet is not None and not skip_cache_insert: - self.record_packet( - packet, - output_packet, - record_id=record_id, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - - return tag, output_packet - - def forward(self, *streams: cp.Stream) -> cp.Stream: - assert len(streams) == 1, "PodBase.forward expects exactly one input stream" - return CachedPodStream(pod=self, input_stream=streams[0]) - - def record_packet( - self, - input_packet: cp.Packet, - output_packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_duplicates: bool = False, - ) -> cp.Packet: - """ - Record the output packet against the input packet in the result store. - """ - - # TODO: consider incorporating execution_engine_opts into the record - data_table = output_packet.as_table(include_context=True, include_source=True) - - for i, (k, v) in enumerate(self.tiered_pod_id.items()): - # add the tiered pod ID to the data table - data_table = data_table.add_column( - i, - f"{constants.POD_ID_PREFIX}{k}", - pa.array([v], type=pa.large_string()), - ) - - # add the input packet hash as a column - data_table = data_table.add_column( - 0, - constants.INPUT_PACKET_HASH, - pa.array([str(input_packet.content_hash())], type=pa.large_string()), - ) - # add execution engine information - execution_engine_hash = execution_engine.name if execution_engine else "default" - data_table = data_table.append_column( - constants.EXECUTION_ENGINE, - pa.array([execution_engine_hash], type=pa.large_string()), - ) - - # add computation timestamp - timestamp = datetime.now(timezone.utc) - data_table = data_table.append_column( - constants.POD_TIMESTAMP, - pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), - ) - - if record_id is None: - record_id = self.get_record_id( - input_packet, execution_engine_hash=execution_engine_hash - ) - - self.result_database.add_record( - self.record_path, - record_id, - data_table, - skip_duplicates=skip_duplicates, - ) - # if result_flag is None: - # # TODO: do more specific error handling - # raise ValueError( - # f"Failed to record packet {input_packet} in result store {self.result_store}" - # ) - # # TODO: make store return retrieved table - return output_packet - - def get_cached_output_for_packet(self, input_packet: cp.Packet) -> cp.Packet | None: - """ - Retrieve the output packet from the result store based on the input packet. - If more than one output packet is found, conflict resolution strategy - will be applied. - If the output packet is not found, return None. - """ - # result_table = self.result_store.get_record_by_id( - # self.record_path, - # self.get_entry_hash(input_packet), - # ) - - # get all records with matching the input packet hash - # TODO: add match based on match_tier if specified - constraints = {constants.INPUT_PACKET_HASH: str(input_packet.content_hash())} - if self.match_tier is not None: - constraints[f"{constants.POD_ID_PREFIX}{self.match_tier}"] = ( - self.pod.tiered_pod_id[self.match_tier] - ) - - result_table = self.result_database.get_records_with_column_value( - self.record_path, - constraints, - ) - if result_table is None or result_table.num_rows == 0: - return None - - if result_table.num_rows > 1: - logger.info( - f"Performing conflict resolution for multiple records for {input_packet.content_hash().display_name()}" - ) - if self.retrieval_mode == "latest": - result_table = result_table.sort_by( - self.DATA_RETRIEVED_FLAG, ascending=False - ).take([0]) - elif self.retrieval_mode == "most_specific": - # match by the most specific pod ID - # trying next level if not found - for k, v in reversed(self.tiered_pod_id.items()): - search_result = result_table.filter( - pc.field(f"{constants.POD_ID_PREFIX}{k}") == v - ) - if search_result.num_rows > 0: - result_table = search_result.take([0]) - break - if result_table.num_rows > 1: - logger.warning( - f"No matching record found for {input_packet.content_hash().display_name()} with tiered pod ID {self.tiered_pod_id}" - ) - result_table = result_table.sort_by( - self.DATA_RETRIEVED_FLAG, ascending=False - ).take([0]) - - else: - raise ValueError( - f"Unknown retrieval mode: {self.retrieval_mode}. Supported modes are 'latest' and 'most_specific'." - ) - - pod_id_columns = [ - f"{constants.POD_ID_PREFIX}{k}" for k in self.tiered_pod_id.keys() - ] - result_table = result_table.drop_columns(pod_id_columns) - result_table = result_table.drop_columns(constants.INPUT_PACKET_HASH) - - # note that data context will be loaded from the result store - return ArrowPacket( - result_table, - meta_info={self.DATA_RETRIEVED_FLAG: str(datetime.now(timezone.utc))}, - ) - - def get_all_cached_outputs( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - """ - Get all records from the result store for this pod. - If include_system_columns is True, include system columns in the result. - """ - record_id_column = ( - constants.PACKET_RECORD_ID if include_system_columns else None - ) - result_table = self.result_database.get_all_records( - self.record_path, record_id_column=record_id_column - ) - if result_table is None or result_table.num_rows == 0: - return None - - if not include_system_columns: - # remove input packet hash and tiered pod ID columns - pod_id_columns = [ - f"{constants.POD_ID_PREFIX}{k}" for k in self.tiered_pod_id.keys() - ] - result_table = result_table.drop_columns(pod_id_columns) - result_table = result_table.drop_columns(constants.INPUT_PACKET_HASH) - - return result_table diff --git a/src/orcapod/core/result_cache.py b/src/orcapod/core/result_cache.py new file mode 100644 index 00000000..8afa8b8f --- /dev/null +++ b/src/orcapod/core/result_cache.py @@ -0,0 +1,226 @@ +"""ResultCache — shared result caching logic for CachedPacketFunction and CachedFunctionPod. + +Owns the database, record path, lookup (with match strategy), store, +conflict resolution, and auto-flush behavior. Both ``CachedPacketFunction`` +and ``CachedFunctionPod`` delegate to a ``ResultCache`` instance. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from orcapod.protocols.core_protocols import PacketProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.system_constants import constants +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +class ResultCache: + """Shared result caching backed by an ``ArrowDatabaseProtocol``. + + Provides lookup (by input packet hash + optional additional constraints), + store (output data + variation/execution metadata + timestamp), conflict + resolution (most-recent-timestamp wins), and auto-flush. + + The match strategy is extensible: the default lookup matches on + ``INPUT_PACKET_HASH_COL`` only, but callers can supply additional + constraints (e.g. function variation columns) to narrow the match. + This is the hook for future match-tier support (see DESIGN_ISSUES P6). + + Args: + result_database: The database to store/retrieve cached results. + record_path: The record path tuple for scoping records in the database. + auto_flush: If True, flush the database after each store operation. + """ + + # Meta column indicating whether the result was freshly computed + RESULT_COMPUTED_FLAG = f"{constants.META_PREFIX}computed" + + def __init__( + self, + result_database: ArrowDatabaseProtocol, + record_path: tuple[str, ...], + auto_flush: bool = True, + ) -> None: + self._result_database = result_database + self._record_path = record_path + self._auto_flush = auto_flush + + @property + def result_database(self) -> ArrowDatabaseProtocol: + """The underlying database.""" + return self._result_database + + @property + def record_path(self) -> tuple[str, ...]: + """The record path for scoping records in the database.""" + return self._record_path + + def set_auto_flush(self, on: bool = True) -> None: + """Set auto-flush behavior.""" + self._auto_flush = on + + def lookup( + self, + input_packet: PacketProtocol, + additional_constraints: dict[str, str] | None = None, + ) -> PacketProtocol | None: + """Look up a cached output packet for *input_packet*. + + The default match is by ``INPUT_PACKET_HASH_COL`` only. + *additional_constraints* can narrow the match further (e.g. by + function variation hash for stricter cache invalidation). + + If multiple records match, the most recent (by timestamp) wins. + + Args: + input_packet: The input packet whose content hash is the + primary lookup key. + additional_constraints: Optional extra column-value pairs to + include in the lookup query. + + Returns: + The cached output packet with ``RESULT_COMPUTED_FLAG: False`` + in its meta, or ``None`` if no match was found. + """ + from orcapod.core.datagrams import Packet + + RECORD_ID_COL = "_record_id" + + constraints: dict[str, str] = { + constants.INPUT_PACKET_HASH_COL: input_packet.content_hash().to_string(), + } + if additional_constraints: + constraints.update(additional_constraints) + + result_table = self._result_database.get_records_with_column_value( + self._record_path, + constraints, + record_id_column=RECORD_ID_COL, + ) + + if result_table is None or result_table.num_rows == 0: + return None + + if result_table.num_rows > 1: + logger.info( + "Cache conflict resolution: %d records for constraints %s, " + "taking most recent", + result_table.num_rows, + list(constraints.keys()), + ) + result_table = result_table.sort_by( + [(constants.POD_TIMESTAMP, "descending")] + ).take([0]) + + record_id = result_table.to_pylist()[0][RECORD_ID_COL] + # Drop lookup columns from the returned packet + drop_cols = [RECORD_ID_COL] + [ + c for c in constraints if c in result_table.column_names + ] + result_table = result_table.drop_columns(drop_cols) + + return Packet( + result_table, + record_id=record_id, + meta_info={self.RESULT_COMPUTED_FLAG: False}, + ) + + def store( + self, + input_packet: PacketProtocol, + output_packet: PacketProtocol, + variation_data: dict[str, Any], + execution_data: dict[str, Any], + skip_duplicates: bool = False, + ) -> None: + """Store an output packet in the cache. + + Stores the output packet data alongside function variation data, + execution data, input packet hash, and a timestamp. + + Args: + input_packet: The input packet (used for its content hash). + output_packet: The computed output packet to store. + variation_data: Function variation metadata (e.g. function name, + signature hash, content hash, git hash). + execution_data: Execution environment metadata (e.g. python + version, execution context). + skip_duplicates: If True, silently skip if a record with the + same ID already exists. + """ + data_table = output_packet.as_table(columns={"source": True, "context": True}) + + # Add function variation data columns + i = 0 + for k, v in variation_data.items(): + data_table = data_table.add_column( + i, + f"{constants.PF_VARIATION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add execution data columns + for k, v in execution_data.items(): + data_table = data_table.add_column( + i, + f"{constants.PF_EXECUTION_PREFIX}{k}", + pa.array([v], type=pa.large_string()), + ) + i += 1 + + # Add input packet hash (position 0) + data_table = data_table.add_column( + 0, + constants.INPUT_PACKET_HASH_COL, + pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), + ) + + # Append timestamp + timestamp = datetime.now(timezone.utc) + data_table = data_table.append_column( + constants.POD_TIMESTAMP, + pa.array([timestamp], type=pa.timestamp("us", tz="UTC")), + ) + + self._result_database.add_record( + self._record_path, + output_packet.datagram_id, + data_table, + skip_duplicates=skip_duplicates, + ) + + if self._auto_flush: + self._result_database.flush() + + def get_all_records( + self, include_system_columns: bool = False + ) -> "pa.Table | None": + """Return all cached records from the result store. + + Args: + include_system_columns: If True, include system columns + (e.g. record_id) in the result. + + Returns: + A PyArrow table of cached results, or ``None`` if empty. + """ + record_id_column = ( + constants.PACKET_RECORD_ID if include_system_columns else None + ) + result_table = self._result_database.get_all_records( + self._record_path, record_id_column=record_id_column + ) + if result_table is None or result_table.num_rows == 0: + return None + return result_table diff --git a/src/orcapod/core/sources/__init__.py b/src/orcapod/core/sources/__init__.py index 6bc4cf3b..3f3a7952 100644 --- a/src/orcapod/core/sources/__init__.py +++ b/src/orcapod/core/sources/__init__.py @@ -1,16 +1,26 @@ -from .base import SourceBase +from .base import RootSource from .arrow_table_source import ArrowTableSource +from .cached_source import CachedSource +from .csv_source import CSVSource +from .data_frame_source import DataFrameSource from .delta_table_source import DeltaTableSource +from .derived_source import DerivedSource from .dict_source import DictSource -from .data_frame_source import DataFrameSource -from .source_registry import SourceRegistry, GLOBAL_SOURCE_REGISTRY +from .list_source import ListSource +from .source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry +from .source_proxy import SourceProxy __all__ = [ - "SourceBase", - "DataFrameSource", + "RootSource", "ArrowTableSource", + "CachedSource", + "CSVSource", + "DataFrameSource", "DeltaTableSource", + "DerivedSource", "DictSource", + "ListSource", "SourceRegistry", + "SourceProxy", "GLOBAL_SOURCE_REGISTRY", ] diff --git a/src/orcapod/core/sources/arrow_table_source.py b/src/orcapod/core/sources/arrow_table_source.py index 7d3c7897..42d04a63 100644 --- a/src/orcapod/core/sources/arrow_table_source.py +++ b/src/orcapod/core/sources/arrow_table_source.py @@ -1,132 +1,74 @@ +from __future__ import annotations + from collections.abc import Collection from typing import TYPE_CHECKING, Any - -from orcapod.core.streams import TableStream -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.system_constants import constants -from orcapod.core import arrow_data_utils -from orcapod.core.sources.source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry - -from orcapod.core.sources.base import SourceBase +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder if TYPE_CHECKING: import pyarrow as pa -else: - pa = LazyModule("pyarrow") -class ArrowTableSource(SourceBase): - """Construct source from a collection of dictionaries""" +class ArrowTableSource(RootSource): + """A source backed by an in-memory PyArrow Table. - SOURCE_ID = "arrow" + Uses ``SourceStreamBuilder`` to strip system columns, add per-row + source-info provenance columns and a system tag column encoding the + schema hash, then wraps the result in an ``ArrowTableStream``. + """ def __init__( self, - arrow_table: "pa.Table", + table: "pa.Table", tag_columns: Collection[str] = (), - source_name: str | None = None, - source_registry: SourceRegistry | None = None, - auto_register: bool = True, - preserve_system_columns: bool = False, - **kwargs, - ): + system_tag_columns: Collection[str] = (), + record_id_column: str | None = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) - # clean the table, dropping any system columns - # TODO: consider special treatment of system columns if provided - if not preserve_system_columns: - arrow_table = arrow_data_utils.drop_system_columns(arrow_table) - - non_system_columns = arrow_data_utils.drop_system_columns(arrow_table) - tag_schema = non_system_columns.select(tag_columns).schema - # FIXME: ensure tag_columns are found among non system columns - packet_schema = non_system_columns.drop(list(tag_columns)).schema - - tag_python_schema = ( - self.data_context.type_converter.arrow_schema_to_python_schema(tag_schema) - ) - packet_python_schema = ( - self.data_context.type_converter.arrow_schema_to_python_schema( - packet_schema - ) - ) - - schema_hash = self.data_context.object_hasher.hash_object( - (tag_python_schema, packet_python_schema) - ).to_hex(char_count=self.orcapod_config.schema_hash_n_char) - - self.tag_columns = [ - col for col in tag_columns if col in arrow_table.column_names - ] - - self.table_hash = self.data_context.arrow_hasher.hash_table(arrow_table) - - if source_name is None: - # TODO: determine appropriate config name - source_name = self.content_hash().to_hex( - char_count=self.orcapod_config.path_hash_n_char - ) - - self._source_name = source_name - - row_index = list(range(arrow_table.num_rows)) - - source_info = [ - f"{self.source_id}{constants.BLOCK_SEPARATOR}row_{i}" for i in row_index - ] - - # add source info - arrow_table = arrow_data_utils.add_source_info( - arrow_table, source_info, exclude_columns=tag_columns + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + table, + tag_columns=tag_columns, + source_id=self._source_id, + record_id_column=record_id_column, + system_tag_columns=system_tag_columns, ) - arrow_table = arrow_data_utils.add_system_tag_column( - arrow_table, f"source{constants.FIELD_SEPARATOR}{schema_hash}", source_info - ) - - self._table = arrow_table - - self._table_stream = TableStream( - table=self._table, - tag_columns=self.tag_columns, - source=self, - upstreams=(), + self._stream = result.stream + self._schema_hash = result.schema_hash + self._table_hash = result.table_hash + self._tag_columns = result.tag_columns + self._system_tag_columns = result.system_tag_columns + self._record_id_column = record_id_column + + if self._source_id is None: + self._source_id = result.source_id + + def to_config(self) -> dict[str, Any]: + """Serialize metadata-only config (in-memory table is not serializable).""" + return { + "source_type": "arrow_table", + "tag_columns": list(self._tag_columns), + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> ArrowTableSource: + """Not supported — ArrowTableSource cannot be reconstructed from config. + + Raises: + NotImplementedError: Always. + """ + raise NotImplementedError( + "ArrowTableSource cannot be reconstructed from config — " + "the in-memory Arrow table is not serializable." ) - # Auto-register with global registry - if auto_register: - registry = source_registry or GLOBAL_SOURCE_REGISTRY - registry.register(self.source_id, self) - - @property - def reference(self) -> tuple[str, ...]: - return ("arrow_table", f"source_{self._source_name}") - @property def table(self) -> "pa.Table": - return self._table - - def source_identity_structure(self) -> Any: - return (self.__class__.__name__, self.tag_columns, self.table_hash) - - def get_all_records( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - return self().as_table(include_source=include_system_columns) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Load data from file and return a static stream. - - This is called by forward() and creates a fresh snapshot each time. - """ - return self._table_stream - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Return tag and packet types based on provided typespecs.""" - return self._table_stream.types(include_system_tags=include_system_tags) + """Return the enriched table (with source-info and system tags).""" + return self._stream.as_table(columns={"source": True, "system_tags": True}) diff --git a/src/orcapod/core/sources/base.py b/src/orcapod/core/sources/base.py index 8f445998..784f59b4 100644 --- a/src/orcapod/core/sources/base.py +++ b/src/orcapod/core/sources/base.py @@ -1,511 +1,216 @@ -from abc import abstractmethod -from collections.abc import Collection, Iterator -from typing import TYPE_CHECKING, Any +from __future__ import annotations +from typing import TYPE_CHECKING, Any -from orcapod.core.kernels import TrackedKernelBase -from orcapod.core.streams import ( - KernelStream, - StatefulStreamBase, -) -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema -from orcapod.utils.lazy_module import LazyModule +from orcapod import contexts +from orcapod.config import Config +from orcapod.core.streams.base import StreamBase +from orcapod.protocols.core_protocols import StreamProtocol +from orcapod.types import ColumnConfig, Schema if TYPE_CHECKING: import pyarrow as pa -else: - pa = LazyModule("pyarrow") + from orcapod.core.sources.cached_source import CachedSource + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol -class InvocationBase(TrackedKernelBase, StatefulStreamBase): - def __init__(self, **kwargs): - super().__init__(**kwargs) - # Cache the KernelStream for reuse across all stream method calls - self._cached_kernel_stream: KernelStream | None = None - - def computed_label(self) -> str | None: - return None - - @abstractmethod - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: ... - - # Redefine the reference to ensure subclass would provide a concrete implementation - @property - @abstractmethod - def reference(self) -> tuple[str, ...]: - """Return the unique identifier for the kernel.""" - ... - - # =========================== Kernel Methods =========================== - - # The following are inherited from TrackedKernelBase as abstract methods. - # @abstractmethod - # def forward(self, *streams: dp.Stream) -> dp.Stream: - # """ - # Pure computation: return a static snapshot of the data. - - # This is the core method that subclasses must implement. - # Each call should return a fresh stream representing the current state of the data. - # This is what KernelStream calls when it needs to refresh its data. - # """ - # ... - - # @abstractmethod - # def kernel_output_types(self, *streams: dp.Stream) -> tuple[TypeSpec, TypeSpec]: - # """Return the tag and packet types this source produces.""" - # ... - - # @abstractmethod - # def kernel_identity_structure( - # self, streams: Collection[dp.Stream] | None = None - # ) -> dp.Any: ... - - def prepare_output_stream( - self, *streams: cp.Stream, label: str | None = None - ) -> KernelStream: - if self._cached_kernel_stream is None: - self._cached_kernel_stream = super().prepare_output_stream( - *streams, label=label - ) - return self._cached_kernel_stream - - def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> None: - raise NotImplementedError("Behavior for track invocation is not determined") - - # ==================== Stream Protocol (Delegation) ==================== - - @property - def source(self) -> cp.Kernel | None: - """Sources are their own source.""" - return self - - # @property - # def upstreams(self) -> tuple[cp.Stream, ...]: ... - - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """Delegate to the cached KernelStream.""" - return self().keys(include_system_tags=include_system_tags) - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Delegate to the cached KernelStream.""" - return self().types(include_system_tags=include_system_tags) +class RootSource(StreamBase): + """Abstract base class for all root sources in Orcapod. - @property - def last_modified(self): - """Delegate to the cached KernelStream.""" - return self().last_modified + A RootSource is a pure stream — the root of a computational graph, producing + data from an external source (file, database, in-memory data, etc.) with no + upstream dependencies. - @property - def is_current(self) -> bool: - """Delegate to the cached KernelStream.""" - return self().is_current + As a StreamProtocol: + - ``source`` returns ``None`` (no upstream source pod) + - ``upstreams`` is always empty + - ``keys``, ``output_schema``, ``iter_packets``, ``as_table`` delegate to + ``self._stream`` by default; concrete subclasses may override them. - def __iter__(self) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """ - Iterate over the cached KernelStream. + As a PipelineElementProtocol: + - ``pipeline_identity_structure()`` returns ``(tag_schema, packet_schema)`` + — schema-only, no data content — forming the base case of the pipeline + identity Merkle chain. - This allows direct iteration over the source as if it were a stream. - """ - return self().iter_packets() + Source identity: + Every source has a ``source_id`` — a canonical name that determines the + source's content identity and is used in the ``SourceRegistry`` so that + provenance tokens embedded in downstream data can be resolved back to the + originating source object. - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """Delegate to the cached KernelStream.""" - return self().iter_packets( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) + Concrete subclasses must ensure ``_source_id`` is set by the end of + ``__init__``. File-backed sources (DeltaTableSource, CSVSource) default + to the file path; ``ArrowTableSource`` defaults to the table's data hash. - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": - """Delegate to the cached KernelStream.""" - return self().as_table( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) + Field resolution: + All sources expose ``resolve_field(record_id, field_name)``. The default + implementation raises ``NotImplementedError``; concrete subclasses + that back addressable data should override it. - def flow( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Collection[tuple[cp.Tag, cp.Packet]]: - """Delegate to the cached KernelStream.""" - return self().flow( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) + Concrete subclasses must set ``self._stream`` in their ``__init__`` to get + the default stream delegation behavior. They may also override + ``identity_structure()``. + """ - def run( + def __init__( self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, + source_id: str | None = None, + label: str | None = None, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, ) -> None: - """ - Run the source node, executing the contained source. - - This is a no-op for sources since they are not executed like pods. - """ - self().run( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, + super().__init__( + label=label, + data_context=data_context, + config=config, ) + self._source_id = source_id - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """ - Run the source node asynchronously, executing the contained source. + # ------------------------------------------------------------------------- + # Source identity + # ------------------------------------------------------------------------- - This is a no-op for sources since they are not executed like pods. + @property + def source_id(self) -> str: + """Canonical name for this source used in the registry and provenance + strings. If not set, raises ``ValueError``. """ - await self().run_async( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, - ) + if self._source_id is None: + raise ValueError("source_id is not set") + return self._source_id - # ==================== LiveStream Protocol (Delegation) ==================== + # ------------------------------------------------------------------------- + # Content identity + # ------------------------------------------------------------------------- - def refresh(self, force: bool = False) -> bool: - """Delegate to the cached KernelStream.""" - return self().refresh(force=force) + def identity_structure(self) -> Any: + """Default identity based on class name, output schema, and source_id.""" + return (self.__class__.__name__, self.output_schema(), self.source_id) - def invalidate(self) -> None: - """Delegate to the cached KernelStream.""" - return self().invalidate() + # ------------------------------------------------------------------------- + # Field resolution + # ------------------------------------------------------------------------- + def resolve_field(self, record_id: str, field_name: str) -> Any: + """Resolve a field value for a record. -class SourceBase(TrackedKernelBase, StatefulStreamBase): - """ - Base class for sources that act as both Kernels and LiveStreams. + Not implemented by default. Subclasses that back addressable data + should override this method. - Design Philosophy: - 1. Source is fundamentally a Kernel (data loader) - 2. forward() returns static snapshots as a stream (pure computation) - 3. __call__() returns a cached KernelStream (live, tracked) - 4. All stream methods delegate to the cached KernelStream + Raises: + NotImplementedError: Always, by default. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement resolve_field. " + f"Cannot resolve field {field_name!r} for record {record_id!r}." + ) - This ensures that direct source iteration and source() iteration - are identical and both benefit from KernelStream's lifecycle management. - """ + # ------------------------------------------------------------------------- + # PipelineElementProtocol — schema-only identity (base case of Merkle chain) + # ------------------------------------------------------------------------- - def __init__(self, **kwargs): - super().__init__(**kwargs) - # Cache the KernelStream for reuse across all stream method calls - self._cached_kernel_stream: KernelStream | None = None - self._schema_hash: str | None = None + def _identity_config(self) -> dict[str, Any]: + """Return identity fields for inclusion in ``to_config()`` output. - # reset, so that computed label won't be used from StatefulStreamBase - def computed_label(self) -> str | None: - return None + These fields allow ``SourceProxy`` to be constructed when the source + cannot be reconstructed from config, preserving identity hashes and + schemas for downstream consumers. + """ + from orcapod.pipeline.serialization import serialize_schema - def schema_hash(self) -> str: - if self._schema_hash is None: - self._schema_hash = self.data_context.object_hasher.hash_object( - (self.tag_types(), self.packet_types()) - ).to_hex(self.orcapod_config.schema_hash_n_char) - return self._schema_hash - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - if streams is not None: - # when checked for invocation id, act as a source - # and just return the output packet types - # _, packet_types = self.stream.types() - # return packet_types - return self.schema_hash() - # otherwise, return the identity structure of the stream - return self.source_identity_structure() + tag_schema, packet_schema = self.output_schema() + type_converter = self.data_context.type_converter + return { + "content_hash": self.content_hash().to_string(), + "pipeline_hash": self.pipeline_hash().to_string(), + "tag_schema": serialize_schema(tag_schema, type_converter), + "packet_schema": serialize_schema(packet_schema, type_converter), + } - @property - def source_id(self) -> str: - return ":".join(self.reference) + def pipeline_identity_structure(self) -> Any: + """Return (tag_schema, packet_schema) as the pipeline identity for this + source. Schema-only: no data content is included, so sources with + identical schemas share the same pipeline hash and therefore the same + pipeline database table. + """ + tag_schema, packet_schema = self.output_schema() + return (tag_schema, packet_schema) - # Redefine the reference to ensure subclass would provide a concrete implementation - @property - @abstractmethod - def reference(self) -> tuple[str, ...]: - """Return the unique identifier for the kernel.""" - ... - - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - return self.source_output_types(include_system_tags=include_system_tags) - - @abstractmethod - def source_identity_structure(self) -> Any: ... - - @abstractmethod - def source_output_types(self, include_system_tags: bool = False) -> Any: ... - - # =========================== Kernel Methods =========================== - - # The following are inherited from TrackedKernelBase as abstract methods. - # @abstractmethod - # def forward(self, *streams: dp.Stream) -> dp.Stream: - # """ - # Pure computation: return a static snapshot of the data. - - # This is the core method that subclasses must implement. - # Each call should return a fresh stream representing the current state of the data. - # This is what KernelStream calls when it needs to refresh its data. - # """ - # ... - - # @abstractmethod - # def kernel_output_types(self, *streams: dp.Stream) -> tuple[TypeSpec, TypeSpec]: - # """Return the tag and packet types this source produces.""" - # ... - - # @abstractmethod - # def kernel_identity_structure( - # self, streams: Collection[dp.Stream] | None = None - # ) -> dp.Any: ... - - def validate_inputs(self, *streams: cp.Stream) -> None: - """Sources take no input streams.""" - if len(streams) > 0: - raise ValueError( - f"{self.__class__.__name__} is a source and takes no input streams" - ) - - def prepare_output_stream( - self, *streams: cp.Stream, label: str | None = None - ) -> KernelStream: - if self._cached_kernel_stream is None: - self._cached_kernel_stream = super().prepare_output_stream( - *streams, label=label - ) - return self._cached_kernel_stream - - def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> None: - if not self._skip_tracking and self._tracker_manager is not None: - self._tracker_manager.record_source_invocation(self, label=label) - - # ==================== Stream Protocol (Delegation) ==================== + # ------------------------------------------------------------------------- + # StreamProtocol protocol + # ------------------------------------------------------------------------- @property - def source(self) -> cp.Kernel | None: - """Sources are their own source.""" - return self + def producer(self) -> None: + """Root sources have no upstream source pod.""" + return None @property - def upstreams(self) -> tuple[cp.Stream, ...]: + def upstreams(self) -> tuple[StreamProtocol, ...]: """Sources have no upstream dependencies.""" return () - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """Delegate to the cached KernelStream.""" - return self().keys(include_system_tags=include_system_tags) - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Delegate to the cached KernelStream.""" - return self().types(include_system_tags=include_system_tags) - - @property - def last_modified(self): - """Delegate to the cached KernelStream.""" - return self().last_modified + # ------------------------------------------------------------------------- + # Stream delegation defaults + # ------------------------------------------------------------------------- - @property - def is_current(self) -> bool: - """Delegate to the cached KernelStream.""" - return self().is_current - - def __iter__(self) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """ - Iterate over the cached KernelStream. - - This allows direct iteration over the source as if it were a stream. - """ - return self().iter_packets() + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """Delegate to the underlying stream's output_schema.""" + return self._stream.output_schema(columns=columns, all_info=all_info) - def iter_packets( + def keys( self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """Delegate to the cached KernelStream.""" - return self().iter_packets( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + """Delegate to the underlying stream's keys.""" + return self._stream.keys(columns=columns, all_info=all_info) + + def iter_packets(self): + """Delegate to the underlying stream's iter_packets.""" + return self._stream.iter_packets() def as_table( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pa.Table": - """Delegate to the cached KernelStream.""" - return self().as_table( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) + """Delegate to the underlying stream's as_table.""" + return self._stream.as_table(columns=columns, all_info=all_info) - def flow( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Collection[tuple[cp.Tag, cp.Packet]]: - """Delegate to the cached KernelStream.""" - return self().flow( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) + # ------------------------------------------------------------------------- + # Convenience — caching + # ------------------------------------------------------------------------- - def run( + def cached( self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + cache_database: ArrowDatabaseProtocol, + cache_path_prefix: tuple[str, ...] = (), **kwargs: Any, - ) -> None: - """ - Run the source node, executing the contained source. + ) -> CachedSource: + """Return a ``CachedSource`` wrapping this source. - This is a no-op for sources since they are not executed like pods. - """ - self().run( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, - ) + Args: + cache_database: Database to store cached records in. + cache_path_prefix: Path prefix for the cache table. + **kwargs: Additional keyword arguments passed to ``CachedSource``. - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: + Returns: + A ``CachedSource`` that caches this source's output. """ - Run the source node asynchronously, executing the contained source. + from orcapod.core.sources.cached_source import CachedSource - This is a no-op for sources since they are not executed like pods. - """ - await self().run_async( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + return CachedSource( + source=self, + cache_database=cache_database, + cache_path_prefix=cache_path_prefix, **kwargs, ) - - # ==================== LiveStream Protocol (Delegation) ==================== - - def refresh(self, force: bool = False) -> bool: - """Delegate to the cached KernelStream.""" - return self().refresh(force=force) - - def invalidate(self) -> None: - """Delegate to the cached KernelStream.""" - return self().invalidate() - - # ==================== Source Protocol ==================== - - def reset_cache(self) -> None: - """ - Clear the cached KernelStream, forcing a fresh one on next access. - - Useful when the underlying data source has fundamentally changed - (e.g., file path changed, database connection reset). - """ - if self._cached_kernel_stream is not None: - self._cached_kernel_stream.invalidate() - self._cached_kernel_stream = None - - -class StreamSource(SourceBase): - def __init__(self, stream: cp.Stream, label: str | None = None, **kwargs) -> None: - """ - A placeholder source based on stream - This is used to represent a kernel that has no computation. - """ - label = label or stream.label - self.stream = stream - super().__init__(label=label, **kwargs) - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Returns the types of the tag and packet columns in the stream. - This is useful for accessing the types of the columns in the stream. - """ - return self.stream.types(include_system_tags=include_system_tags) - - @property - def reference(self) -> tuple[str, ...]: - return ("stream", self.stream.content_hash().to_string()) - - def forward(self, *args: Any, **kwargs: Any) -> cp.Stream: - """ - Forward the stream through the stub kernel. - This is a no-op and simply returns the stream. - """ - return self.stream - - def source_identity_structure(self) -> Any: - return self.stream.identity_structure() - - # def __hash__(self) -> int: - # # TODO: resolve the logic around identity structure on a stream / stub kernel - # """ - # Hash the StubKernel based on its label and stream. - # This is used to uniquely identify the StubKernel in the tracker. - # """ - # identity_structure = self.identity_structure() - # if identity_structure is None: - # return hash(self.stream) - # return identity_structure - - -# ==================== Example Implementation ==================== diff --git a/src/orcapod/core/sources/cached_source.py b/src/orcapod/core/sources/cached_source.py new file mode 100644 index 00000000..c9141fd4 --- /dev/null +++ b/src/orcapod/core/sources/cached_source.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import logging +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any + +from orcapod import contexts +from orcapod.config import Config +from orcapod.core.sources.base import RootSource +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.protocols.core_protocols import PacketProtocol, SourceProtocol, TagProtocol +from orcapod.protocols.database_protocols import ArrowDatabaseProtocol +from orcapod.types import ColumnConfig, Schema +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + +logger = logging.getLogger(__name__) + + +class CachedSource(RootSource): + """DB-backed wrapper around a source that caches every packet. + + Accepts any ``SourceProtocol`` implementation as the inner source. + Implements ``StreamProtocol`` transparently so downstream consumers + are unaware of caching. Cache table is scoped to the source's + ``content_hash()`` — each unique source gets its own table. + + Behavior: + - Cache is **always on**. + - On first access, live source data is stored in the cache table + (deduped by per-row content hash). + - Returns the union of all cached data (cumulative across runs). + + Semantic guarantee: + The cache is a correct cumulative record. The union of cache + live + packets is the full set of data ever available from that source. + + Example:: + + source = ArrowTableSource(table, tag_columns=["id"]) + cached = CachedSource(source, cache_database=db) + # or equivalently: + cached = source.cached(cache_database=db) + """ + + HASH_COLUMN_NAME = "_record_hash" + + def __init__( + self, + source: SourceProtocol, + cache_database: ArrowDatabaseProtocol, + cache_path_prefix: tuple[str, ...] = (), + cache_path: tuple[str, ...] | None = None, + source_id: str | None = None, + label: str | None = None, + data_context: str | contexts.DataContext | None = None, + config: Config | None = None, + ) -> None: + if data_context is None: + data_context = source.data_context_key + if source_id is None: + source_id = source.source_id + super().__init__( + source_id=source_id, + label=label, + data_context=data_context, + config=config, + ) + self._source: SourceProtocol = source + self._cache_database = cache_database + self._cache_path_prefix = cache_path_prefix + self._explicit_cache_path = cache_path + self._cached_stream: ArrowTableStream | None = None + + # ------------------------------------------------------------------------- + # Serialization + # ------------------------------------------------------------------------- + + def to_config(self) -> dict[str, Any]: + """Serialize this CachedSource configuration to a JSON-compatible dict. + + Returns: + Dict containing the inner source config, cache database config, + cache path prefix, and resolved cache path (for cache-only loading). + """ + return { + "source_type": "cached", + "inner_source": self._source.to_config(), + "cache_database": self._cache_database.to_config(), + "cache_path_prefix": list(self._cache_path_prefix), + "cache_path": list(self.cache_path), + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> CachedSource: + """Reconstruct a CachedSource from a config dict. + + If the inner source cannot be resolved (e.g. it requires live data + that is unavailable), ``resolve_source_from_config`` returns a + ``SourceProxy`` preserving the original source's identity. The + CachedSource can still serve data from its cache database. + + Args: + config: Dict as produced by :meth:`to_config`. + + Returns: + A new CachedSource constructed from the config. + """ + from orcapod.pipeline.serialization import ( + resolve_database_from_config, + resolve_source_from_config, + ) + + cache_db = resolve_database_from_config(config["cache_database"]) + inner_source = resolve_source_from_config( + config["inner_source"], fallback_to_proxy=True + ) + + return cls( + source=inner_source, + cache_database=cache_db, + cache_path_prefix=tuple(config.get("cache_path_prefix", ())), + cache_path=tuple(config["cache_path"]) if "cache_path" in config else None, + source_id=config.get("source_id"), + ) + + # ------------------------------------------------------------------------- + # Identity — delegate to wrapped source + # ------------------------------------------------------------------------- + + def identity_structure(self) -> Any: + return self._source.identity_structure() + + # ------------------------------------------------------------------------- + # Cache path — scoped to source's content hash + # ------------------------------------------------------------------------- + + @property + def cache_path(self) -> tuple[str, ...]: + """Cache table path, scoped to the source's content hash.""" + if self._explicit_cache_path is not None: + return self._explicit_cache_path + return self._cache_path_prefix + ( + "source", + f"node:{self._source.content_hash().to_string()}", + ) + + # ------------------------------------------------------------------------- + # Stream interface — delegate schema, materialize with cache + # ------------------------------------------------------------------------- + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + return self._source.output_schema(columns=columns, all_info=all_info) + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + return self._source.keys(columns=columns, all_info=all_info) + + def _ingest_live_data(self) -> None: + """Fetch live data from the source and store new rows in the cache. + + Raises if the source cannot provide data (e.g. an unbound + ``SourceProxy``). + """ + live_table = self._source.as_table( + columns={"source": True, "system_tags": True} + ) + + # Compute per-row record hashes for dedup: hash(full row) + arrow_hasher = self.data_context.arrow_hasher + record_hashes: list[str] = [] + for batch in live_table.to_batches(): + for i in range(len(batch)): + record_hashes.append( + arrow_hasher.hash_table(batch.slice(i, 1)).to_hex() + ) + + # Store in DB with hash as record ID (skip_duplicates deduplicates) + live_with_hash = live_table.add_column( + 0, + self.HASH_COLUMN_NAME, + pa.array(record_hashes, type=pa.large_string()), + ) + self._cache_database.add_records( + self.cache_path, + live_with_hash, + record_id_column=self.HASH_COLUMN_NAME, + skip_duplicates=True, + ) + self._cache_database.flush() + + def _build_merged_stream(self) -> ArrowTableStream: + """Ingest live data (if available), then return all cached records. + + If the inner source cannot provide data (e.g. an unbound + ``SourceProxy``), the method falls back to returning whatever is + already stored in the cache database. If the cache is empty, an + empty stream is returned. + """ + try: + self._ingest_live_data() + except NotImplementedError: + logger.info( + "Inner source %r cannot provide data; serving from cache only.", + self._source.source_id, + ) + + all_records = self._cache_database.get_all_records(self.cache_path) + if all_records is None: + all_records = self._empty_table() + + tag_keys = self._source.keys()[0] + return ArrowTableStream(all_records, tag_columns=tag_keys) + + def _empty_table(self) -> pa.Table: + """Build an empty Arrow table matching the source's output schema.""" + tag_schema, packet_schema = self._source.output_schema() + merged = dict(tag_schema) + merged.update(packet_schema) + type_converter = self.data_context.type_converter + arrow_schema = type_converter.python_schema_to_arrow_schema(merged) + return pa.Table.from_pylist([], schema=arrow_schema) + + @property + def is_stale(self) -> bool: + """True if the wrapped source has been modified since the last build. + + Overrides ``StreamBase.is_stale`` because CachedSource is a RootSource + (no upstreams/producer) yet still depends on the wrapped ``_source``. + """ + own_time = self.last_modified + if own_time is None: + return True + src_time = self._source.last_modified + return src_time is None or src_time > own_time + + def _ensure_stream(self) -> None: + """Build the merged stream on first access, or rebuild if source is stale.""" + if self._cached_stream is not None and self.is_stale: + self._cached_stream = None + if self._cached_stream is None: + self._cached_stream = self._build_merged_stream() + self._update_modified_time() + + def clear_cache(self) -> None: + """Discard in-memory cached stream (forces rebuild on next access).""" + self._cached_stream = None + + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + self._ensure_stream() + assert self._cached_stream is not None + return self._cached_stream.iter_packets() + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + self._ensure_stream() + assert self._cached_stream is not None + return self._cached_stream.as_table(columns=columns, all_info=all_info) + + def get_all_records( + self, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table | None": + """Retrieve all stored records from the cache database.""" + return self._cache_database.get_all_records(self.cache_path) diff --git a/src/orcapod/core/sources/csv_source.py b/src/orcapod/core/sources/csv_source.py index cafc6c76..632abaf9 100644 --- a/src/orcapod/core/sources/csv_source.py +++ b/src/orcapod/core/sources/csv_source.py @@ -1,66 +1,79 @@ -from typing import TYPE_CHECKING, Any +from __future__ import annotations +from collections.abc import Collection +from typing import TYPE_CHECKING, Any -from orcapod.core.streams import ( - TableStream, -) -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: - import pandas as pd - import polars as pl import pyarrow as pa else: - pl = LazyModule("polars") - pd = LazyModule("pandas") pa = LazyModule("pyarrow") -from orcapod.core.sources.base import SourceBase +class CSVSource(RootSource): + """A source backed by a CSV file. -class CSVSource(SourceBase): - """Loads data from a CSV file.""" + The file is read once at construction time using PyArrow's CSV reader, + converted to an Arrow table, and enriched by ``SourceStreamBuilder`` + (source-info, schema-hash, system tags). + """ def __init__( self, file_path: str, - tag_columns: list[str] | None = None, + tag_columns: Collection[str] = (), + system_tag_columns: Collection[str] = (), + record_id_column: str | None = None, source_id: str | None = None, - **kwargs, - ): - super().__init__(**kwargs) - self.file_path = file_path - self.tag_columns = tag_columns or [] + **kwargs: Any, + ) -> None: + import pyarrow.csv as pa_csv + if source_id is None: - source_id = self.file_path + source_id = file_path + super().__init__(source_id=source_id, **kwargs) - def source_identity_structure(self) -> Any: - return (self.__class__.__name__, self.source_id, tuple(self.tag_columns)) + self._file_path = file_path + table: pa.Table = pa_csv.read_csv(file_path) - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Load data from file and return a static stream. + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + table, + tag_columns=tag_columns, + source_id=self._source_id, + record_id_column=record_id_column, + system_tag_columns=system_tag_columns, + ) - This is called by forward() and creates a fresh snapshot each time. - """ - import pyarrow.csv as csv + self._stream = result.stream + self._tag_columns = result.tag_columns + self._system_tag_columns = result.system_tag_columns + self._record_id_column = record_id_column + if self._source_id is None: + self._source_id = result.source_id - # Load current state of the file - table = csv.read_csv(self.file_path) + def to_config(self) -> dict[str, Any]: + """Serialize this source's configuration to a JSON-compatible dict.""" + return { + "source_type": "csv", + "file_path": self._file_path, + "tag_columns": list(self._tag_columns), + "system_tag_columns": list(self._system_tag_columns), + "record_id_column": self._record_id_column, + "source_id": self.source_id, + **self._identity_config(), + } - return TableStream( - table=table, - tag_columns=self.tag_columns, - source=self, - upstreams=(), + @classmethod + def from_config(cls, config: dict[str, Any]) -> "CSVSource": + """Reconstruct a CSVSource from a config dict.""" + return cls( + file_path=config["file_path"], + tag_columns=config.get("tag_columns", ()), + system_tag_columns=config.get("system_tag_columns", ()), + record_id_column=config.get("record_id_column"), + source_id=config.get("source_id"), ) - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Infer types from the file (could be cached).""" - # For demonstration - in practice you might cache this - sample_stream = self.forward() - return sample_stream.types(include_system_tags=include_system_tags) diff --git a/src/orcapod/core/sources/data_frame_source.py b/src/orcapod/core/sources/data_frame_source.py index 2fb4a78a..e9148429 100644 --- a/src/orcapod/core/sources/data_frame_source.py +++ b/src/orcapod/core/sources/data_frame_source.py @@ -1,54 +1,47 @@ +from __future__ import annotations + +import logging from collections.abc import Collection from typing import TYPE_CHECKING, Any -from orcapod.core.streams import TableStream -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder +from orcapod.utils import polars_data_utils from orcapod.utils.lazy_module import LazyModule -from orcapod.core.system_constants import constants -from orcapod.core import polars_data_utils -from orcapod.core.sources.source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry -import logging -from orcapod.core.sources.base import SourceBase if TYPE_CHECKING: - import pyarrow as pa import polars as pl from polars._typing import FrameInitTypes else: - pa = LazyModule("pyarrow") pl = LazyModule("polars") - logger = logging.getLogger(__name__) -class DataFrameSource(SourceBase): - """Construct source from a dataframe and any Polars dataframe compatible data structure""" +class DataFrameSource(RootSource): + """A source backed by a Polars DataFrame (or any Polars-compatible data). - SOURCE_ID = "polars" + The DataFrame is converted to an Arrow table and enriched by + ``SourceStreamBuilder`` (source-info, schema-hash, system tags). + """ def __init__( self, data: "FrameInitTypes", tag_columns: str | Collection[str] = (), - source_name: str | None = None, - source_registry: SourceRegistry | None = None, - auto_register: bool = True, - preserve_system_columns: bool = False, - **kwargs, - ): - super().__init__(**kwargs) - - # clean the table, dropping any system columns - # Initialize polars dataframe - # TODO: work with LazyFrame + system_tag_columns: Collection[str] = (), + source_id: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(source_id=source_id, **kwargs) + df = pl.DataFrame(data) + # Convert any Object-dtype columns to Arrow-compatible types. object_columns = [c for c in df.columns if df[c].dtype == pl.Object] - if len(object_columns) > 0: + if object_columns: logger.info( - f"Converting {len(object_columns)}object columns to Arrow format" + f"Converting {len(object_columns)} object column(s) to Arrow format" ) sub_table = self.data_context.type_converter.python_dicts_to_arrow_table( df.select(object_columns).to_dicts() @@ -57,97 +50,44 @@ def __init__( if isinstance(tag_columns, str): tag_columns = [tag_columns] + tag_columns = list(tag_columns) - if not preserve_system_columns: - df = polars_data_utils.drop_system_columns(df) - - non_system_columns = polars_data_utils.drop_system_columns(df) - missing_columns = set(tag_columns) - set(non_system_columns.columns) - if missing_columns: - raise ValueError( - f"Following tag columns not found in data: {missing_columns}" - ) - tag_schema = non_system_columns.select(tag_columns).to_arrow().schema - packet_schema = non_system_columns.drop(list(tag_columns)).to_arrow().schema - self.tag_columns = tag_columns - - tag_python_schema = ( - self.data_context.type_converter.arrow_schema_to_python_schema(tag_schema) - ) - packet_python_schema = ( - self.data_context.type_converter.arrow_schema_to_python_schema( - packet_schema - ) - ) - schema_hash = self.data_context.object_hasher.hash_object( - (tag_python_schema, packet_python_schema) - ).to_hex(char_count=self.orcapod_config.schema_hash_n_char) - - self.table_hash = self.data_context.arrow_hasher.hash_table(df.to_arrow()) - - if source_name is None: - # TODO: determine appropriate config name - source_name = self.content_hash().to_hex( - char_count=self.orcapod_config.path_hash_n_char - ) - - self._source_name = source_name + df = polars_data_utils.drop_system_columns(df) - row_index = list(range(df.height)) + missing = set(tag_columns) - set(df.columns) + if missing: + raise ValueError(f"TagProtocol column(s) not found in data: {missing}") - source_info = [ - f"{self.source_id}{constants.BLOCK_SEPARATOR}row_{i}" for i in row_index - ] - - # add source info - df = polars_data_utils.add_source_info( - df, source_info, exclude_columns=tag_columns - ) - - df = polars_data_utils.add_system_tag_column( - df, f"source{constants.FIELD_SEPARATOR}{schema_hash}", source_info - ) - - self._df = df - - self._table_stream = TableStream( - table=self._df.to_arrow(), - tag_columns=self.tag_columns, - source=self, - upstreams=(), + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + df.to_arrow(), + tag_columns=tag_columns, + source_id=self._source_id, + system_tag_columns=system_tag_columns, ) - # Auto-register with global registry - if auto_register: - registry = source_registry or GLOBAL_SOURCE_REGISTRY - registry.register(self.source_id, self) - - @property - def reference(self) -> tuple[str, ...]: - return ("data_frame", f"source_{self._source_name}") - - @property - def df(self) -> "pl.DataFrame": - return self._df - - def source_identity_structure(self) -> Any: - return (self.__class__.__name__, self.tag_columns, self.table_hash) - - def get_all_records( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - return self().as_table(include_source=include_system_columns) - - def forward(self, *streams: cp.Stream) -> cp.Stream: + self._stream = result.stream + self._tag_columns = result.tag_columns + if self._source_id is None: + self._source_id = result.source_id + + def to_config(self) -> dict[str, Any]: + """Serialize metadata-only config (DataFrame is not serializable).""" + return { + "source_type": "data_frame", + "tag_columns": list(self._tag_columns), + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DataFrameSource": + """Not supported — DataFrameSource cannot be reconstructed from config. + + Raises: + NotImplementedError: Always. """ - Load data from file and return a static stream. - - This is called by forward() and creates a fresh snapshot each time. - """ - return self._table_stream - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Return tag and packet types based on provided typespecs.""" - return self._table_stream.types(include_system_tags=include_system_tags) + raise NotImplementedError( + "DataFrameSource cannot be reconstructed from config — " + "the original DataFrame is not serializable." + ) diff --git a/src/orcapod/core/sources/delta_table_source.py b/src/orcapod/core/sources/delta_table_source.py index b5c82d77..9deef416 100644 --- a/src/orcapod/core/sources/delta_table_source.py +++ b/src/orcapod/core/sources/delta_table_source.py @@ -1,18 +1,13 @@ +from __future__ import annotations + from collections.abc import Collection +from pathlib import Path from typing import TYPE_CHECKING, Any - -from orcapod.core.streams import TableStream -from orcapod.protocols import core_protocols as cp -from orcapod.types import PathLike, PythonSchema +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder +from orcapod.types import PathLike from orcapod.utils.lazy_module import LazyModule -from pathlib import Path - - -from orcapod.core.sources.base import SourceBase -from orcapod.core.sources.source_registry import GLOBAL_SOURCE_REGISTRY, SourceRegistry -from deltalake import DeltaTable -from deltalake.exceptions import TableNotFoundError if TYPE_CHECKING: import pyarrow as pa @@ -20,181 +15,76 @@ pa = LazyModule("pyarrow") -class DeltaTableSource(SourceBase): - """Source that generates streams from a Delta table.""" +class DeltaTableSource(RootSource): + """A source backed by a Delta Lake table. + + The table is read once at construction time using ``deltalake``'s + PyArrow integration. The resulting Arrow table is enriched by + ``SourceStreamBuilder`` (source-info, schema-hash, system tags). + """ def __init__( self, delta_table_path: PathLike, tag_columns: Collection[str] = (), - source_name: str | None = None, - source_registry: SourceRegistry | None = None, - auto_register: bool = True, - **kwargs, - ): - """ - Initialize DeltaTableSource with a Delta table. - - Args: - delta_table_path: Path to the Delta table - source_name: Name for this source (auto-generated if None) - tag_columns: Column names to use as tags vs packet data - source_registry: Registry to register with (uses global if None) - auto_register: Whether to auto-register this source - """ - super().__init__(**kwargs) - - # Normalize path - self._delta_table_path = Path(delta_table_path).resolve() - - # Try to open the Delta table - try: - self._delta_table = DeltaTable(str(self._delta_table_path)) - except TableNotFoundError: - raise ValueError(f"Delta table not found at {self._delta_table_path}") - - # Generate source name if not provided - if source_name is None: - source_name = self._delta_table_path.name - - self._source_name = source_name - self._tag_columns = tuple(tag_columns) - self._cached_table_stream: TableStream | None = None - - # Auto-register with global registry - if auto_register: - registry = source_registry or GLOBAL_SOURCE_REGISTRY - registry.register(self.source_id, self) - - @property - def reference(self) -> tuple[str, ...]: - """Reference tuple for this source.""" - return ("delta_table", self._source_name) - - def source_identity_structure(self) -> Any: - """ - Identity structure for this source - includes path and modification info. - This changes when the underlying Delta table changes. - """ - # Get Delta table version for change detection - table_version = self._delta_table.version() + system_tag_columns: Collection[str] = (), + record_id_column: str | None = None, + source_id: str | None = None, + **kwargs: Any, + ) -> None: + from deltalake import DeltaTable + from deltalake.exceptions import TableNotFoundError - return { - "class": self.__class__.__name__, - "path": str(self._delta_table_path), - "version": table_version, - "tag_columns": self._tag_columns, - } + resolved = Path(delta_table_path).resolve() + + if source_id is None: + source_id = resolved.name + super().__init__(source_id=source_id, **kwargs) + + self._delta_table_path = resolved - def validate_inputs(self, *streams: cp.Stream) -> None: - """Delta table sources don't take input streams.""" - if len(streams) > 0: - raise ValueError( - f"DeltaTableSource doesn't accept input streams, got {len(streams)}" - ) - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Return tag and packet types based on Delta table schema.""" - # Create a sample stream to get types - return self.forward().types(include_system_tags=include_system_tags) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Generate stream from Delta table data. - - Returns: - TableStream containing all data from the Delta table - """ - if self._cached_table_stream is None: - # Refresh table to get latest data - self._refresh_table() - - # Load table data - table_data = self._delta_table.to_pyarrow_dataset( - as_large_types=True - ).to_table() - - self._cached_table_stream = TableStream( - table=table_data, - tag_columns=self._tag_columns, - source=self, - ) - return self._cached_table_stream - - def _refresh_table(self) -> None: - """Refresh the Delta table to get latest version.""" try: - # Create fresh Delta table instance to get latest data - self._delta_table = DeltaTable(str(self._delta_table_path)) - except Exception as e: - # If refresh fails, log but continue with existing table - import logging + delta_table = DeltaTable(str(resolved)) + except TableNotFoundError: + raise ValueError(f"Delta table not found at {resolved}") - logger = logging.getLogger(__name__) - logger.warning( - f"Failed to refresh Delta table {self._delta_table_path}: {e}" - ) + table: pa.Table = delta_table.to_pyarrow_dataset(as_large_types=True).to_table() - def get_table_info(self) -> dict[str, Any]: - """Get metadata about the Delta table.""" - self._refresh_table() + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + table, + tag_columns=tag_columns, + source_id=self._source_id, + record_id_column=record_id_column, + system_tag_columns=system_tag_columns, + ) - schema = self._delta_table.schema() - history = self._delta_table.history() + self._stream = result.stream + self._tag_columns = result.tag_columns + self._system_tag_columns = result.system_tag_columns + self._record_id_column = record_id_column + if self._source_id is None: + self._source_id = result.source_id + def to_config(self) -> dict[str, Any]: + """Serialize this source's configuration to a JSON-compatible dict.""" return { - "path": str(self._delta_table_path), - "version": self._delta_table.version(), - "schema": schema, - "num_files": len(self._delta_table.files()), - "tag_columns": self._tag_columns, - "latest_commit": history[0] if history else None, + "source_type": "delta_table", + "delta_table_path": str(self._delta_table_path), + "tag_columns": list(self._tag_columns), + "system_tag_columns": list(self._system_tag_columns), + "record_id_column": self._record_id_column, + "source_id": self.source_id, + **self._identity_config(), } - def resolve_field(self, collection_id: str, record_id: str, field_name: str) -> Any: - """ - Resolve a specific field value from source field reference. - - For Delta table sources: - - collection_id: Not used (single table) - - record_id: Row identifier (implementation dependent) - - field_name: Column name - """ - # This is a basic implementation - you might want to add more sophisticated - # record identification based on your needs - - # For now, assume record_id is a row index - try: - row_index = int(record_id) - table_data = self._delta_table.to_pyarrow_dataset( - as_large_types=True - ).to_table() - - if row_index >= table_data.num_rows: - raise ValueError( - f"Record ID {record_id} out of range (table has {table_data.num_rows} rows)" - ) - - if field_name not in table_data.column_names: - raise ValueError( - f"Field '{field_name}' not found in table columns: {table_data.column_names}" - ) - - return table_data[field_name][row_index].as_py() - - except ValueError as e: - if "invalid literal for int()" in str(e): - raise ValueError( - f"Record ID must be numeric for DeltaTableSource, got: {record_id}" - ) - raise - - def __repr__(self) -> str: - return ( - f"DeltaTableSource(path={self._delta_table_path}, name={self._source_name})" + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DeltaTableSource": + """Reconstruct a DeltaTableSource from a config dict.""" + return cls( + delta_table_path=config["delta_table_path"], + tag_columns=config.get("tag_columns", ()), + system_tag_columns=config.get("system_tag_columns", ()), + record_id_column=config.get("record_id_column"), + source_id=config.get("source_id"), ) - - def __str__(self) -> str: - return f"DeltaTableSource:{self._source_name}" diff --git a/src/orcapod/core/sources/derived_source.py b/src/orcapod/core/sources/derived_source.py new file mode 100644 index 00000000..a152c364 --- /dev/null +++ b/src/orcapod/core/sources/derived_source.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from orcapod.core.sources.base import RootSource +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.types import ColumnConfig, Schema +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.core.nodes import FunctionNode, OperatorNode +else: + pa = LazyModule("pyarrow") + + +class DerivedSource(RootSource): + """ + A static stream backed by the computed records of a DB-backed stream node. + + Created by ``FunctionNode.as_source()`` or ``OperatorNode.as_source()``, + this source reads from the pipeline database, presenting the computed + results as an immutable stream usable as input to downstream processing. + + The origin must implement ``get_all_records()``, ``output_schema()``, + ``keys()``, and ``content_hash()``. + + Identity + -------- + - ``content_hash``: tied to the specific origin node's content hash — + unique to this exact computation. + - ``pipeline_hash``: inherited from RootSource — schema-only, so multiple + DerivedSources with identical schemas share the same pipeline DB table. + + Usage + ----- + If the origin has not been run yet, the DerivedSource will present an + empty stream (zero rows) with the correct schema. After ``origin.run()``, + it reflects the computed records. + """ + + def __init__( + self, + origin: "FunctionNode | OperatorNode", + source_id: str | None = None, + **kwargs: Any, + ) -> None: + if source_id is None: + origin_hash = origin.content_hash().to_string()[:16] + source_id = f"derived:{origin_hash}" + super().__init__(source_id=source_id, **kwargs) + self._origin = origin + self._cached_table: pa.Table | None = None + + def identity_structure(self) -> Any: + # Tied precisely to the specific node's data identity + return (self._origin.content_hash(),) + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + return self._origin.output_schema(columns=columns, all_info=all_info) + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + return self._origin.keys(columns=columns, all_info=all_info) + + def _get_stream(self) -> ArrowTableStream: + if self._cached_table is None: + records = self._origin.get_all_records() + if records is None: + # Build empty table with correct schema + tag_schema, packet_schema = self._origin.output_schema() + tag_keys = self._origin.keys()[0] + tc = self.data_context.type_converter + fields = [ + pa.field(k, tc.python_type_to_arrow_type(tag_schema[k])) + for k in tag_keys + ] + fields += [ + pa.field(k, tc.python_type_to_arrow_type(v)) + for k, v in packet_schema.items() + ] + arrow_schema = pa.schema(fields) + self._cached_table = pa.table( + {f.name: pa.array([], type=f.type) for f in arrow_schema}, + schema=arrow_schema, + ) + else: + self._cached_table = records + tag_keys = self._origin.keys()[0] + return ArrowTableStream(self._cached_table, tag_columns=tag_keys) + + def iter_packets(self): + return self._get_stream().iter_packets() + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": + return self._get_stream().as_table(columns=columns, all_info=all_info) diff --git a/src/orcapod/core/sources/dict_source.py b/src/orcapod/core/sources/dict_source.py index d291b3ff..77f28673 100644 --- a/src/orcapod/core/sources/dict_source.py +++ b/src/orcapod/core/sources/dict_source.py @@ -1,113 +1,67 @@ -from collections.abc import Collection, Mapping -from typing import TYPE_CHECKING, Any - - -from orcapod.protocols import core_protocols as cp -from orcapod.types import DataValue, PythonSchema, PythonSchemaLike -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.system_constants import constants -from orcapod.core.sources.arrow_table_source import ArrowTableSource - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - -from orcapod.core.sources.base import SourceBase +from __future__ import annotations +from collections.abc import Collection, Mapping +from typing import Any -def add_source_field( - record: dict[str, DataValue], source_info: str -) -> dict[str, DataValue]: - """Add source information to a record.""" - # for all "regular" fields, add source info - for key in record.keys(): - if not key.startswith(constants.META_PREFIX) and not key.startswith( - constants.DATAGRAM_PREFIX - ): - record[f"{constants.SOURCE_PREFIX}{key}"] = f"{source_info}:{key}" - return record - - -def split_fields_with_prefixes( - record, prefixes: Collection[str] -) -> tuple[dict[str, DataValue], dict[str, DataValue]]: - """Split fields in a record into two dictionaries based on prefixes.""" - matching = {} - non_matching = {} - for key, value in record.items(): - if any(key.startswith(prefix) for prefix in prefixes): - matching[key] = value - else: - non_matching[key] = value - return matching, non_matching - +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder +from orcapod.types import DataValue, SchemaLike -def split_system_columns( - data: list[dict[str, DataValue]], -) -> tuple[list[dict[str, DataValue]], list[dict[str, DataValue]]]: - system_columns: list[dict[str, DataValue]] = [] - non_system_columns: list[dict[str, DataValue]] = [] - for record in data: - sys_cols, non_sys_cols = split_fields_with_prefixes( - record, [constants.META_PREFIX, constants.DATAGRAM_PREFIX] - ) - system_columns.append(sys_cols) - non_system_columns.append(non_sys_cols) - return system_columns, non_system_columns +class DictSource(RootSource): + """A source backed by a collection of Python dictionaries. -class DictSource(SourceBase): - """Construct source from a collection of dictionaries""" + Each dict becomes one (tag, packet) pair in the stream. The dicts are + converted to an Arrow table via the data-context type converter, then + enriched by ``SourceStreamBuilder`` (source-info, schema-hash, system tags). + """ def __init__( self, data: Collection[Mapping[str, DataValue]], tag_columns: Collection[str] = (), system_tag_columns: Collection[str] = (), - source_name: str | None = None, - data_schema: PythonSchemaLike | None = None, - **kwargs, - ): - super().__init__(**kwargs) + data_schema: SchemaLike | None = None, + source_id: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(source_id=source_id, **kwargs) + arrow_table = self.data_context.type_converter.python_dicts_to_arrow_table( - [dict(e) for e in data], python_schema=data_schema + [dict(row) for row in data], + python_schema=data_schema, ) - self._table_source = ArrowTableSource( + + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( arrow_table, tag_columns=tag_columns, - source_name=source_name, + source_id=self._source_id, system_tag_columns=system_tag_columns, ) - @property - def reference(self) -> tuple[str, ...]: - # TODO: provide more thorough implementation - return ("dict",) + self._table_source.reference[1:] - - def source_identity_structure(self) -> Any: - return self._table_source.source_identity_structure() - - def get_all_records( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - return self._table_source.get_all_records( - include_system_columns=include_system_columns - ) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - """ - Load data from file and return a static stream. - - This is called by forward() and creates a fresh snapshot each time. + self._stream = result.stream + self._tag_columns = result.tag_columns + if self._source_id is None: + self._source_id = result.source_id + + def to_config(self) -> dict[str, Any]: + """Serialize metadata-only config (data is not serializable).""" + return { + "source_type": "dict", + "tag_columns": list(self._tag_columns), + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DictSource": + """Not supported — DictSource data cannot be reconstructed from config. + + Raises: + NotImplementedError: Always. """ - return self._table_source.forward(*streams) - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Return tag and packet types based on provided typespecs.""" - # TODO: add system tag - return self._table_source.source_output_types( - include_system_tags=include_system_tags + raise NotImplementedError( + "DictSource cannot be reconstructed from config — " + "original data is not serializable." ) diff --git a/src/orcapod/core/sources/list_source.py b/src/orcapod/core/sources/list_source.py index fdc7ffa0..0350f434 100644 --- a/src/orcapod/core/sources/list_source.py +++ b/src/orcapod/core/sources/list_source.py @@ -1,187 +1,133 @@ -from collections.abc import Callable, Collection, Iterator -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, cast - -from deltalake import DeltaTable, write_deltalake -from pyarrow.lib import Table - -from orcapod.core.datagrams import DictTag -from orcapod.core.kernels import TrackedKernelBase -from orcapod.core.streams import ( - TableStream, - KernelStream, - StatefulStreamBase, -) -from orcapod.errors import DuplicateTagError -from orcapod.protocols import core_protocols as cp -from orcapod.types import DataValue, PythonSchema -from orcapod.utils import arrow_utils -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.system_constants import constants -from orcapod.semantic_types import infer_python_schema_from_pylist_data +from __future__ import annotations + +from collections.abc import Callable, Collection +from typing import TYPE_CHECKING, Any, Literal + +from orcapod.core.sources.base import RootSource +from orcapod.core.sources.stream_builder import SourceStreamBuilder +from orcapod.protocols.core_protocols import TagProtocol if TYPE_CHECKING: - import pandas as pd - import polars as pl - import pyarrow as pa -else: - pl = LazyModule("polars") - pd = LazyModule("pandas") - pa = LazyModule("pyarrow") + pass -from orcapod.core.sources.base import SourceBase +class ListSource(RootSource): + """A source backed by a Python list. -class ListSource(SourceBase): - """ - A stream source that sources data from a list of elements. - For each element in the list, yields a tuple containing: - - A tag generated either by the provided tag_function or defaulting to the element index - - A packet containing the element under the provided name key - Parameters - ---------- - name : str - The key name under which each list element will be stored in the packet - data : list[Any] - The list of elements to source data from - tag_function : Callable[[Any, int], Tag] | None, default=None - Optional function to generate a tag from a list element and its index. - The function receives the element and the index as arguments. - If None, uses the element index in a dict with key 'element_index' - tag_function_hash_mode : Literal["content", "signature", "name"], default="name" - How to hash the tag function for identity purposes - expected_tag_keys : Collection[str] | None, default=None - Expected tag keys for the stream - label : str | None, default=None - Optional label for the source - Examples - -------- - >>> # Simple list of file names - >>> file_list = ['/path/to/file1.txt', '/path/to/file2.txt', '/path/to/file3.txt'] - >>> source = ListSource('file_path', file_list) - >>> - >>> # Custom tag function using filename stems - >>> from pathlib import Path - >>> source = ListSource( - ... 'file_path', - ... file_list, - ... tag_function=lambda elem, idx: {'file_name': Path(elem).stem} - ... ) - >>> - >>> # List of sample IDs - >>> samples = ['sample_001', 'sample_002', 'sample_003'] - >>> source = ListSource( - ... 'sample_id', - ... samples, - ... tag_function=lambda elem, idx: {'sample': elem} - ... ) + Each element in the list becomes one (tag, packet) pair. The element is + stored as the packet under ``name``; the tag is either the element's index + (default) or the dict returned by ``tag_function(element, index)``. """ @staticmethod - def default_tag_function(element: Any, idx: int) -> cp.Tag: - return DictTag({"element_index": idx}) + def _default_tag(element: Any, idx: int) -> dict[str, Any]: + return {"element_index": idx} def __init__( self, name: str, data: list[Any], - tag_function: Callable[[Any, int], cp.Tag] | None = None, - label: str | None = None, - tag_function_hash_mode: Literal["content", "signature", "name"] = "name", + tag_function: Callable[[Any, int], dict[str, Any] | TagProtocol] | None = None, expected_tag_keys: Collection[str] | None = None, - **kwargs, + tag_function_hash_mode: Literal["content", "signature", "name"] = "name", + source_id: str | None = None, + **kwargs: Any, ) -> None: - super().__init__(label=label, **kwargs) + super().__init__(source_id=source_id, **kwargs) + self.name = name - self.elements = list(data) # Create a copy to avoid external modifications + self._elements = list(data) + self._tag_function_hash_mode = tag_function_hash_mode if tag_function is None: - tag_function = self.__class__.default_tag_function - # If using default tag function and no explicit expected_tag_keys, set to default + tag_function = self.__class__._default_tag if expected_tag_keys is None: expected_tag_keys = ["element_index"] - self.expected_tag_keys = expected_tag_keys - self.tag_function = tag_function - self.tag_function_hash_mode = tag_function_hash_mode - - def forward(self, *streams: SyncStream) -> SyncStream: - if len(streams) != 0: - raise ValueError( - "ListSource does not support forwarding streams. " - "It generates its own stream from the list elements." - ) - - def generator() -> Iterator[tuple[Tag, Packet]]: - for idx, element in enumerate(self.elements): - tag = self.tag_function(element, idx) - packet = {self.name: element} - yield tag, packet - - return SyncStreamFromGenerator(generator) - - def __repr__(self) -> str: - return f"ListSource({self.name}, {len(self.elements)} elements)" - - def identity_structure(self, *streams: SyncStream) -> Any: - hash_function_kwargs = {} - if self.tag_function_hash_mode == "content": - # if using content hash, exclude few - hash_function_kwargs = { - "include_name": False, - "include_module": False, - "include_declaration": False, - } - - tag_function_hash = hash_function( - self.tag_function, - function_hash_mode=self.tag_function_hash_mode, - hash_kwargs=hash_function_kwargs, + self._tag_function = tag_function + self._expected_tag_keys = ( + tuple(expected_tag_keys) if expected_tag_keys is not None else None + ) + + # Hash the tag function for identity purposes. + self._tag_function_hash = self._hash_tag_function() + + # Build rows: each row is tag_fields merged with {name: element}. + rows = [] + for idx, element in enumerate(self._elements): + tag_fields = tag_function(element, idx) + if hasattr(tag_fields, "as_dict"): + tag_fields = tag_fields.as_dict() + row = dict(tag_fields) + row[name] = element + rows.append(row) + + tag_columns = ( + list(self._expected_tag_keys) + if self._expected_tag_keys is not None + else [k for k in (rows[0].keys() if rows else []) if k != name] + ) + + arrow_table = self.data_context.type_converter.python_dicts_to_arrow_table(rows) + + builder = SourceStreamBuilder(self.data_context, self.orcapod_config) + result = builder.build( + arrow_table, + tag_columns=tag_columns, + source_id=self._source_id, ) - # Convert list to hashable representation - # Handle potentially unhashable elements by converting to string + self._stream = result.stream + if self._source_id is None: + self._source_id = result.source_id + + def _hash_tag_function(self) -> str: + """Produce a stable hash string for the tag function.""" + if self._tag_function_hash_mode == "name": + fn = self._tag_function + return f"{fn.__module__}.{fn.__qualname__}" + elif self._tag_function_hash_mode == "signature": + import inspect + + return str(inspect.signature(self._tag_function)) + else: # "content" + import inspect + + src = inspect.getsource(self._tag_function) + return self.data_context.semantic_hasher.hash_object(src).to_hex() + + def identity_structure(self) -> Any: + """Return identity including class name, field name, elements, and tag + function hash. + """ try: - elements_hashable = tuple(self.elements) + elements_repr: Any = tuple(self._elements) except TypeError: - # If elements are not hashable, convert to string representation - elements_hashable = tuple(str(elem) for elem in self.elements) - + elements_repr = tuple(str(e) for e in self._elements) return ( self.__class__.__name__, self.name, - elements_hashable, - tag_function_hash, - ) + tuple(streams) + elements_repr, + self._tag_function_hash, + ) - def keys( - self, *streams: SyncStream, trigger_run: bool = False - ) -> tuple[Collection[str] | None, Collection[str] | None]: + def to_config(self) -> dict[str, Any]: + """Serialize metadata-only config (data is not serializable).""" + return { + "source_type": "list", + "name": self.name, + "source_id": self.source_id, + **self._identity_config(), + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "ListSource": + """Not supported — ListSource data cannot be reconstructed from config. + + Raises: + NotImplementedError: Always. """ - Returns the keys of the stream. The keys are the names of the packets - in the stream. The keys are used to identify the packets in the stream. - If expected_keys are provided, they will be used instead of the default keys. - """ - if len(streams) != 0: - raise ValueError( - "ListSource does not support forwarding streams. " - "It generates its own stream from the list elements." - ) - - if self.expected_tag_keys is not None: - return tuple(self.expected_tag_keys), (self.name,) - return super().keys(trigger_run=trigger_run) - - def claims_unique_tags( - self, *streams: "SyncStream", trigger_run: bool = True - ) -> bool | None: - if len(streams) != 0: - raise ValueError( - "ListSource does not support forwarding streams. " - "It generates its own stream from the list elements." - ) - # Claim uniqueness only if the default tag function is used - if self.tag_function == self.__class__.default_tag_function: - return True - # Otherwise, delegate to the base class - return super().claims_unique_tags(trigger_run=trigger_run) + raise NotImplementedError( + "ListSource cannot be reconstructed from config — " + "original list data is not serializable." + ) diff --git a/src/orcapod/core/sources/manual_table_source.py b/src/orcapod/core/sources/manual_table_source.py deleted file mode 100644 index ba365ecc..00000000 --- a/src/orcapod/core/sources/manual_table_source.py +++ /dev/null @@ -1,367 +0,0 @@ -from collections.abc import Collection -from pathlib import Path -from typing import TYPE_CHECKING, Any, cast - -from deltalake import DeltaTable, write_deltalake -from deltalake.exceptions import TableNotFoundError - -from orcapod.core.sources.source_registry import SourceRegistry -from orcapod.core.streams import TableStream -from orcapod.errors import DuplicateTagError -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema, PythonSchemaLike -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import pandas as pd - import polars as pl - import pyarrow as pa -else: - pl = LazyModule("polars") - pd = LazyModule("pandas") - pa = LazyModule("pyarrow") - -from orcapod.core.sources.base import SourceBase - - -class ManualDeltaTableSource(SourceBase): - """ - A source that allows manual delta updates to a table. - This is useful for testing and debugging purposes. - - Supports duplicate tag handling: - - skip_duplicates=True: Use merge operation to only insert new tag combinations - - skip_duplicates=False: Raise error if duplicate tags would be created - """ - - def __init__( - self, - table_path: str | Path, - python_schema: PythonSchemaLike | None = None, - tag_columns: Collection[str] | None = None, - source_name: str | None = None, - source_registry: SourceRegistry | None = None, - **kwargs, - ) -> None: - """ - Initialize the ManualDeltaTableSource with a label and optional data context. - """ - super().__init__(**kwargs) - - if source_name is None: - source_name = Path(table_path).name - - self._source_name = source_name - - self.table_path = Path(table_path) - self._delta_table: DeltaTable | None = None - self.load_delta_table() - - if self._delta_table is None: - if python_schema is None: - raise ValueError( - "Delta table not found and no schema provided. " - "Please provide a valid Delta table path or a schema to create a new table." - ) - if tag_columns is None: - raise ValueError( - "At least one tag column must be provided when creating a new Delta table." - ) - arrow_schema = ( - self.data_context.type_converter.python_schema_to_arrow_schema( - python_schema - ) - ) - - fields = [] - for field in arrow_schema: - if field.name in tag_columns: - field = field.with_metadata({b"tag": b"True"}) - fields.append(field) - arrow_schema = pa.schema(fields) - - else: - arrow_schema = pa.schema(self._delta_table.schema().to_arrow()) - python_schema = ( - self.data_context.type_converter.arrow_schema_to_python_schema( - arrow_schema - ) - ) - - inferred_tag_columns = [] - for field in arrow_schema: - if ( - field.metadata is not None - and field.metadata.get(b"tag", b"False").decode().lower() == "true" - ): - inferred_tag_columns.append(field.name) - tag_columns = tag_columns or inferred_tag_columns - - self.python_schema = python_schema - self.arrow_schema = arrow_schema - self.tag_columns = list(tag_columns) if tag_columns else [] - - @property - def reference(self) -> tuple[str, ...]: - return ("manual_delta", self._source_name) - - @property - def delta_table_version(self) -> int | None: - """ - Return the version of the delta table. - If the table does not exist, return None. - """ - if self._delta_table is not None: - return self._delta_table.version() - return None - - def forward(self, *streams: cp.Stream) -> cp.Stream: - """Load current delta table data as a stream.""" - if len(streams) > 0: - raise ValueError("ManualDeltaTableSource takes no input streams") - - if self._delta_table is None: - arrow_data = pa.Table.from_pylist([], schema=self.arrow_schema) - else: - arrow_data = self._delta_table.to_pyarrow_dataset( - as_large_types=True - ).to_table() - - return TableStream( - arrow_data, tag_columns=self.tag_columns, source=self, upstreams=() - ) - - def source_identity_structure(self) -> Any: - """ - Return the identity structure of the kernel. - This is a unique identifier for the kernel based on its class name and table path. - """ - return (self.__class__.__name__, str(self.table_path)) - - def source_output_types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """Return tag and packet types based on schema and tag columns.""" - # TODO: auto add system entry tag - tag_types: PythonSchema = {} - packet_types: PythonSchema = {} - for field, field_type in self.python_schema.items(): - if field in self.tag_columns: - tag_types[field] = field_type - else: - packet_types[field] = field_type - return tag_types, packet_types - - def get_all_records(self, include_system_columns: bool = False) -> pa.Table | None: - """Get all records from the delta table.""" - if self._delta_table is None: - return None - - arrow_data = self._delta_table.to_pyarrow_dataset( - as_large_types=True - ).to_table() - - if not include_system_columns: - arrow_data = arrow_data.drop( - [col for col in arrow_data.column_names if col.startswith("_")] - ) - return arrow_data - - def _normalize_data_to_table( - self, data: "dict | pa.Table | pl.DataFrame | pd.DataFrame" - ) -> pa.Table: - """Convert input data to PyArrow Table with correct schema.""" - if isinstance(data, dict): - return pa.Table.from_pylist([data], schema=self.arrow_schema) - elif isinstance(data, pa.Table): - return data - else: - # Handle polars/pandas DataFrames - if hasattr(data, "to_arrow"): # Polars DataFrame - return data.to_arrow() # type: ignore - elif hasattr(data, "to_pandas"): # Polars to pandas fallback - return pa.Table.from_pandas(data.to_pandas(), schema=self.arrow_schema) # type: ignore - else: # Assume pandas DataFrame - return pa.Table.from_pandas( - cast(pd.DataFrame, data), schema=self.arrow_schema - ) - - def _check_for_duplicates(self, new_data: pa.Table) -> None: - """ - Check if new data contains tag combinations that already exist. - Raises DuplicateTagError if duplicates found. - """ - if self._delta_table is None or not self.tag_columns: - return # No existing data or no tag columns to check - - # Get existing tag combinations - existing_data = self._delta_table.to_pyarrow_dataset( - as_large_types=True - ).to_table() - if len(existing_data) == 0: - return # No existing data - - # Extract tag combinations from existing data - existing_tags = existing_data.select(self.tag_columns) - new_tags = new_data.select(self.tag_columns) - - # Convert to sets of tuples for comparison - existing_tag_tuples = set() - for i in range(len(existing_tags)): - tag_tuple = tuple( - existing_tags.column(col)[i].as_py() for col in self.tag_columns - ) - existing_tag_tuples.add(tag_tuple) - - # Check for duplicates in new data - duplicate_tags = [] - for i in range(len(new_tags)): - tag_tuple = tuple( - new_tags.column(col)[i].as_py() for col in self.tag_columns - ) - if tag_tuple in existing_tag_tuples: - duplicate_tags.append(tag_tuple) - - if duplicate_tags: - tag_names = ", ".join(self.tag_columns) - duplicate_strs = [str(tags) for tags in duplicate_tags] - raise DuplicateTagError( - f"Duplicate tag combinations found for columns [{tag_names}]: " - f"{duplicate_strs}. Use skip_duplicates=True to merge instead." - ) - - def _merge_data(self, new_data: pa.Table) -> None: - """ - Merge new data using Delta Lake merge operation. - Only inserts rows where tag combinations don't already exist. - """ - if self._delta_table is None: - # No existing table, just write the data - write_deltalake( - self.table_path, - new_data, - mode="overwrite", - ) - else: - # Use merge operation - only insert if tag combination doesn't exist - # Build merge condition based on tag columns - # Format: "target.col1 = source.col1 AND target.col2 = source.col2" - merge_conditions = " AND ".join( - f"target.{col} = source.{col}" for col in self.tag_columns - ) - - try: - # Use Delta Lake's merge functionality - ( - self._delta_table.merge( - source=new_data, - predicate=merge_conditions, - source_alias="source", - target_alias="target", - ) - .when_not_matched_insert_all() # Insert when no match found - .execute() - ) - except Exception: - # Fallback: manual duplicate filtering if merge fails - self._manual_merge_fallback(new_data) - - def _manual_merge_fallback(self, new_data: pa.Table) -> None: - """ - Fallback merge implementation that manually filters duplicates. - """ - if self._delta_table is None or not self.tag_columns: - write_deltalake(self.table_path, new_data, mode="append") - return - - # Get existing tag combinations - existing_data = self._delta_table.to_pyarrow_dataset( - as_large_types=True - ).to_table() - existing_tags = existing_data.select(self.tag_columns) - - # Create set of existing tag tuples - existing_tag_tuples = set() - for i in range(len(existing_tags)): - tag_tuple = tuple( - existing_tags.column(col)[i].as_py() for col in self.tag_columns - ) - existing_tag_tuples.add(tag_tuple) - - # Filter new data to only include non-duplicate rows - filtered_rows = [] - new_tags = new_data.select(self.tag_columns) - - for i in range(len(new_data)): - tag_tuple = tuple( - new_tags.column(col)[i].as_py() for col in self.tag_columns - ) - if tag_tuple not in existing_tag_tuples: - # Extract this row - row_dict = {} - for col_name in new_data.column_names: - row_dict[col_name] = new_data.column(col_name)[i].as_py() - filtered_rows.append(row_dict) - - # Only append if there are new rows to add - if filtered_rows: - filtered_table = pa.Table.from_pylist( - filtered_rows, schema=self.arrow_schema - ) - write_deltalake(self.table_path, filtered_table, mode="append") - - def insert( - self, - data: "dict | pa.Table | pl.DataFrame | pd.DataFrame", - skip_duplicates: bool = False, - ) -> None: - """ - Insert data into the delta table. - - Args: - data: Data to insert (dict, PyArrow Table, Polars DataFrame, or Pandas DataFrame) - skip_duplicates: If True, use merge operation to skip duplicate tag combinations. - If False, raise error if duplicate tag combinations are found. - - Raises: - DuplicateTagError: If skip_duplicates=False and duplicate tag combinations are found. - """ - # Normalize data to PyArrow Table - new_data_table = self._normalize_data_to_table(data) - - if skip_duplicates: - # Use merge operation to only insert new tag combinations - self._merge_data(new_data_table) - else: - # Check for duplicates first, raise error if found - self._check_for_duplicates(new_data_table) - - # No duplicates found, safe to append - write_deltalake(self.table_path, new_data_table, mode="append") - - # Update our delta table reference and mark as modified - self._set_modified_time() - self._delta_table = DeltaTable(self.table_path) - - # Invalidate any cached streams - self.invalidate() - - def load_delta_table(self) -> None: - """ - Try loading the delta table from the file system. - """ - current_version = self.delta_table_version - try: - delta_table = DeltaTable(self.table_path) - except TableNotFoundError: - delta_table = None - - if delta_table is not None: - new_version = delta_table.version() - if (current_version is None) or ( - current_version is not None and new_version > current_version - ): - # Delta table has been updated - self._set_modified_time() - - self._delta_table = delta_table diff --git a/src/orcapod/core/sources/source_proxy.py b/src/orcapod/core/sources/source_proxy.py new file mode 100644 index 00000000..f541db70 --- /dev/null +++ b/src/orcapod/core/sources/source_proxy.py @@ -0,0 +1,223 @@ +"""Proxy source that preserves identity hashes without requiring live data. + +A ``SourceProxy`` stands in for a source that cannot be reconstructed from +its serialized config (e.g. an in-memory ``ArrowTableSource`` or ``DictSource``). +It returns the same ``content_hash``, ``pipeline_hash``, ``source_id``, and +``output_schema`` as the original, so downstream hash chains remain +consistent. Any attempt to iterate or materialize data raises an error +unless a live source has been bound via :meth:`bind`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from orcapod.core.sources.base import RootSource +from orcapod.protocols.core_protocols import SourceProtocol +from orcapod.types import ContentHash, Schema + +if TYPE_CHECKING: + from collections.abc import Iterator + + import pyarrow as pa + + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + from orcapod.types import ColumnConfig + + +class SourceProxy(RootSource): + """A proxy source that preserves identity and optionally delegates. + + When created without a bound source, ``SourceProxy`` returns stored hashes + and schemas but raises on data access. A live source can be substituted + in later via :meth:`bind` — if the source's identity matches, all data + methods delegate to it. + + Args: + source_id: The original source's canonical ID. + content_hash_str: The original source's content hash string. + pipeline_hash_str: The original source's pipeline hash string. + tag_schema: The original source's tag schema. + packet_schema: The original source's packet schema. + expected_class_name: Class name of the original source (e.g. + ``"ArrowTableSource"``). Informational and used for validation + in :meth:`bind`. + source_config: The original source's serialized config (for to_config). + label: Optional label. + """ + + def __init__( + self, + source_id: str, + content_hash_str: str, + pipeline_hash_str: str, + tag_schema: Schema, + packet_schema: Schema, + expected_class_name: str | None = None, + source_config: dict[str, Any] | None = None, + label: str | None = None, + ) -> None: + super().__init__(source_id=source_id, label=label) + self._content_hash_str = content_hash_str + self._pipeline_hash_str = pipeline_hash_str + self._tag_schema = tag_schema + self._packet_schema = packet_schema + self._expected_class_name = expected_class_name + self._source_config = source_config or {} + self._delegate: SourceProtocol | None = None + + # ------------------------------------------------------------------------- + # Binding a live source + # ------------------------------------------------------------------------- + + @property + def expected_class_name(self) -> str | None: + """The class name of the original source this proxy stands in for.""" + return self._expected_class_name + + @property + def delegate(self) -> SourceProtocol | None: + """The bound live source, or ``None`` if no source has been bound.""" + return self._delegate + + @property + def is_bound(self) -> bool: + """``True`` if a live source has been bound via :meth:`bind`.""" + return self._delegate is not None + + def bind(self, source: SourceProtocol) -> None: + """Bind a live source to this proxy, enabling data access. + + The source must match this proxy's identity — same ``source_id``, + ``content_hash``, ``pipeline_hash``, tag schema keys, and packet + schema keys. If ``expected_class_name`` is set, the source's class + name must also match. + + Args: + source: The live source to bind. + + Raises: + ValueError: If the source's identity does not match. + """ + errors: list[str] = [] + + if source.source_id != self.source_id: + errors.append( + f"source_id mismatch: expected {self.source_id!r}, " + f"got {source.source_id!r}" + ) + + if source.content_hash().to_string() != self._content_hash_str: + errors.append( + f"content_hash mismatch: expected {self._content_hash_str!r}, " + f"got {source.content_hash().to_string()!r}" + ) + + if source.pipeline_hash().to_string() != self._pipeline_hash_str: + errors.append( + f"pipeline_hash mismatch: expected {self._pipeline_hash_str!r}, " + f"got {source.pipeline_hash().to_string()!r}" + ) + + if ( + self._expected_class_name is not None + and source.__class__.__name__ != self._expected_class_name + ): + errors.append( + f"class mismatch: expected {self._expected_class_name!r}, " + f"got {source.__class__.__name__!r}" + ) + + if errors: + raise ValueError(f"Cannot bind source to SourceProxy: {'; '.join(errors)}") + + self._delegate = source + + def unbind(self) -> SourceProtocol | None: + """Remove and return the bound source, reverting to proxy behavior. + + Returns: + The previously bound source, or ``None`` if none was bound. + """ + delegate = self._delegate + self._delegate = None + return delegate + + # ------------------------------------------------------------------------- + # Identity — return stored hashes (always, regardless of delegate) + # ------------------------------------------------------------------------- + + def identity_structure(self) -> Any: + return ContentHash.from_string(self._content_hash_str) + + def pipeline_identity_structure(self) -> Any: + return ContentHash.from_string(self._pipeline_hash_str) + + # ------------------------------------------------------------------------- + # Schema — return stored schemas (always, regardless of delegate) + # ------------------------------------------------------------------------- + + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + return (self._tag_schema, self._packet_schema) + + def keys( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[tuple[str, ...], tuple[str, ...]]: + return ( + tuple(self._tag_schema.keys()), + tuple(self._packet_schema.keys()), + ) + + # ------------------------------------------------------------------------- + # Data access — delegate if bound, raise otherwise + # ------------------------------------------------------------------------- + + def _require_delegate(self) -> SourceProtocol: + """Return the delegate or raise if not bound.""" + if self._delegate is None: + raise NotImplementedError( + f"SourceProxy({self.source_id!r}) cannot provide data — " + f"the original source ({self._expected_class_name or 'unknown'}) " + f"was not reconstructable from config. " + f"Use bind() to attach a live source." + ) + return self._delegate + + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + return self._require_delegate().iter_packets() + + def as_table( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> pa.Table: + return self._require_delegate().as_table(columns=columns, all_info=all_info) + + def resolve_field(self, record_id: str, field_name: str) -> Any: + """Delegate to bound source, or raise.""" + return self._require_delegate().resolve_field(record_id, field_name) + + # ------------------------------------------------------------------------- + # Serialization + # ------------------------------------------------------------------------- + + def to_config(self) -> dict[str, Any]: + """Return the original source's config (preserves source_type).""" + return dict(self._source_config) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> SourceProxy: + """Not supported — SourceProxy is created by the deserialization pipeline.""" + raise NotImplementedError( + "SourceProxy cannot be reconstructed via from_config. " + "It is created automatically when a source cannot be loaded." + ) diff --git a/src/orcapod/core/sources/source_registry.py b/src/orcapod/core/sources/source_registry.py index 66f9bf73..309429f2 100644 --- a/src/orcapod/core/sources/source_registry.py +++ b/src/orcapod/core/sources/source_registry.py @@ -1,232 +1,151 @@ +from __future__ import annotations + import logging from collections.abc import Iterator -from orcapod.protocols.core_protocols import Source +from typing import TYPE_CHECKING, Any +if TYPE_CHECKING: + from orcapod.core.sources.base import RootSource logger = logging.getLogger(__name__) -class SourceCollisionError(Exception): - """Raised when attempting to register a source ID that already exists.""" - - pass - - -class SourceNotFoundError(Exception): - """Raised when attempting to access a source that doesn't exist.""" - - pass - - class SourceRegistry: """ - Registry for managing data sources. - - Provides collision detection, source lookup, and management of source lifecycles. + Registry mapping canonical source IDs to live ``RootSource`` objects. + + A source ID is a stable, human-readable name (e.g. ``"delta_table:sales"``) + that is independent of physical location. The registry lets downstream + code resolve a ``source_id`` token embedded in a provenance string back to + the concrete source object that produced it, enabling ``resolve_field`` + calls without requiring a direct reference to the source object. + + Registration behaviour + ---------------------- + - Registering the **same object** under the same ID is idempotent. + - Registering a **different object** under an already-taken ID logs a + warning and skips (rather than raising), so that sources constructed in + different contexts don't crash each other via the global singleton. + - Use ``replace`` when you explicitly want to overwrite an entry. + + The module-level ``GLOBAL_SOURCE_REGISTRY`` is the default registry used + when no explicit registry is provided. """ - def __init__(self): - self._sources: dict[str, Source] = {} + def __init__(self) -> None: + self._sources: dict[str, "RootSource"] = {} - def register(self, source_id: str, source: Source) -> None: - """ - Register a source with the given ID. + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ - Args: - source_id: Unique identifier for the source - source: Source instance to register + def register(self, source_id: str, source: "RootSource") -> None: + """ + Register *source* under *source_id*. - Raises: - SourceCollisionError: If source_id already exists - ValueError: If source_id or source is invalid + If *source_id* is already taken by the same object, the call is a + no-op. If it is taken by a *different* object, a warning is emitted + and the existing entry is left unchanged. """ if not source_id: - raise ValueError("Source ID cannot be empty") - - if not isinstance(source_id, str): - raise ValueError(f"Source ID must be a string, got {type(source_id)}") - + raise ValueError("source_id cannot be empty") if source is None: - raise ValueError("Source cannot be None") + raise ValueError("source cannot be None") - if source_id in self._sources: - existing_source = self._sources[source_id] - if existing_source == source: - # Idempotent - same source already registered + existing = self._sources.get(source_id) + if existing is not None: + if existing is source: logger.debug( - f"Source ID '{source_id}' already registered with the same source instance." + "Source '%s' already registered with the same object — skipping.", + source_id, ) return - raise SourceCollisionError( - f"Source ID '{source_id}' already registered with {type(existing_source).__name__}. " - f"Cannot register {type(source).__name__}. " - f"Choose a different source_id or unregister the existing source first." + logger.warning( + "Source ID '%s' is already registered with a different %s object; " + "keeping the existing registration. Use replace() to overwrite.", + source_id, + type(existing).__name__, ) + return self._sources[source_id] = source - logger.info(f"Registered source: '{source_id}' -> {type(source).__name__}") + logger.debug("Registered source '%s' -> %s", source_id, type(source).__name__) - def get(self, source_id: str) -> Source: + def replace(self, source_id: str, source: "RootSource") -> "RootSource | None": """ - Get a source by ID. - - Args: - source_id: Source identifier - - Returns: - Source instance - - Raises: - SourceNotFoundError: If source doesn't exist + Register *source* under *source_id*, unconditionally replacing any + existing entry. Returns the previous source if one existed. """ - if source_id not in self._sources: - available_ids = list(self._sources.keys()) - raise SourceNotFoundError( - f"Source '{source_id}' not found. Available sources: {available_ids}" + if not source_id: + raise ValueError("source_id cannot be empty") + old = self._sources.get(source_id) + self._sources[source_id] = source + if old is not None and old is not source: + logger.info( + "Replaced source '%s': %s -> %s", + source_id, + type(old).__name__, + type(source).__name__, ) + return old - return self._sources[source_id] - - def get_optional(self, source_id: str) -> Source | None: - """ - Get a source by ID, returning None if not found. - - Args: - source_id: Source identifier - - Returns: - Source instance or None if not found - """ - return self._sources.get(source_id) - - def unregister(self, source_id: str) -> Source: - """ - Unregister a source by ID. - - Args: - source_id: Source identifier - - Returns: - The unregistered source instance - - Raises: - SourceNotFoundError: If source doesn't exist - """ + def unregister(self, source_id: str) -> "RootSource": + """Remove and return the source registered under *source_id*.""" if source_id not in self._sources: - raise SourceNotFoundError(f"Source '{source_id}' not found") - + raise KeyError(f"No source registered under '{source_id}'") source = self._sources.pop(source_id) - logger.info(f"Unregistered source: '{source_id}'") + logger.debug("Unregistered source '%s'", source_id) return source - # TODO: consider just using __contains__ - def contains(self, source_id: str) -> bool: - """Check if a source ID is registered.""" - return source_id in self._sources + # ------------------------------------------------------------------ + # Lookup + # ------------------------------------------------------------------ - def list_sources(self) -> list[str]: - """Get list of all registered source IDs.""" - return list(self._sources.keys()) + def get(self, source_id: str) -> "RootSource": + """Return the source for *source_id*, raising ``KeyError`` if absent.""" + if source_id not in self._sources: + raise KeyError( + f"No source registered under '{source_id}'. " + f"Available: {list(self._sources)}" + ) + return self._sources[source_id] - # TODO: consider removing this - def list_sources_by_type(self, source_type: type) -> list[str]: - """ - Get list of source IDs filtered by source type. + def get_optional(self, source_id: str) -> "RootSource | None": + """Return the source for *source_id*, or ``None`` if not registered.""" + return self._sources.get(source_id) - Args: - source_type: Class type to filter by + # ------------------------------------------------------------------ + # Introspection + # ------------------------------------------------------------------ - Returns: - List of source IDs that match the type - """ - return [ - source_id - for source_id, source in self._sources.items() - if isinstance(source, source_type) - ] + def list_ids(self) -> list[str]: + return list(self._sources) def clear(self) -> None: - """Remove all registered sources.""" count = len(self._sources) self._sources.clear() - logger.info(f"Cleared {count} sources from registry") - - def replace(self, source_id: str, source: Source) -> Source | None: - """ - Replace an existing source or register a new one. - - Args: - source_id: Source identifier - source: New source instance + logger.debug("Cleared %d source(s) from registry", count) - Returns: - Previous source if it existed, None otherwise - """ - old_source = self._sources.get(source_id) - self._sources[source_id] = source + # ------------------------------------------------------------------ + # Dunder helpers + # ------------------------------------------------------------------ - if old_source: - logger.info(f"Replaced source: '{source_id}' -> {type(source).__name__}") - else: - logger.info( - f"Registered new source: '{source_id}' -> {type(source).__name__}" - ) - - return old_source - - def get_source_info(self, source_id: str) -> dict: - """ - Get information about a registered source. - - Args: - source_id: Source identifier - - Returns: - Dictionary with source information - - Raises: - SourceNotFoundError: If source doesn't exist - """ - source = self.get(source_id) # This handles the not found case - - info = { - "source_id": source_id, - "type": type(source).__name__, - "reference": source.reference if hasattr(source, "reference") else None, - } - info["identity"] = source.identity_structure() - - return info + def __contains__(self, source_id: Any) -> bool: + return source_id in self._sources def __len__(self) -> int: - """Return number of registered sources.""" return len(self._sources) - def __contains__(self, source_id: str) -> bool: - """Support 'in' operator for checking source existence.""" - return source_id in self._sources - def __iter__(self) -> Iterator[str]: - """Iterate over source IDs.""" return iter(self._sources) - def items(self) -> Iterator[tuple[str, Source]]: - """Iterate over (source_id, source) pairs.""" + def items(self) -> Iterator[tuple[str, "RootSource"]]: yield from self._sources.items() def __repr__(self) -> str: - return f"SourceRegistry({len(self._sources)} sources)" - - def __str__(self) -> str: - if not self._sources: - return "SourceRegistry(empty)" - - source_summary = [] - for source_id, source in self._sources.items(): - source_summary.append(f" {source_id}: {type(source).__name__}") - - return "SourceRegistry:\n" + "\n".join(source_summary) + return f"SourceRegistry({len(self._sources)} source(s): {list(self._sources)})" -# Global source registry instance -GLOBAL_SOURCE_REGISTRY = SourceRegistry() +# Module-level global singleton — used as the default when no explicit +# registry is passed to DataContextMixin or RootSource. +GLOBAL_SOURCE_REGISTRY: SourceRegistry = SourceRegistry() diff --git a/src/orcapod/core/sources/stream_builder.py b/src/orcapod/core/sources/stream_builder.py new file mode 100644 index 00000000..797ed4ef --- /dev/null +++ b/src/orcapod/core/sources/stream_builder.py @@ -0,0 +1,170 @@ +"""Compositional builder for enriching raw Arrow tables into source streams. + +Extracts the enrichment pipeline that was previously embedded in +``ArrowTableSource.__init__``: dropping system columns, validating tags, +computing schema/table hashes, adding source-info provenance, adding system +tag columns, and wrapping the result in an ``ArrowTableStream``. +""" + +from __future__ import annotations + +from collections.abc import Collection +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.system_constants import constants +from orcapod.types import ContentHash +from orcapod.utils import arrow_data_utils +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.config import Config + from orcapod.contexts import DataContext +else: + pa = LazyModule("pyarrow") + + +def _make_record_id(record_id_column: str | None, row_index: int, row: dict) -> str: + """Build the record-ID token for a single row. + + When *record_id_column* is given the token is ``"{column}={value}"``, + giving a stable, human-readable key that survives row reordering. + When no column is specified the fallback is ``"row_{index}"``. + """ + if record_id_column is not None: + return f"{record_id_column}={row[record_id_column]}" + return f"row_{row_index}" + + +@dataclass(frozen=True) +class SourceStreamResult: + """Artifacts produced by ``SourceStreamBuilder.build()``.""" + + stream: ArrowTableStream + schema_hash: str + table_hash: ContentHash + source_id: str + tag_columns: tuple[str, ...] + system_tag_columns: tuple[str, ...] + + +class SourceStreamBuilder: + """Builds an enriched ``ArrowTableStream`` from a raw Arrow table. + + Args: + data_context: Provides type_converter, semantic_hasher, arrow_hasher. + config: Orcapod config (controls hash character counts). + """ + + def __init__(self, data_context: DataContext, config: Config) -> None: + self._data_context = data_context + self._config = config + + def build( + self, + table: pa.Table, + tag_columns: Collection[str], + source_id: str | None = None, + record_id_column: str | None = None, + system_tag_columns: Collection[str] = (), + ) -> SourceStreamResult: + """Run the full enrichment pipeline. + + Args: + table: Raw Arrow table (system columns will be stripped). + tag_columns: Column names forming the tag for each row. + source_id: Canonical source name. Defaults to table hash. + record_id_column: Column for stable record IDs in provenance. + system_tag_columns: Additional system-level tag columns. + + Returns: + SourceStreamResult with enriched stream and metadata. + + Raises: + ValueError: If tag_columns or record_id_column are not in table. + """ + tag_columns_tuple = tuple(tag_columns) + system_tag_columns_tuple = tuple(system_tag_columns) + + # 1. Drop system columns from raw input. + table = arrow_data_utils.drop_system_columns(table) + + # 2. Validate tag_columns. + missing_tags = set(tag_columns_tuple) - set(table.column_names) + if missing_tags: + raise ValueError( + f"tag_columns not found in table: {missing_tags}. " + f"Available columns: {list(table.column_names)}" + ) + + # 3. Validate record_id_column. + if record_id_column is not None and record_id_column not in table.column_names: + raise ValueError( + f"record_id_column {record_id_column!r} not found in table columns: " + f"{table.column_names}" + ) + + # 4. Compute schema hash from tag/packet python schemas. + non_sys = arrow_data_utils.drop_system_columns(table) + tag_schema = non_sys.select(list(tag_columns_tuple)).schema + packet_schema = non_sys.drop(list(tag_columns_tuple)).schema + tag_python = self._data_context.type_converter.arrow_schema_to_python_schema( + tag_schema + ) + packet_python = self._data_context.type_converter.arrow_schema_to_python_schema( + packet_schema + ) + schema_hash = self._data_context.semantic_hasher.hash_object( + (tag_python, packet_python) + ).to_hex(char_count=self._config.schema_hash_n_char) + + # 5. Compute table hash for data identity. + table_hash = self._data_context.arrow_hasher.hash_table(table) + + # 6. Default source_id to table hash. + if source_id is None: + source_id = table_hash.to_hex(char_count=self._config.path_hash_n_char) + + # 7. Build per-row source-info strings. + rows_as_dicts = table.to_pylist() + source_info = [ + f"{source_id}{constants.BLOCK_SEPARATOR}" + f"{_make_record_id(record_id_column, i, row)}" + for i, row in enumerate(rows_as_dicts) + ] + + # 8. Add source-info provenance columns. + table = arrow_data_utils.add_source_info( + table, source_info, exclude_columns=tag_columns_tuple + ) + + # 9. Add system tag columns. + record_id_values = [ + _make_record_id(record_id_column, i, row) + for i, row in enumerate(rows_as_dicts) + ] + table = arrow_data_utils.add_system_tag_columns( + table, + schema_hash, + source_id, + record_id_values, + ) + + # 10. Wrap in ArrowTableStream. + stream = ArrowTableStream( + table=table, + tag_columns=tag_columns_tuple, + system_tag_columns=system_tag_columns_tuple, + ) + + return SourceStreamResult( + stream=stream, + schema_hash=schema_hash, + table_hash=table_hash, + source_id=source_id, + tag_columns=tag_columns_tuple, + system_tag_columns=system_tag_columns_tuple, + ) diff --git a/src/orcapod/core/streams/__init__.py b/src/orcapod/core/streams/__init__.py index 9f1d6258..752c876b 100644 --- a/src/orcapod/core/streams/__init__.py +++ b/src/orcapod/core/streams/__init__.py @@ -1,18 +1,7 @@ -from .base import StatefulStreamBase -from .kernel_stream import KernelStream -from .table_stream import TableStream -from .lazy_pod_stream import LazyPodResultStream -from .cached_pod_stream import CachedPodStream -from .wrapped_stream import WrappedStream -from .pod_node_stream import PodNodeStream - +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.core.streams.base import StreamBase __all__ = [ - "StatefulStreamBase", - "KernelStream", - "TableStream", - "LazyPodResultStream", - "CachedPodStream", - "WrappedStream", - "PodNodeStream", + "ArrowTableStream", + "StreamBase", ] diff --git a/src/orcapod/core/streams/table_stream.py b/src/orcapod/core/streams/arrow_table_stream.py similarity index 68% rename from src/orcapod/core/streams/table_stream.py rename to src/orcapod/core/streams/arrow_table_stream.py index 9df62894..361b0035 100644 --- a/src/orcapod/core/streams/table_stream.py +++ b/src/orcapod/core/streams/arrow_table_stream.py @@ -4,33 +4,25 @@ from typing import TYPE_CHECKING, Any, cast from orcapod import contexts -from orcapod.core.datagrams import ( - ArrowPacket, - ArrowTag, - DictTag, -) -from orcapod.core.system_constants import constants -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema +from orcapod.core.datagrams import Packet, Tag +from orcapod.core.streams.base import StreamBase +from orcapod.protocols.core_protocols import PodProtocol, StreamProtocol, TagProtocol +from orcapod.protocols.hashing_protocols import PipelineElementProtocol +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig, Schema from orcapod.utils import arrow_utils from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams.base import ImmutableStream if TYPE_CHECKING: import pyarrow as pa - import pyarrow.compute as pc - import polars as pl - import pandas as pd else: pa = LazyModule("pyarrow") - pc = LazyModule("pyarrow.compute") - pl = LazyModule("polars") - pd = LazyModule("pandas") + logger = logging.getLogger(__name__) -class TableStream(ImmutableStream): +class ArrowTableStream(StreamBase): """ An immutable stream based on a PyArrow Table. This stream is designed to be used with data that is already in a tabular format, @@ -48,11 +40,14 @@ def __init__( tag_columns: Collection[str] = (), system_tag_columns: Collection[str] = (), source_info: dict[str, str | None] | None = None, - source: cp.Kernel | None = None, - upstreams: tuple[cp.Stream, ...] = (), + producer: PodProtocol | None = None, + upstreams: tuple[StreamProtocol, ...] = (), **kwargs, ) -> None: - super().__init__(source=source, upstreams=upstreams, **kwargs) + super().__init__(**kwargs) + + self._producer = producer + self._upstreams = upstreams data_table, data_context_table = arrow_utils.split_by_column_groups( table, [constants.CONTEXT_KEY] @@ -132,58 +127,67 @@ def __init__( self._system_tag_schema = system_tag_schema self._all_tag_schema = all_tag_schema self._packet_schema = packet_schema - # self._tag_converter = SemanticConverter.from_semantic_schema( - # schemas.SemanticSchema.from_arrow_schema( - # tag_schema, self._data_context.semantic_type_registry - # ) - # ) - # self._packet_converter = SemanticConverter.from_semantic_schema( - # schemas.SemanticSchema.from_arrow_schema( - # packet_schema, self._data_context.semantic_type_registry - # ) - # ) - - self._cached_elements: list[tuple[cp.Tag, ArrowPacket]] | None = None - self._set_modified_time() # set modified time to now - - def data_content_identity_structure(self) -> Any: - """ - Returns a hash of the content of the stream. - This is used to identify the content of the stream. - """ - table_hash = self.data_context.arrow_hasher.hash_table( - self.as_table( - include_data_context=True, include_source=True, include_system_tags=True - ), - ) + + self._cached_elements: list[tuple[TagProtocol, Packet]] | None = None + self._update_modified_time() # set modified time to now + + def identity_structure(self) -> Any: + if self._producer is not None: + return super().identity_structure() return ( self.__class__.__name__, - table_hash, + self.as_table(all_info=True), self._tag_columns, ) + def pipeline_identity_structure(self) -> Any: + if self._producer is None or not isinstance( + self._producer, PipelineElementProtocol + ): + tag_schema, packet_schema = self.output_schema() + return (tag_schema, packet_schema) + return super().pipeline_identity_structure() + + @property + def producer(self) -> PodProtocol | None: + return self._producer + + @property + def upstreams(self) -> tuple[StreamProtocol, ...]: + return self._upstreams + def keys( - self, include_system_tags: bool = False + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> tuple[tuple[str, ...], tuple[str, ...]]: """ Returns the keys of the tag and packet columns in the stream. This is useful for accessing the columns in the stream. """ tag_columns = self._tag_columns - if include_system_tags: + columns_config = ColumnConfig.handle_config(columns, all_info=all_info) + # TODO: add standard parsing of columns + if columns_config.system_tags: tag_columns += self._system_tag_columns return tag_columns, self._packet_columns - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: """ Returns the types of the tag and packet columns in the stream. This is useful for accessing the types of the columns in the stream. """ + # normalize column config + columns_config = ColumnConfig.handle_config(columns, all_info=all_info) # TODO: consider using MappingProxyType to avoid copying the dicts converter = self.data_context.type_converter - if include_system_tags: + if columns_config.system_tags: tag_schema = self._all_tag_schema else: tag_schema = self._tag_schema @@ -194,24 +198,21 @@ def types( def as_table( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pa.Table": """ Returns the underlying table representation of the stream. This is useful for converting the stream to a table format. """ + columns_config = ColumnConfig.handle_config(columns, all_info=all_info) output_table = self._table - if include_content_hash: + if columns_config.content_hash: hash_column_name = ( "_content_hash" - if include_content_hash is True - else include_content_hash + if columns_config.content_hash is True + else columns_config.content_hash ) content_hashes = [ str(packet.content_hash()) for _, packet in self.iter_packets() @@ -219,22 +220,24 @@ def as_table( output_table = output_table.append_column( hash_column_name, pa.array(content_hashes, type=pa.large_string()) ) - if not include_system_tags: + if not columns_config.system_tags: # Check in original implementation output_table = output_table.drop_columns(list(self._system_tag_columns)) table_stack = (output_table,) - if include_data_context: + if columns_config.context: table_stack += (self._data_context_table,) - if include_source: + if columns_config.source: table_stack += (self._source_info_table,) table = arrow_utils.hstack_tables(*table_stack) - if sort_by_tags: + if columns_config.sort_by_tags: # TODO: cleanup the sorting tag selection logic try: target_tags = ( - self._all_tag_columns if include_system_tags else self._tag_columns + self._all_tag_columns + if columns_config.system_tags + else self._tag_columns ) return table.sort_by([(column, "ascending") for column in target_tags]) except pa.ArrowTypeError: @@ -250,14 +253,10 @@ def clear_cache(self) -> None: """ self._cached_elements = None - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, ArrowPacket]]: + def iter_packets(self) -> Iterator[tuple[TagProtocol, Packet]]: """ Iterates over the packets in the stream. - Each packet is represented as a tuple of (Tag, Packet). + Each packet is represented as a tuple of (TagProtocol, PacketProtocol). """ # TODO: make it work with table batch stream if self._cached_elements is None: @@ -267,7 +266,7 @@ def iter_packets( tags = self._table.select(self._all_tag_columns) tag_batches = tags.to_batches() else: - tag_batches = repeat(DictTag({})) + tag_batches = repeat(Tag({})) # TODO: come back and clean up this logic @@ -276,15 +275,15 @@ def iter_packets( for tag_batch, packet_batch in zip(tag_batches, packets.to_batches()): for i in range(len(packet_batch)): if tag_present: - tag = ArrowTag( + tag = Tag( tag_batch.slice(i, 1), # type: ignore data_context=self.data_context, ) else: - tag = cast(DictTag, tag_batch) + tag = cast(Tag, tag_batch) - packet = ArrowPacket( + packet = Packet( packet_batch.slice(i, 1), source_info=self._source_info_table.slice(i, 1).to_pylist()[0], data_context=self.data_context, @@ -297,34 +296,6 @@ def iter_packets( else: yield from self._cached_elements - def run( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """ - Runs the stream, which in this case is a no-op since the stream is immutable. - This is typically used to trigger any upstream computation of the stream. - """ - # No-op for immutable streams - pass - - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """ - Runs the stream asynchronously, which in this case is a no-op since the stream is immutable. - This is typically used to trigger any upstream computation of the stream. - """ - # No-op for immutable streams - pass - def __repr__(self) -> str: return ( f"{self.__class__.__name__}(table={self._table.column_names}, " diff --git a/src/orcapod/core/streams/base.py b/src/orcapod/core/streams/base.py index 2959cf3a..acf8d90c 100644 --- a/src/orcapod/core/streams/base.py +++ b/src/orcapod/core/streams/base.py @@ -1,25 +1,27 @@ -from calendar import c +from __future__ import annotations + import logging from abc import abstractmethod -from collections.abc import Collection, Iterator, Mapping -from datetime import datetime, timezone +from collections.abc import AsyncIterator, Collection, Iterator, Mapping +from datetime import datetime from typing import TYPE_CHECKING, Any -from orcapod import contexts -from orcapod.core.base import LabeledContentIdentifiableBase -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema +from orcapod.core.base import TraceableBase +from orcapod.protocols.core_protocols import ( + PacketProtocol, + PodProtocol, + StreamProtocol, + TagProtocol, +) +from orcapod.types import ColumnConfig, Schema from orcapod.utils.lazy_module import LazyModule - if TYPE_CHECKING: - import pyarrow as pa - import pyarrow.compute as pc - import polars as pl import pandas as pd + import polars as pl + import pyarrow as pa else: pa = LazyModule("pyarrow") - pc = LazyModule("pyarrow.compute") pl = LazyModule("polars") pd = LazyModule("pandas") @@ -30,63 +32,112 @@ logger = logging.getLogger(__name__) -class OperatorStreamBaseMixin: - def join(self, other_stream: cp.Stream, label: str | None = None) -> cp.Stream: +class StreamBase(TraceableBase): + @property + @abstractmethod + def producer(self) -> PodProtocol | None: ... + + @property + @abstractmethod + def upstreams(self) -> tuple[StreamProtocol, ...]: ... + + def identity_structure(self) -> Any: + if self.producer is not None: + return (self.producer, self.producer.argument_symmetry(self.upstreams)) + + raise NotImplementedError("StreamBase.identity_structure") + + def pipeline_identity_structure(self) -> Any: + return self.identity_structure() + + @property + def is_stale(self) -> bool: + """ + True if any upstream stream or the source pod has a ``last_modified`` + timestamp strictly newer than this stream's own ``last_modified``, + indicating that any in-memory cached content should be discarded and + repopulated. + + Semantics: + - A ``None`` timestamp on *this* stream means "content not yet + established" → always stale. + - A ``None`` timestamp on an upstream or source means "modification + time unknown" → conservatively treat as stale. + - Immutable streams with no upstreams and no source (e.g. + ``ArrowTableStream``) always return ``False``. + """ + own_time: datetime | None = self.last_modified + if own_time is None: + return True + candidates: list[datetime | None] = [s.last_modified for s in self.upstreams] + if self.producer is not None: + candidates.append(self.producer.last_modified) + return any(t is None or t > own_time for t in candidates) + + def computed_label(self) -> str | None: + if self.producer is not None: + # use the invocation operation label + return self.producer.label + return None + + def join( + self, other_stream: StreamProtocol, label: str | None = None + ) -> StreamBase: """ Joins this stream with another stream, returning a new stream that contains the combined data from both streams. """ from orcapod.core.operators import Join - return Join()(self, other_stream, label=label) # type: ignore + return Join()(self, other_stream, label=label) def semi_join( self, - other_stream: cp.Stream, + other_stream: StreamProtocol, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: """ Performs a semi-join with another stream, returning a new stream that contains only the packets from this stream that have matching tags in the other stream. """ from orcapod.core.operators import SemiJoin - return SemiJoin()(self, other_stream, label=label) # type: ignore + return SemiJoin()(self, other_stream, label=label) def map_tags( self, name_map: Mapping[str, str], drop_unmapped: bool = True, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: """ Maps the tags in this stream according to the provided name_map. If drop_unmapped is True, any tags that are not in the name_map will be dropped. """ from orcapod.core.operators import MapTags - return MapTags(name_map, drop_unmapped)(self, label=label) # type: ignore + return MapTags(name_map, drop_unmapped)(self, label=label) def map_packets( self, name_map: Mapping[str, str], drop_unmapped: bool = True, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: """ Maps the packets in this stream according to the provided packet_map. If drop_unmapped is True, any packets that are not in the packet_map will be dropped. """ from orcapod.core.operators import MapPackets - return MapPackets(name_map, drop_unmapped)(self, label=label) # type: ignore + return MapPackets(name_map, drop_unmapped)(self, label=label) def batch( - self: cp.Stream, + self, batch_size: int = 0, drop_partial_batch: bool = False, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: """ Batch stream into fixed-size chunks, each of size batch_size. If drop_last is True, any remaining elements that don't fit into a full batch will be dropped. @@ -95,15 +146,15 @@ def batch( return Batch(batch_size=batch_size, drop_partial_batch=drop_partial_batch)( self, label=label - ) # type: ignore + ) def polars_filter( - self: cp.Stream, + self, *predicates: Any, constraint_map: Mapping[str, Any] | None = None, label: str | None = None, **constraints: Any, - ) -> cp.Stream: + ) -> StreamBase: from orcapod.core.operators import PolarsFilter total_constraints = dict(constraint_map) if constraint_map is not None else {} @@ -115,11 +166,11 @@ def polars_filter( ) def select_tag_columns( - self: cp.Stream, + self, tag_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: """ Select the specified tag columns from the stream. A ValueError is raised if one or more specified tag columns do not exist in the stream unless strict = False. @@ -129,11 +180,11 @@ def select_tag_columns( return SelectTagColumns(tag_columns, strict=strict)(self, label=label) def select_packet_columns( - self: cp.Stream, + self, packet_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: """ Select the specified packet columns from the stream. A ValueError is raised if one or more specified packet columns do not exist in the stream unless strict = False. @@ -143,313 +194,127 @@ def select_packet_columns( return SelectPacketColumns(packet_columns, strict=strict)(self, label=label) def drop_tag_columns( - self: cp.Stream, + self, tag_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: from orcapod.core.operators import DropTagColumns return DropTagColumns(tag_columns, strict=strict)(self, label=label) def drop_packet_columns( - self: cp.Stream, + self, packet_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> cp.Stream: + ) -> StreamBase: from orcapod.core.operators import DropPacketColumns return DropPacketColumns(packet_columns, strict=strict)(self, label=label) - -class StatefulStreamBase(OperatorStreamBaseMixin, LabeledContentIdentifiableBase): - """ - A stream that has a unique identity within the pipeline. - """ - - def pop(self) -> cp.Stream: - return self - - def __init__( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self._last_modified: datetime | None = None - self._set_modified_time() - # note that this is not necessary for Stream protocol, but is provided - # for convenience to resolve semantic types and other context-specific information - self._execution_engine = execution_engine - self._execution_engine_opts = execution_engine_opts - - @property - def substream_identities(self) -> tuple[str, ...]: - """ - Returns the identities of the substreams that this stream is composed of. - This is used to identify the substreams in the computational graph. - """ - return (self.content_hash().to_hex(),) - - @property - def execution_engine(self) -> cp.ExecutionEngine | None: - """ - Returns the execution engine that is used to execute this stream. - This is typically used to track the execution context of the stream. - """ - return self._execution_engine - - @execution_engine.setter - def execution_engine(self, engine: cp.ExecutionEngine | None) -> None: - """ - Sets the execution engine for the stream. - This is typically used to track the execution context of the stream. - """ - self._execution_engine = engine - - # TODO: add getter/setter for execution engine opts - - def get_substream(self, substream_id: str) -> cp.Stream: - """ - Returns the substream with the given substream_id. - This is used to retrieve a specific substream from the stream. - """ - if substream_id == self.substream_identities[0]: - return self - else: - raise ValueError(f"Substream with ID {substream_id} not found.") - - @property - @abstractmethod - def source(self) -> cp.Kernel | None: - """ - The source of the stream, which is the kernel that generated the stream. - This is typically used to track the origin of the stream in the computational graph. - """ - ... - - @property - @abstractmethod - def upstreams(self) -> tuple[cp.Stream, ...]: - """ - The upstream streams that are used to generate this stream. - This is typically used to track the origin of the stream in the computational graph. - """ - ... - - def computed_label(self) -> str | None: - if self.source is not None: - # use the invocation operation label - return self.source.label - return None - @abstractmethod def keys( - self, include_system_tags: bool = False + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> tuple[tuple[str, ...], tuple[str, ...]]: ... - def tag_keys(self, include_system_tags: bool = False) -> tuple[str, ...]: - return self.keys(include_system_tags=include_system_tags)[0] - - def packet_keys(self) -> tuple[str, ...]: - return self.keys()[1] - @abstractmethod - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: ... - - def tag_types(self, include_system_tags: bool = False) -> PythonSchema: - return self.types(include_system_tags=include_system_tags)[0] - - def packet_types(self) -> PythonSchema: - return self.types()[1] - - @property - def last_modified(self) -> datetime | None: - """ - Returns when the stream's content was last modified. - This is used to track the time when the stream was last accessed. - Returns None if the stream has not been accessed yet. - """ - return self._last_modified - - @property - def is_current(self) -> bool: - """ - Returns whether the stream is current. - A stream is current if the content is up-to-date with respect to its source. - This can be used to determine if a stream with non-None last_modified is up-to-date. - Note that for asynchronous streams, this status is not applicable and always returns False. - """ - if self.last_modified is None: - # If there is no last_modified timestamp, we cannot determine if the stream is current - return False - - # check if the source kernel has been modified - if self.source is not None and ( - self.source.last_modified is None - or self.source.last_modified > self.last_modified - ): - return False - - # check if all upstreams are current - for upstream in self.upstreams: - if ( - not upstream.is_current - or upstream.last_modified is None - or upstream.last_modified > self.last_modified - ): - return False - return True - - def _set_modified_time( - self, timestamp: datetime | None = None, invalidate: bool = False - ) -> None: - if invalidate: - self._last_modified = None - return - - if timestamp is not None: - self._last_modified = timestamp - else: - self._last_modified = datetime.now(timezone.utc) + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: ... def __iter__( self, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: return self.iter_packets() @abstractmethod def iter_packets( self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: ... + ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: ... - @abstractmethod - def run( + async def async_iter_packets( self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: ... + ) -> AsyncIterator[tuple[TagProtocol, PacketProtocol]]: + """Async iterator over (tag, packet) pairs. - @abstractmethod - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: ... + Subclasses should override this to provide true async iteration. + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement async_iter_packets" + ) + # Make this an async generator so the return type is correct + yield # pragma: no cover @abstractmethod def as_table( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pa.Table": ... def as_polars_df( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pl.DataFrame": """ Convert the entire stream to a Polars DataFrame. """ return pl.DataFrame( self.as_table( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + columns=columns, + all_info=all_info, ) ) def as_df( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pl.DataFrame": """ Convert the entire stream to a Polars DataFrame. """ return self.as_polars_df( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + columns=columns, + all_info=all_info, ) def as_lazy_frame( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pl.LazyFrame": """ Convert the entire stream to a Polars LazyFrame. """ df = self.as_polars_df( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + columns=columns, + all_info=all_info, ) return df.lazy() def as_pandas_df( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - index_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + index_by_tags: bool = False, + all_info: bool = False, ) -> "pd.DataFrame": df = self.as_polars_df( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + columns=columns, + all_info=all_info, ) tag_keys, _ = self.keys() pdf = df.to_pandas() @@ -459,58 +324,43 @@ def as_pandas_df( def flow( self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Collection[tuple[cp.Tag, cp.Packet]]: + ) -> Collection[tuple[TagProtocol, PacketProtocol]]: """ Flow everything through the stream, returning the entire collection of - (Tag, Packet) as a collection. This will tigger any upstream computation of the stream. + (TagProtocol, PacketProtocol) as a collection. This will tigger any upstream computation of the stream. """ - return [ - e - for e in self.iter_packets( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - ] + return [e for e in self.iter_packets()] def _repr_html_(self) -> str: df = self.as_polars_df() # reorder columns - new_column_order = [c for c in df.columns if c in self.tag_keys()] + [c for c in df.columns if c not in self.tag_keys()] + new_column_order = [c for c in df.columns if c in self.keys()[0]] + [ + c for c in df.columns if c not in self.keys()[0] + ] df = df[new_column_order] - tag_map = {t: f"*{t}" for t in self.tag_keys()} + tag_map = {t: f"*{t}" for t in self.keys()[0]} # TODO: construct repr html better df = df.rename(tag_map) return f"{self.__class__.__name__}[{self.label}]\n" + df._repr_html_() def view( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "StreamView": df = self.as_polars_df( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + columns=columns, + all_info=all_info, ) - tag_map = {t: f"*{t}" for t in self.tag_keys()} + tag_map = {t: f"*{t}" for t in self.keys()[0]} # TODO: construct repr html better df = df.rename(tag_map) return StreamView(self, df) class StreamView: - def __init__(self, stream: StatefulStreamBase, view_df: "pl.DataFrame") -> None: + def __init__(self, stream: StreamBase, view_df: "pl.DataFrame") -> None: self._stream = stream self._view_df = view_df @@ -519,130 +369,3 @@ def _repr_html_(self) -> str: f"{self._stream.__class__.__name__}[{self._stream.label}]\n" + self._view_df._repr_html_() ) - - # def identity_structure(self) -> Any: - # """ - # Identity structure of a stream is deferred to the identity structure - # of the associated invocation, if present. - # A bare stream without invocation has no well-defined identity structure. - # Specialized stream subclasses should override this method to provide more meaningful identity structure - # """ - # ... - - -class StreamBase(StatefulStreamBase): - """ - A stream is a collection of tagged-packets that are generated by an operation. - The stream is iterable and can be used to access the packets in the stream. - - A stream has property `invocation` that is an instance of Invocation that generated the stream. - This may be None if the stream is not generated by a kernel (i.e. directly instantiated by a user). - """ - - def __init__( - self, - source: cp.Kernel | None = None, - upstreams: tuple[cp.Stream, ...] = (), - data_context: str | contexts.DataContext | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self._source = source - self._upstreams = upstreams - - # if data context is not provided, use that of the source kernel - if data_context is None and source is not None: - # if source is provided, use its data context - data_context = source.data_context_key - super().__init__(data_context=data_context, **kwargs) - - @property - def source(self) -> cp.Kernel | None: - """ - The source of the stream, which is the kernel that generated the stream. - This is typically used to track the origin of the stream in the computational graph. - """ - return self._source - - @property - def upstreams(self) -> tuple[cp.Stream, ...]: - """ - The upstream streams that are used to generate this stream. - This is typically used to track the origin of the stream in the computational graph. - """ - return self._upstreams - - def computed_label(self) -> str | None: - if self.source is not None: - # use the invocation operation label - return self.source.label - return None - - # @abstractmethod - # def iter_packets( - # self, - # execution_engine: dp.ExecutionEngine | None = None, - # ) -> Iterator[tuple[dp.Tag, dp.Packet]]: ... - - # @abstractmethod - # def run( - # self, - # execution_engine: dp.ExecutionEngine | None = None, - # ) -> None: ... - - # @abstractmethod - # async def run_async( - # self, - # execution_engine: dp.ExecutionEngine | None = None, - # ) -> None: ... - - # @abstractmethod - # def as_table( - # self, - # include_data_context: bool = False, - # include_source: bool = False, - # include_system_tags: bool = False, - # include_content_hash: bool | str = False, - # sort_by_tags: bool = True, - # execution_engine: dp.ExecutionEngine | None = None, - # ) -> "pa.Table": ... - - def identity_structure(self) -> Any: - """ - Identity structure of a stream is deferred to the identity structure - of the associated invocation, if present. - A bare stream without invocation has no well-defined identity structure. - Specialized stream subclasses should override this method to provide more meaningful identity structure - """ - if self.source is not None: - # if the stream is generated by an operation, use the identity structure from the invocation - return self.source.identity_structure(self.upstreams) - return super().identity_structure() - - -class ImmutableStream(StreamBase): - """ - A class of stream that is constructed from immutable/constant data and does not change over time. - Consequently, the identity of an unsourced stream should be based on the content of the stream itself. - """ - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._data_content_identity = None - - @abstractmethod - def data_content_identity_structure(self) -> Any: - """ - Returns a hash of the content of the stream. - This is used to identify the content of the stream. - """ - ... - - def identity_structure(self) -> Any: - if self.source is not None: - # if the stream is generated by an operation, use the identity structure from the invocation - return self.source.identity_structure(self.upstreams) - # otherwise, use the content of the stream as the identity structure - if self._data_content_identity is None: - self._data_content_identity = self.data_content_identity_structure() - return self._data_content_identity diff --git a/src/orcapod/core/streams/cached_pod_stream.py b/src/orcapod/core/streams/cached_pod_stream.py deleted file mode 100644 index 541af520..00000000 --- a/src/orcapod/core/streams/cached_pod_stream.py +++ /dev/null @@ -1,479 +0,0 @@ -import logging -from collections.abc import Iterator -from typing import TYPE_CHECKING, Any - -from orcapod.core.system_constants import constants -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema -from orcapod.utils import arrow_utils -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams.base import StreamBase -from orcapod.core.streams.table_stream import TableStream - - -if TYPE_CHECKING: - import pyarrow as pa - import pyarrow.compute as pc - import polars as pl - -else: - pa = LazyModule("pyarrow") - pc = LazyModule("pyarrow.compute") - pl = LazyModule("polars") - - -# TODO: consider using this instead of making copy of dicts -# from types import MappingProxyType - -logger = logging.getLogger(__name__) - - -class CachedPodStream(StreamBase): - """ - A fixed stream that lazily processes packets from a prepared input stream. - This is what Pod.process() returns - it's static/fixed but efficient. - """ - - # TODO: define interface for storage or pod storage - def __init__(self, pod: cp.CachedPod, input_stream: cp.Stream, **kwargs): - super().__init__(source=pod, upstreams=(input_stream,), **kwargs) - self.pod = pod - self.input_stream = input_stream - self._set_modified_time() # set modified time to when we obtain the iterator - # capture the immutable iterator from the input stream - - self._prepared_stream_iterator = input_stream.iter_packets() - - # Packet-level caching (from your PodStream) - self._cached_output_packets: list[tuple[cp.Tag, cp.Packet | None]] | None = None - self._cached_output_table: pa.Table | None = None - self._cached_content_hash_column: pa.Array | None = None - - def set_mode(self, mode: str) -> None: - return self.pod.set_mode(mode) - - @property - def mode(self) -> str: - return self.pod.mode - - def test(self) -> cp.Stream: - return self - - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """ - Runs the stream, processing the input stream and preparing the output stream. - This is typically called before iterating over the packets. - """ - if self._cached_output_packets is None: - cached_results = [] - - # identify all entries in the input stream for which we still have not computed packets - target_entries = self.input_stream.as_table( - include_content_hash=constants.INPUT_PACKET_HASH, - include_source=True, - include_system_tags=True, - ) - existing_entries = self.pod.get_all_cached_outputs( - include_system_columns=True - ) - if existing_entries is None or existing_entries.num_rows == 0: - missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH]) - existing = None - else: - all_results = target_entries.join( - existing_entries.append_column( - "_exists", pa.array([True] * len(existing_entries)) - ), - keys=[constants.INPUT_PACKET_HASH], - join_type="left outer", - right_suffix="_right", - ) - # grab all columns from target_entries first - missing = ( - all_results.filter(pc.is_null(pc.field("_exists"))) - .select(target_entries.column_names) - .drop_columns([constants.INPUT_PACKET_HASH]) - ) - - existing = ( - all_results.filter(pc.is_valid(pc.field("_exists"))) - .drop_columns(target_entries.column_names) - .drop_columns(["_exists"]) - ) - renamed = [ - c.removesuffix("_right") if c.endswith("_right") else c - for c in existing.column_names - ] - existing = existing.rename_columns(renamed) - - tag_keys = self.input_stream.keys()[0] - - if existing is not None and existing.num_rows > 0: - # If there are existing entries, we can cache them - existing_stream = TableStream(existing, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - cached_results.append((tag, packet)) - - pending_calls = [] - if missing is not None and missing.num_rows > 0: - for tag, packet in TableStream(missing, tag_columns=tag_keys): - # Since these packets are known to be missing, skip the cache lookup - pending = self.pod.async_call( - tag, - packet, - skip_cache_lookup=True, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - pending_calls.append(pending) - import asyncio - - completed_calls = await asyncio.gather(*pending_calls) - for result in completed_calls: - cached_results.append(result) - - self._cached_output_packets = cached_results - self._set_modified_time() - - def run( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - cached_results = [] - - # identify all entries in the input stream for which we still have not computed packets - target_entries = self.input_stream.as_table( - include_system_tags=True, - include_source=True, - include_content_hash=constants.INPUT_PACKET_HASH, - execution_engine=execution_engine, - ) - existing_entries = self.pod.get_all_cached_outputs(include_system_columns=True) - if ( - existing_entries is None - or existing_entries.num_rows == 0 - or self.mode == "development" - ): - missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH]) - existing = None - else: - # TODO: do more proper replacement operation - target_df = pl.DataFrame(target_entries) - existing_df = pl.DataFrame( - existing_entries.append_column( - "_exists", pa.array([True] * len(existing_entries)) - ) - ) - all_results_df = target_df.join( - existing_df, - on=constants.INPUT_PACKET_HASH, - how="left", - suffix="_right", - ) - all_results = all_results_df.to_arrow() - - missing = ( - all_results.filter(pc.is_null(pc.field("_exists"))) - .select(target_entries.column_names) - .drop_columns([constants.INPUT_PACKET_HASH]) - ) - - existing = all_results.filter( - pc.is_valid(pc.field("_exists")) - ).drop_columns( - [ - "_exists", - constants.INPUT_PACKET_HASH, - constants.PACKET_RECORD_ID, - *self.input_stream.keys()[1], # remove the input packet keys - ] - # TODO: look into NOT fetching back the record ID - ) - renamed = [ - c.removesuffix("_right") if c.endswith("_right") else c - for c in existing.column_names - ] - existing = existing.rename_columns(renamed) - - tag_keys = self.input_stream.keys()[0] - - if existing is not None and existing.num_rows > 0: - # If there are existing entries, we can cache them - existing_stream = TableStream(existing, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - cached_results.append((tag, packet)) - - if missing is not None and missing.num_rows > 0: - hash_to_output_lut: dict[str, cp.Packet | None] = {} - for tag, packet in TableStream(missing, tag_columns=tag_keys): - # Since these packets are known to be missing, skip the cache lookup - packet_hash = packet.content_hash().to_string() - if packet_hash in hash_to_output_lut: - output_packet = hash_to_output_lut[packet_hash] - else: - tag, output_packet = self.pod.call( - tag, - packet, - skip_cache_lookup=True, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - # TODO: use getter for execution engine opts - hash_to_output_lut[packet_hash] = output_packet - cached_results.append((tag, output_packet)) - - self._cached_output_packets = cached_results - self._set_modified_time() - - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """ - Processes the input stream and prepares the output stream. - This is typically called before iterating over the packets. - """ - if self._cached_output_packets is None: - cached_results = [] - - # identify all entries in the input stream for which we still have not computed packets - target_entries = self.input_stream.as_table( - include_system_tags=True, - include_source=True, - include_content_hash=constants.INPUT_PACKET_HASH, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - existing_entries = self.pod.get_all_cached_outputs( - include_system_columns=True - ) - if existing_entries is None or existing_entries.num_rows == 0: - missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH]) - existing = None - else: - # missing = target_entries.join( - # existing_entries, - # keys=[constants.INPUT_PACKET_HASH], - # join_type="left anti", - # ) - # Single join that gives you both missing and existing - # More efficient - only bring the key column from existing_entries - # .select([constants.INPUT_PACKET_HASH]).append_column( - # "_exists", pa.array([True] * len(existing_entries)) - # ), - - # TODO: do more proper replacement operation - target_df = pl.DataFrame(target_entries) - existing_df = pl.DataFrame( - existing_entries.append_column( - "_exists", pa.array([True] * len(existing_entries)) - ) - ) - all_results_df = target_df.join( - existing_df, - on=constants.INPUT_PACKET_HASH, - how="left", - suffix="_right", - ) - all_results = all_results_df.to_arrow() - # all_results = target_entries.join( - # existing_entries.append_column( - # "_exists", pa.array([True] * len(existing_entries)) - # ), - # keys=[constants.INPUT_PACKET_HASH], - # join_type="left outer", - # right_suffix="_right", # rename the existing records in case of collision of output packet keys with input packet keys - # ) - # grab all columns from target_entries first - missing = ( - all_results.filter(pc.is_null(pc.field("_exists"))) - .select(target_entries.column_names) - .drop_columns([constants.INPUT_PACKET_HASH]) - ) - - existing = all_results.filter( - pc.is_valid(pc.field("_exists")) - ).drop_columns( - [ - "_exists", - constants.INPUT_PACKET_HASH, - constants.PACKET_RECORD_ID, - *self.input_stream.keys()[1], # remove the input packet keys - ] - # TODO: look into NOT fetching back the record ID - ) - renamed = [ - c.removesuffix("_right") if c.endswith("_right") else c - for c in existing.column_names - ] - existing = existing.rename_columns(renamed) - - tag_keys = self.input_stream.keys()[0] - - if existing is not None and existing.num_rows > 0: - # If there are existing entries, we can cache them - existing_stream = TableStream(existing, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - cached_results.append((tag, packet)) - yield tag, packet - - if missing is not None and missing.num_rows > 0: - hash_to_output_lut: dict[str, cp.Packet | None] = {} - for tag, packet in TableStream(missing, tag_columns=tag_keys): - # Since these packets are known to be missing, skip the cache lookup - packet_hash = packet.content_hash().to_string() - if packet_hash in hash_to_output_lut: - output_packet = hash_to_output_lut[packet_hash] - else: - tag, output_packet = self.pod.call( - tag, - packet, - skip_cache_lookup=True, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - hash_to_output_lut[packet_hash] = output_packet - cached_results.append((tag, output_packet)) - if output_packet is not None: - yield tag, output_packet - - self._cached_output_packets = cached_results - self._set_modified_time() - else: - for tag, packet in self._cached_output_packets: - if packet is not None: - yield tag, packet - - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """ - Returns the keys of the tag and packet columns in the stream. - This is useful for accessing the columns in the stream. - """ - - tag_keys, _ = self.input_stream.keys(include_system_tags=include_system_tags) - packet_keys = tuple(self.pod.output_packet_types().keys()) - return tag_keys, packet_keys - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_typespec, _ = self.input_stream.types( - include_system_tags=include_system_tags - ) - # TODO: check if copying can be avoided - packet_typespec = dict(self.pod.output_packet_types()) - return tag_typespec, packet_typespec - - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": - if self._cached_output_table is None: - all_tags = [] - all_packets = [] - tag_schema, packet_schema = None, None - for tag, packet in self.iter_packets( - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ): - if tag_schema is None: - tag_schema = tag.arrow_schema(include_system_tags=True) - if packet_schema is None: - packet_schema = packet.arrow_schema( - include_context=True, - include_source=True, - ) - all_tags.append(tag.as_dict(include_system_tags=True)) - # FIXME: using in the pinch conversion to str from path - # replace with an appropriate semantic converter-based approach! - dict_patcket = packet.as_dict(include_context=True, include_source=True) - all_packets.append(dict_patcket) - - converter = self.data_context.type_converter - - struct_packets = converter.python_dicts_to_struct_dicts(all_packets) - all_tags_as_tables: pa.Table = pa.Table.from_pylist( - all_tags, schema=tag_schema - ) - all_packets_as_tables: pa.Table = pa.Table.from_pylist( - struct_packets, schema=packet_schema - ) - - self._cached_output_table = arrow_utils.hstack_tables( - all_tags_as_tables, all_packets_as_tables - ) - assert self._cached_output_table is not None, ( - "_cached_output_table should not be None here." - ) - - drop_columns = [] - if not include_source: - drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1]) - if not include_data_context: - drop_columns.append(constants.CONTEXT_KEY) - if not include_system_tags: - # TODO: come up with a more efficient approach - drop_columns.extend( - [ - c - for c in self._cached_output_table.column_names - if c.startswith(constants.SYSTEM_TAG_PREFIX) - ] - ) - - output_table = self._cached_output_table.drop_columns(drop_columns) - - # lazily prepare content hash column if requested - if include_content_hash: - if self._cached_content_hash_column is None: - content_hashes = [] - for tag, packet in self.iter_packets(execution_engine=execution_engine): - content_hashes.append(packet.content_hash().to_string()) - self._cached_content_hash_column = pa.array( - content_hashes, type=pa.large_string() - ) - assert self._cached_content_hash_column is not None, ( - "_cached_content_hash_column should not be None here." - ) - hash_column_name = ( - "_content_hash" - if include_content_hash is True - else include_content_hash - ) - output_table = output_table.append_column( - hash_column_name, self._cached_content_hash_column - ) - - if sort_by_tags: - try: - # TODO: consider having explicit tag/packet properties? - output_table = output_table.sort_by( - [(column, "ascending") for column in self.keys()[0]] - ) - except pa.ArrowTypeError: - pass - - return output_table diff --git a/src/orcapod/core/streams/kernel_stream.py b/src/orcapod/core/streams/kernel_stream.py deleted file mode 100644 index e5f60e34..00000000 --- a/src/orcapod/core/streams/kernel_stream.py +++ /dev/null @@ -1,215 +0,0 @@ -import logging -from collections.abc import Iterator -from datetime import datetime -from typing import TYPE_CHECKING, Any - -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams.base import StreamBase - - -if TYPE_CHECKING: - import pyarrow as pa - import pyarrow.compute as pc - import polars as pl - import pandas as pd - import asyncio -else: - pa = LazyModule("pyarrow") - pc = LazyModule("pyarrow.compute") - pl = LazyModule("polars") - pd = LazyModule("pandas") - asyncio = LazyModule("asyncio") - - -# TODO: consider using this instead of making copy of dicts -# from types import MappingProxyType - -logger = logging.getLogger(__name__) - - -class KernelStream(StreamBase): - """ - Recomputable stream that wraps a stream produced by a kernel to provide - an abstraction over the stream, taking the stream's source and upstreams as the basis of - recomputing the stream. - - This stream is used to represent the output of a kernel invocation. - """ - - def __init__( - self, - output_stream: cp.Stream | None = None, - source: cp.Kernel | None = None, - upstreams: tuple[ - cp.Stream, ... - ] = (), # if provided, this will override the upstreams of the output_stream - **kwargs, - ) -> None: - if (output_stream is None or output_stream.source is None) and source is None: - raise ValueError( - "Either output_stream must have a kernel assigned to it or source must be provided in order to be recomputable." - ) - if source is None: - if output_stream is None or output_stream.source is None: - raise ValueError( - "Either output_stream must have a kernel assigned to it or source must be provided in order to be recomputable." - ) - source = output_stream.source - upstreams = upstreams or output_stream.upstreams - - super().__init__(source=source, upstreams=upstreams, **kwargs) - self.kernel = source - self._cached_stream = output_stream - - def clear_cache(self) -> None: - """ - Clears the cached stream. - This is useful for re-processing the stream with the same kernel. - """ - self._cached_stream = None - self._set_modified_time(invalidate=True) - - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """ - Returns the keys of the tag and packet columns in the stream. - This is useful for accessing the columns in the stream. - """ - tag_types, packet_types = self.kernel.output_types( - *self.upstreams, include_system_tags=include_system_tags - ) - return tuple(tag_types.keys()), tuple(packet_types.keys()) - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Returns the types of the tag and packet columns in the stream. - This is useful for accessing the types of the columns in the stream. - """ - return self.kernel.output_types( - *self.upstreams, include_system_tags=include_system_tags - ) - - @property - def is_current(self) -> bool: - if self._cached_stream is None or not super().is_current: - status = self.refresh() - if not status: # if it failed to update for whatever reason - return False - return True - - def refresh(self, force: bool = False) -> bool: - updated = False - if force or (self._cached_stream is not None and not super().is_current): - self.clear_cache() - - if self._cached_stream is None: - assert self.source is not None, ( - "Stream source must be set to recompute the stream." - ) - self._cached_stream = self.source.forward(*self.upstreams) - self._set_modified_time() - updated = True - - if self._cached_stream is None: - # TODO: use beter error type - raise ValueError( - "Stream could not be updated. Ensure that the source is valid and upstreams are correct." - ) - - return updated - - def invalidate(self) -> None: - """ - Invalidate the stream, marking it as needing recomputation. - This will clear the cached stream and set the last modified time to None. - """ - self.clear_cache() - self._set_modified_time(invalidate=True) - - @property - def last_modified(self) -> datetime | None: - if self._cached_stream is None: - return None - return self._cached_stream.last_modified - - def run( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - self.refresh() - assert self._cached_stream is not None, ( - "Stream has not been updated or is empty." - ) - self._cached_stream.run( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, - ) - - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - self.refresh() - assert self._cached_stream is not None, ( - "Stream has not been updated or is empty." - ) - await self._cached_stream.run_async( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, - ) - - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": - self.refresh() - assert self._cached_stream is not None, ( - "Stream has not been updated or is empty." - ) - return self._cached_stream.as_table( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - self.refresh() - assert self._cached_stream is not None, ( - "Stream has not been updated or is empty." - ) - return self._cached_stream.iter_packets( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(kernel={self.source}, upstreams={self.upstreams})" diff --git a/src/orcapod/core/streams/lazy_pod_stream.py b/src/orcapod/core/streams/lazy_pod_stream.py deleted file mode 100644 index 23f146ac..00000000 --- a/src/orcapod/core/streams/lazy_pod_stream.py +++ /dev/null @@ -1,256 +0,0 @@ -import logging -from collections.abc import Iterator -from pathlib import Path -from typing import TYPE_CHECKING, Any - -from orcapod.core.system_constants import constants -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema -from orcapod.utils import arrow_utils -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams.base import StreamBase - - -if TYPE_CHECKING: - import pyarrow as pa - import polars as pl - import asyncio -else: - pa = LazyModule("pyarrow") - pl = LazyModule("polars") - asyncio = LazyModule("asyncio") - - -# TODO: consider using this instead of making copy of dicts -# from types import MappingProxyType - -logger = logging.getLogger(__name__) - - -class LazyPodResultStream(StreamBase): - """ - A fixed stream that lazily processes packets from a prepared input stream. - This is what Pod.process() returns - it's static/fixed but efficient. - """ - - def __init__(self, pod: cp.Pod, prepared_stream: cp.Stream, **kwargs): - super().__init__(source=pod, upstreams=(prepared_stream,), **kwargs) - self.pod = pod - self.prepared_stream = prepared_stream - # capture the immutable iterator from the prepared stream - self._prepared_stream_iterator = prepared_stream.iter_packets() - self._set_modified_time() # set modified time to AFTER we obtain the iterator - # note that the invocation of iter_packets on upstream likely triggeres the modified time - # to be updated on the usptream. Hence you want to set this stream's modified time after that. - - # Packet-level caching (from your PodStream) - self._cached_output_packets: dict[int, tuple[cp.Tag, cp.Packet | None]] = {} - self._cached_output_table: pa.Table | None = None - self._cached_content_hash_column: pa.Array | None = None - - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - if self._prepared_stream_iterator is not None: - for i, (tag, packet) in enumerate(self._prepared_stream_iterator): - if i in self._cached_output_packets: - # Use cached result - tag, packet = self._cached_output_packets[i] - if packet is not None: - yield tag, packet - else: - # Process packet - processed = self.pod.call( - tag, - packet, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - # TODO: verify the proper use of execution engine opts - if processed is not None: - # Update shared cache for future iterators (optimization) - self._cached_output_packets[i] = processed - tag, packet = processed - if packet is not None: - yield tag, packet - - # Mark completion by releasing the iterator - self._prepared_stream_iterator = None - else: - # Yield from snapshot of complete cache - for i in range(len(self._cached_output_packets)): - tag, packet = self._cached_output_packets[i] - if packet is not None: - yield tag, packet - - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - if self._prepared_stream_iterator is not None: - pending_call_lut = {} - for i, (tag, packet) in enumerate(self._prepared_stream_iterator): - if i not in self._cached_output_packets: - # Process packet - pending_call_lut[i] = self.pod.async_call( - tag, - packet, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - - indices = list(pending_call_lut.keys()) - pending_calls = [pending_call_lut[i] for i in indices] - - results = await asyncio.gather(*pending_calls) - for i, result in zip(indices, results): - self._cached_output_packets[i] = result - - # Mark completion by releasing the iterator - self._prepared_stream_iterator = None - - def run( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - # Fallback to synchronous run - self.flow( - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts or self._execution_engine_opts, - ) - - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """ - Returns the keys of the tag and packet columns in the stream. - This is useful for accessing the columns in the stream. - """ - - tag_keys, _ = self.prepared_stream.keys(include_system_tags=include_system_tags) - packet_keys = tuple(self.pod.output_packet_types().keys()) - return tag_keys, packet_keys - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_typespec, _ = self.prepared_stream.types( - include_system_tags=include_system_tags - ) - # TODO: check if copying can be avoided - packet_typespec = dict(self.pod.output_packet_types()) - return tag_typespec, packet_typespec - - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": - if self._cached_output_table is None: - all_tags = [] - all_packets = [] - tag_schema, packet_schema = None, None - for tag, packet in self.iter_packets( - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ): - if tag_schema is None: - tag_schema = tag.arrow_schema(include_system_tags=True) - if packet_schema is None: - packet_schema = packet.arrow_schema( - include_context=True, - include_source=True, - ) - all_tags.append(tag.as_dict(include_system_tags=True)) - # FIXME: using in the pinch conversion to str from path - # replace with an appropriate semantic converter-based approach! - dict_patcket = packet.as_dict(include_context=True, include_source=True) - all_packets.append(dict_patcket) - - # TODO: re-verify the implemetation of this conversion - converter = self.data_context.type_converter - - struct_packets = converter.python_dicts_to_struct_dicts(all_packets) - all_tags_as_tables: pa.Table = pa.Table.from_pylist( - all_tags, schema=tag_schema - ) - all_packets_as_tables: pa.Table = pa.Table.from_pylist( - struct_packets, schema=packet_schema - ) - - self._cached_output_table = arrow_utils.hstack_tables( - all_tags_as_tables, all_packets_as_tables - ) - assert self._cached_output_table is not None, ( - "_cached_output_table should not be None here." - ) - - drop_columns = [] - if not include_system_tags: - # TODO: get system tags more effiicently - drop_columns.extend( - [ - c - for c in self._cached_output_table.column_names - if c.startswith(constants.SYSTEM_TAG_PREFIX) - ] - ) - if not include_source: - drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1]) - if not include_data_context: - drop_columns.append(constants.CONTEXT_KEY) - - output_table = self._cached_output_table.drop(drop_columns) - - # lazily prepare content hash column if requested - if include_content_hash: - if self._cached_content_hash_column is None: - content_hashes = [] - # TODO: verify that order will be preserved - for tag, packet in self.iter_packets( - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts or self._execution_engine_opts, - ): - content_hashes.append(packet.content_hash().to_string()) - self._cached_content_hash_column = pa.array( - content_hashes, type=pa.large_string() - ) - assert self._cached_content_hash_column is not None, ( - "_cached_content_hash_column should not be None here." - ) - hash_column_name = ( - "_content_hash" - if include_content_hash is True - else include_content_hash - ) - output_table = output_table.append_column( - hash_column_name, self._cached_content_hash_column - ) - - if sort_by_tags: - # TODO: reimplement using polars natively - output_table = ( - pl.DataFrame(output_table) - .sort(by=self.keys()[0], descending=False) - .to_arrow() - ) - # output_table = output_table.sort_by( - # [(column, "ascending") for column in self.keys()[0]] - # ) - return output_table diff --git a/src/orcapod/core/streams/pod_node_stream.py b/src/orcapod/core/streams/pod_node_stream.py deleted file mode 100644 index 4596bcbd..00000000 --- a/src/orcapod/core/streams/pod_node_stream.py +++ /dev/null @@ -1,422 +0,0 @@ -import logging -from collections.abc import Iterator -from typing import TYPE_CHECKING, Any - -from orcapod.core.system_constants import constants -from orcapod.protocols import core_protocols as cp, pipeline_protocols as pp -from orcapod.types import PythonSchema -from orcapod.utils import arrow_utils -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams.base import StreamBase -from orcapod.core.streams.table_stream import TableStream - - -if TYPE_CHECKING: - import pyarrow as pa - import pyarrow.compute as pc - import polars as pl - -else: - pa = LazyModule("pyarrow") - pc = LazyModule("pyarrow.compute") - pl = LazyModule("polars") - - -# TODO: consider using this instead of making copy of dicts -# from types import MappingProxyType - -logger = logging.getLogger(__name__) - - -class PodNodeStream(StreamBase): - """ - A fixed stream that is both cached pod and pipeline storage aware - """ - - # TODO: define interface for storage or pod storage - def __init__(self, pod_node: pp.PodNode, input_stream: cp.Stream, **kwargs): - super().__init__(source=pod_node, upstreams=(input_stream,), **kwargs) - self.pod_node = pod_node - self.input_stream = input_stream - - # capture the immutable iterator from the input stream - self._prepared_stream_iterator = input_stream.iter_packets() - self._set_modified_time() # set modified time to when we obtain the iterator - - # Packet-level caching (from your PodStream) - self._cached_output_packets: list[tuple[cp.Tag, cp.Packet | None]] | None = None - self._cached_output_table: pa.Table | None = None - self._cached_content_hash_column: pa.Array | None = None - - def set_mode(self, mode: str) -> None: - return self.pod_node.set_mode(mode) - - @property - def mode(self) -> str: - return self.pod_node.mode - - async def run_async( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """ - Runs the stream, processing the input stream and preparing the output stream. - This is typically called before iterating over the packets. - """ - if self._cached_output_packets is None: - cached_results, missing = self._identify_existing_and_missing_entries(*args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, - ) - - tag_keys = self.input_stream.keys()[0] - - pending_calls = [] - if missing is not None and missing.num_rows > 0: - for tag, packet in TableStream(missing, tag_columns=tag_keys): - # Since these packets are known to be missing, skip the cache lookup - pending = self.pod_node.async_call( - tag, - packet, - skip_cache_lookup=True, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - pending_calls.append(pending) - - import asyncio - completed_calls = await asyncio.gather(*pending_calls) - for result in completed_calls: - cached_results.append(result) - - self.clear_cache() - self._cached_output_packets = cached_results - self._set_modified_time() - self.pod_node.flush() - - def _identify_existing_and_missing_entries(self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any) -> tuple[list[tuple[cp.Tag, cp.Packet|None]], pa.Table | None]: - cached_results: list[tuple[cp.Tag, cp.Packet|None]] = [] - - # identify all entries in the input stream for which we still have not computed packets - if len(args) > 0 or len(kwargs) > 0: - input_stream_used = self.input_stream.polars_filter(*args, **kwargs) - else: - input_stream_used = self.input_stream - - target_entries = input_stream_used.as_table( - include_system_tags=True, - include_source=True, - include_content_hash=constants.INPUT_PACKET_HASH, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts or self._execution_engine_opts, - ) - existing_entries = self.pod_node.get_all_cached_outputs( - include_system_columns=True - ) - if ( - existing_entries is None - or existing_entries.num_rows == 0 - or self.mode == "development" - ): - missing = target_entries.drop_columns([constants.INPUT_PACKET_HASH]) - existing = None - else: - # TODO: do more proper replacement operation - target_df = pl.DataFrame(target_entries) - existing_df = pl.DataFrame( - existing_entries.append_column( - "_exists", pa.array([True] * len(existing_entries)) - ) - ) - all_results_df = target_df.join( - existing_df, - on=constants.INPUT_PACKET_HASH, - how="left", - suffix="_right", - ) - all_results = all_results_df.to_arrow() - - missing = ( - all_results.filter(pc.is_null(pc.field("_exists"))) - .select(target_entries.column_names) - .drop_columns([constants.INPUT_PACKET_HASH]) - ) - - existing = all_results.filter( - pc.is_valid(pc.field("_exists")) - ).drop_columns( - [ - "_exists", - constants.INPUT_PACKET_HASH, - constants.PACKET_RECORD_ID, - *self.input_stream.keys()[1], # remove the input packet keys - ] - # TODO: look into NOT fetching back the record ID - ) - renamed = [ - c.removesuffix("_right") if c.endswith("_right") else c - for c in existing.column_names - ] - existing = existing.rename_columns(renamed) - - tag_keys = self.input_stream.keys()[0] - - if existing is not None and existing.num_rows > 0: - # If there are existing entries, we can cache them - # TODO: cache them based on the record ID - existing_stream = TableStream(existing, tag_columns=tag_keys) - for tag, packet in existing_stream.iter_packets(): - cached_results.append((tag, packet)) - - - - return cached_results, missing - - def run( - self, - *args: Any, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - tag_keys = self.input_stream.keys()[0] - cached_results, missing = self._identify_existing_and_missing_entries( - *args, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - **kwargs, - ) - - if missing is not None and missing.num_rows > 0: - packet_record_to_output_lut: dict[str, cp.Packet | None] = {} - execution_engine_hash = ( - execution_engine.name if execution_engine is not None else "default" - ) - for tag, packet in TableStream(missing, tag_columns=tag_keys): - # compute record id - packet_record_id = self.pod_node.get_record_id( - packet, execution_engine_hash=execution_engine_hash - ) - - # Since these packets are known to be missing, skip the cache lookup - if packet_record_id in packet_record_to_output_lut: - output_packet = packet_record_to_output_lut[packet_record_id] - else: - tag, output_packet = self.pod_node.call( - tag, - packet, - record_id=packet_record_id, - skip_cache_lookup=True, - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ) - packet_record_to_output_lut[packet_record_id] = output_packet - self.pod_node.add_pipeline_record( - tag, - packet, - packet_record_id, - retrieved=False, - skip_cache_lookup=True, - ) - cached_results.append((tag, output_packet)) - - - # reset the cache and set new results - self.clear_cache() - self._cached_output_packets = cached_results - self._set_modified_time() - self.pod_node.flush() - # TODO: evaluate proper handling of cache here - # self.clear_cache() - - def clear_cache(self) -> None: - self._cached_output_packets = None - self._cached_output_table = None - self._cached_content_hash_column = None - - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """ - Processes the input stream and prepares the output stream. - This is typically called before iterating over the packets. - """ - - # if results are cached, simply return from them - if self._cached_output_packets is not None: - for tag, packet in self._cached_output_packets: - if packet is not None: - # make sure to skip over an empty packet - yield tag, packet - else: - cached_results = [] - # prepare the cache by loading from the record - total_table = self.pod_node.get_all_records(include_system_columns=True) - if total_table is None: - return # empty out - tag_types, packet_types = self.pod_node.output_types() - - for tag, packet in TableStream(total_table, tag_columns=tag_types.keys()): - cached_results.append((tag, packet)) - yield tag, packet - - # come up with a better caching mechanism - self._cached_output_packets = cached_results - self._set_modified_time() - - - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """ - Returns the keys of the tag and packet columns in the stream. - This is useful for accessing the columns in the stream. - """ - - tag_keys, _ = self.input_stream.keys(include_system_tags=include_system_tags) - packet_keys = tuple(self.pod_node.output_packet_types().keys()) - return tag_keys, packet_keys - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - tag_typespec, _ = self.input_stream.types( - include_system_tags=include_system_tags - ) - # TODO: check if copying can be avoided - packet_typespec = dict(self.pod_node.output_packet_types()) - return tag_typespec, packet_typespec - - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": - if self._cached_output_table is None: - all_tags = [] - all_packets = [] - tag_schema, packet_schema = None, None - for tag, packet in self.iter_packets( - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ): - if tag_schema is None: - tag_schema = tag.arrow_schema(include_system_tags=True) - if packet_schema is None: - packet_schema = packet.arrow_schema( - include_context=True, - include_source=True, - ) - all_tags.append(tag.as_dict(include_system_tags=True)) - # FIXME: using in the pinch conversion to str from path - # replace with an appropriate semantic converter-based approach! - dict_patcket = packet.as_dict(include_context=True, include_source=True) - all_packets.append(dict_patcket) - - converter = self.data_context.type_converter - - if len(all_tags) == 0: - tag_types, packet_types = self.pod_node.output_types( - include_system_tags=True - ) - tag_schema = converter.python_schema_to_arrow_schema(tag_types) - source_entries = { - f"{constants.SOURCE_PREFIX}{c}": str for c in packet_types.keys() - } - packet_types.update(source_entries) - packet_types[constants.CONTEXT_KEY] = str - packet_schema = converter.python_schema_to_arrow_schema(packet_types) - total_schema = arrow_utils.join_arrow_schemas(tag_schema, packet_schema) - # return an empty table with the right schema - self._cached_output_table = pa.Table.from_pylist( - [], schema=total_schema - ) - else: - struct_packets = converter.python_dicts_to_struct_dicts(all_packets) - - all_tags_as_tables: pa.Table = pa.Table.from_pylist( - all_tags, schema=tag_schema - ) - all_packets_as_tables: pa.Table = pa.Table.from_pylist( - struct_packets, schema=packet_schema - ) - - self._cached_output_table = arrow_utils.hstack_tables( - all_tags_as_tables, all_packets_as_tables - ) - assert self._cached_output_table is not None, ( - "_cached_output_table should not be None here." - ) - - if self._cached_output_table.num_rows == 0: - return self._cached_output_table - drop_columns = [] - if not include_source: - drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1]) - if not include_data_context: - drop_columns.append(constants.CONTEXT_KEY) - if not include_system_tags: - # TODO: come up with a more efficient approach - drop_columns.extend( - [ - c - for c in self._cached_output_table.column_names - if c.startswith(constants.SYSTEM_TAG_PREFIX) - ] - ) - - output_table = self._cached_output_table.drop_columns(drop_columns) - - # lazily prepare content hash column if requested - if include_content_hash: - if self._cached_content_hash_column is None: - content_hashes = [] - for tag, packet in self.iter_packets( - execution_engine=execution_engine or self.execution_engine, - execution_engine_opts=execution_engine_opts - or self._execution_engine_opts, - ): - content_hashes.append(packet.content_hash().to_string()) - self._cached_content_hash_column = pa.array( - content_hashes, type=pa.large_string() - ) - assert self._cached_content_hash_column is not None, ( - "_cached_content_hash_column should not be None here." - ) - hash_column_name = ( - "_content_hash" - if include_content_hash is True - else include_content_hash - ) - output_table = output_table.append_column( - hash_column_name, self._cached_content_hash_column - ) - - if sort_by_tags: - try: - # TODO: consider having explicit tag/packet properties? - output_table = output_table.sort_by( - [(column, "ascending") for column in self.keys()[0]] - ) - except pa.ArrowTypeError: - pass - - return output_table diff --git a/src/orcapod/core/streams/wrapped_stream.py b/src/orcapod/core/streams/wrapped_stream.py deleted file mode 100644 index 6ba85308..00000000 --- a/src/orcapod/core/streams/wrapped_stream.py +++ /dev/null @@ -1,93 +0,0 @@ -import logging -from collections.abc import Iterator -from typing import TYPE_CHECKING, Any - -from orcapod.protocols import core_protocols as cp -from orcapod.types import PythonSchema -from orcapod.utils.lazy_module import LazyModule -from orcapod.core.streams.base import StreamBase - - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") - - -# TODO: consider using this instead of making copy of dicts -# from types import MappingProxyType - -logger = logging.getLogger(__name__) - - -class WrappedStream(StreamBase): - def __init__( - self, - stream: cp.Stream, - source: cp.Kernel, - input_streams: tuple[cp.Stream, ...], - label: str | None = None, - **kwargs, - ) -> None: - super().__init__(source=source, upstreams=input_streams, label=label, **kwargs) - self._stream = stream - - def keys( - self, include_system_tags: bool = False - ) -> tuple[tuple[str, ...], tuple[str, ...]]: - """ - Returns the keys of the tag and packet columns in the stream. - This is useful for accessing the columns in the stream. - """ - return self._stream.keys(include_system_tags=include_system_tags) - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Returns the types of the tag and packet columns in the stream. - This is useful for accessing the types of the columns in the stream. - """ - return self._stream.types(include_system_tags=include_system_tags) - - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": - """ - Returns the underlying table representation of the stream. - This is useful for converting the stream to a table format. - """ - return self._stream.as_table( - include_data_context=include_data_context, - include_source=include_source, - include_system_tags=include_system_tags, - include_content_hash=include_content_hash, - sort_by_tags=sort_by_tags, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - # TODO handle default execution engine - - def iter_packets( - self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[cp.Tag, cp.Packet]]: - """ - Iterates over the packets in the stream. - Each packet is represented as a tuple of (Tag, Packet). - """ - return self._stream.iter_packets( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, - ) - - def identity_structure(self) -> Any: - return self._stream.identity_structure() diff --git a/src/orcapod/core/tracker.py b/src/orcapod/core/tracker.py new file mode 100644 index 00000000..b040d79e --- /dev/null +++ b/src/orcapod/core/tracker.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +from orcapod.protocols import core_protocols as cp + + +class BasicTrackerManager: + def __init__(self) -> None: + self._active_trackers: list[cp.TrackerProtocol] = [] + self._active = True + + def set_active(self, active: bool = True) -> None: + """Set the active state of the tracker manager.""" + self._active = active + + def register_tracker(self, tracker: cp.TrackerProtocol) -> None: + """Register a new tracker in the system.""" + if tracker not in self._active_trackers: + self._active_trackers.append(tracker) + + def deregister_tracker(self, tracker: cp.TrackerProtocol) -> None: + """Remove a tracker from the system.""" + if tracker in self._active_trackers: + self._active_trackers.remove(tracker) + + def get_active_trackers(self) -> list[cp.TrackerProtocol]: + """Get the list of active trackers.""" + if not self._active: + return [] + return [t for t in self._active_trackers if t.is_active()] + + def record_operator_pod_invocation( + self, + pod: cp.PodProtocol, + upstreams: tuple[cp.StreamProtocol, ...] = (), + label: str | None = None, + ) -> None: + """Record the invocation of a pod in the tracker.""" + for tracker in self.get_active_trackers(): + tracker.record_operator_pod_invocation(pod, upstreams, label=label) + + def record_function_pod_invocation( + self, + pod: cp.FunctionPodProtocol, + input_stream: cp.StreamProtocol, + label: str | None = None, + ) -> None: + """Record the invocation of a packet function to the tracker.""" + for tracker in self.get_active_trackers(): + tracker.record_function_pod_invocation(pod, input_stream, label=label) + + @contextmanager + def no_tracking(self) -> Generator[None, Any, None]: + original_state = self._active + self.set_active(False) + try: + yield + finally: + self.set_active(original_state) + + +class AutoRegisteringContextBasedTracker(ABC): + def __init__( + self, tracker_manager: cp.TrackerManagerProtocol | None = None + ) -> None: + self._tracker_manager = tracker_manager or DEFAULT_TRACKER_MANAGER + self._active = False + + def set_active(self, active: bool = True) -> None: + if active: + self._tracker_manager.register_tracker(self) + else: + self._tracker_manager.deregister_tracker(self) + self._active = active + + def is_active(self) -> bool: + return self._active + + @abstractmethod + def record_operator_pod_invocation( + self, + pod: cp.OperatorPodProtocol, + upstreams: tuple[cp.StreamProtocol, ...] = (), + label: str | None = None, + ) -> None: ... + + @abstractmethod + def record_function_pod_invocation( + self, + pod: cp.FunctionPodProtocol, + input_stream: cp.StreamProtocol, + label: str | None = None, + ) -> None: ... + + def __enter__(self): + self.set_active(True) + return self + + def __exit__(self, exc_type, exc_val, ext_tb): + self.set_active(False) + + +DEFAULT_TRACKER_MANAGER = BasicTrackerManager() diff --git a/src/orcapod/core/trackers.py b/src/orcapod/core/trackers.py deleted file mode 100644 index 4ffe39a7..00000000 --- a/src/orcapod/core/trackers.py +++ /dev/null @@ -1,270 +0,0 @@ -from orcapod.core.base import LabeledContentIdentifiableBase -from orcapod.protocols import core_protocols as cp -from collections import defaultdict -from collections.abc import Generator -from abc import ABC, abstractmethod -from typing import Any, TYPE_CHECKING -from contextlib import contextmanager - - -if TYPE_CHECKING: - import networkx as nx - - -class BasicTrackerManager: - def __init__(self) -> None: - self._active_trackers: list[cp.Tracker] = [] - self._active = True - - def set_active(self, active: bool = True) -> None: - """ - Set the active state of the tracker manager. - This is used to enable or disable the tracker manager. - """ - self._active = active - - def register_tracker(self, tracker: cp.Tracker) -> None: - """ - Register a new tracker in the system. - This is used to add a new tracker to the list of active trackers. - """ - if tracker not in self._active_trackers: - self._active_trackers.append(tracker) - - def deregister_tracker(self, tracker: cp.Tracker) -> None: - """ - Remove a tracker from the system. - This is used to deactivate a tracker and remove it from the list of active trackers. - """ - if tracker in self._active_trackers: - self._active_trackers.remove(tracker) - - def get_active_trackers(self) -> list[cp.Tracker]: - """ - Get the list of active trackers. - This is used to retrieve the currently active trackers in the system. - """ - if not self._active: - return [] - # Filter out inactive trackers - # This is to ensure that we only return trackers that are currently active - return [t for t in self._active_trackers if t.is_active()] - - def record_kernel_invocation( - self, - kernel: cp.Kernel, - upstreams: tuple[cp.Stream, ...], - label: str | None = None, - ) -> None: - """ - Record the output stream of a kernel invocation in the tracker. - This is used to track the computational graph and the invocations of kernels. - """ - for tracker in self.get_active_trackers(): - tracker.record_kernel_invocation(kernel, upstreams, label=label) - - def record_source_invocation( - self, source: cp.Source, label: str | None = None - ) -> None: - """ - Record the output stream of a source invocation in the tracker. - This is used to track the computational graph and the invocations of sources. - """ - for tracker in self.get_active_trackers(): - tracker.record_source_invocation(source, label=label) - - def record_pod_invocation( - self, pod: cp.Pod, upstreams: tuple[cp.Stream, ...], label: str | None = None - ) -> None: - """ - Record the output stream of a pod invocation in the tracker. - This is used to track the computational graph and the invocations of pods. - """ - for tracker in self.get_active_trackers(): - tracker.record_pod_invocation(pod, upstreams, label=label) - - @contextmanager - def no_tracking(self) -> Generator[None, Any, None]: - original_state = self._active - self.set_active(False) - try: - yield - finally: - self.set_active(original_state) - - -class AutoRegisteringContextBasedTracker(ABC): - def __init__(self, tracker_manager: cp.TrackerManager | None = None) -> None: - self._tracker_manager = tracker_manager or DEFAULT_TRACKER_MANAGER - self._active = False - - def set_active(self, active: bool = True) -> None: - if active: - self._tracker_manager.register_tracker(self) - else: - self._tracker_manager.deregister_tracker(self) - self._active = active - - def is_active(self) -> bool: - return self._active - - @abstractmethod - def record_kernel_invocation( - self, - kernel: cp.Kernel, - upstreams: tuple[cp.Stream, ...], - label: str | None = None, - ) -> None: ... - - @abstractmethod - def record_source_invocation( - self, source: cp.Source, label: str | None = None - ) -> None: ... - - @abstractmethod - def record_pod_invocation( - self, pod: cp.Pod, upstreams: tuple[cp.Stream, ...], label: str | None = None - ) -> None: ... - - def __enter__(self): - self.set_active(True) - return self - - def __exit__(self, exc_type, exc_val, ext_tb): - self.set_active(False) - - -class Invocation(LabeledContentIdentifiableBase): - def __init__( - self, - kernel: cp.Kernel, - upstreams: tuple[cp.Stream, ...] = (), - label: str | None = None, - ) -> None: - """ - Represents an invocation of a kernel with its upstream streams. - This is used to track the computational graph and the invocations of kernels. - """ - super().__init__(label=label) - self.kernel = kernel - self.upstreams = upstreams - - def parents(self) -> tuple["Invocation", ...]: - parent_invoctions = [] - for stream in self.upstreams: - if stream.source is not None: - parent_invoctions.append(Invocation(stream.source, stream.upstreams)) - else: - # import JIT to avoid circular imports - from orcapod.core.sources.base import StreamSource - - source = StreamSource(stream) - parent_invoctions.append(Invocation(source)) - - return tuple(parent_invoctions) - - def computed_label(self) -> str | None: - """ - Compute a label for this invocation based on its kernel and upstreams. - If label is not explicitly set for this invocation and computed_label returns a valid value, - it will be used as label of this invocation. - """ - return self.kernel.label - - def identity_structure(self) -> Any: - """ - Return a structure that represents the identity of this invocation. - This is used to uniquely identify the invocation in the tracker. - """ - # if no upstreams, then we want to identify the source directly - if not self.upstreams: - return self.kernel.identity_structure() - return self.kernel.identity_structure(self.upstreams) - - def __repr__(self) -> str: - return f"Invocation(kernel={self.kernel}, upstreams={self.upstreams}, label={self.label})" - - -class GraphTracker(AutoRegisteringContextBasedTracker): - """ - A tracker that records the invocations of operations and generates a graph - of the invocations and their dependencies. - """ - - # Thread-local storage to track active trackers - - def __init__( - self, - tracker_manager: cp.TrackerManager | None = None, - **kwargs, - ) -> None: - super().__init__(tracker_manager=tracker_manager) - - # Dictionary to map kernels to the streams they have invoked - # This is used to track the computational graph and the invocations of kernels - self.kernel_invocations: set[Invocation] = set() - self.invocation_to_pod_lut: dict[Invocation, cp.Pod] = {} - self.invocation_to_source_lut: dict[Invocation, cp.Source] = {} - - def _record_kernel_and_get_invocation( - self, - kernel: cp.Kernel, - upstreams: tuple[cp.Stream, ...], - label: str | None = None, - ) -> Invocation: - invocation = Invocation(kernel, upstreams, label=label) - self.kernel_invocations.add(invocation) - return invocation - - def record_kernel_invocation( - self, - kernel: cp.Kernel, - upstreams: tuple[cp.Stream, ...], - label: str | None = None, - ) -> None: - """ - Record the output stream of a kernel invocation in the tracker. - This is used to track the computational graph and the invocations of kernels. - """ - self._record_kernel_and_get_invocation(kernel, upstreams, label) - - def record_source_invocation( - self, source: cp.Source, label: str | None = None - ) -> None: - """ - Record the output stream of a source invocation in the tracker. - """ - invocation = self._record_kernel_and_get_invocation(source, (), label) - self.invocation_to_source_lut[invocation] = source - - def record_pod_invocation( - self, pod: cp.Pod, upstreams: tuple[cp.Stream, ...], label: str | None = None - ) -> None: - """ - Record the output stream of a pod invocation in the tracker. - """ - invocation = self._record_kernel_and_get_invocation(pod, upstreams, label) - self.invocation_to_pod_lut[invocation] = pod - - def reset(self) -> dict[cp.Kernel, list[cp.Stream]]: - """ - Reset the tracker and return the recorded invocations. - """ - recorded_streams = self.kernel_to_invoked_stream_lut - self.kernel_to_invoked_stream_lut = defaultdict(list) - return recorded_streams - - def generate_graph(self) -> "nx.DiGraph": - import networkx as nx - - G = nx.DiGraph() - - # Add edges for each invocation - for invocation in self.kernel_invocations: - G.add_node(invocation) - for upstream_invocation in invocation.parents(): - G.add_edge(upstream_invocation, invocation) - return G - - -DEFAULT_TRACKER_MANAGER = BasicTrackerManager() diff --git a/src/orcapod/databases/__init__.py b/src/orcapod/databases/__init__.py index f47c7345..e8556e84 100644 --- a/src/orcapod/databases/__init__.py +++ b/src/orcapod/databases/__init__.py @@ -1,16 +1,21 @@ -# from .legacy.types import DataStore, ArrowDataStore -# from .legacy.legacy_arrow_data_stores import MockArrowDataStore, SimpleParquetDataStore -# from .legacy.dict_data_stores import DirDataStore, NoOpDataStore -# from .legacy.safe_dir_data_store import SafeDirDataStore +from .delta_lake_databases import DeltaTableDatabase +from .in_memory_databases import InMemoryArrowDatabase +from .noop_database import NoOpArrowDatabase -# __all__ = [ -# "DataStore", -# "ArrowDataStore", -# "DirDataStore", -# "SafeDirDataStore", -# "NoOpDataStore", -# "MockArrowDataStore", -# "SimpleParquetDataStore", -# ] +__all__ = [ + "DeltaTableDatabase", + "InMemoryArrowDatabase", + "NoOpArrowDatabase", +] -from .delta_lake_databases import DeltaTableDatabase +# Future ArrowDatabaseProtocol backends to implement: +# +# ParquetArrowDatabase -- stores each record_path as a partitioned Parquet +# directory; simpler, no Delta Lake dependency, +# suitable for write-once / read-heavy workloads. +# +# IcebergArrowDatabase -- Apache Iceberg backend for cloud-native / +# object-store deployments. +# +# All backends must satisfy the ArrowDatabaseProtocol protocol defined in +# orcapod.protocols.database_protocols. diff --git a/src/orcapod/databases/basic_delta_lake_arrow_database.py b/src/orcapod/databases/basic_delta_lake_arrow_database.py deleted file mode 100644 index 412d2479..00000000 --- a/src/orcapod/databases/basic_delta_lake_arrow_database.py +++ /dev/null @@ -1,1008 +0,0 @@ -import logging -from collections import defaultdict -from collections.abc import Collection, Mapping -from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, cast - -from deltalake import DeltaTable, write_deltalake -from deltalake.exceptions import TableNotFoundError - -from orcapod.core import constants -from orcapod.utils.lazy_module import LazyModule - -if TYPE_CHECKING: - import polars as pl - import pyarrow as pa - import pyarrow.compute as pc -else: - pa = LazyModule("pyarrow") - pl = LazyModule("polars") - pc = LazyModule("pyarrow.compute") - -# Module-level logger -logger = logging.getLogger(__name__) - - -class BasicDeltaTableArrowStore: - """ - A basic Delta Table-based Arrow data store with flexible hierarchical path support. - This store does NOT implement lazy loading or streaming capabilities, therefore - being "basic" in that sense. It is designed for simple use cases where data is written - in batches and read back as complete tables. It is worth noting that the Delta table - structure created by this store IS compatible with more advanced Delta Table-based - data stores (to be implemented) that will support lazy loading and streaming. - - Uses tuple-based source paths for robust parameter handling: - - ("source_name", "source_id") -> source_name/source_id/ - - ("org", "project", "dataset") -> org/project/dataset/ - - ("year", "month", "day", "experiment") -> year/month/day/experiment/ - """ - - RECORD_ID_COLUMN = f"{constants.META_PREFIX}record_id" - - def __init__( - self, - base_path: str | Path, - duplicate_entry_behavior: str = "error", - create_base_path: bool = True, - max_hierarchy_depth: int = 10, - batch_size: int = 100, - ): - """ - Initialize the BasicDeltaTableArrowStore. - - Args: - base_path: Base directory path where Delta tables will be stored - duplicate_entry_behavior: How to handle duplicate record_ids: - - 'error': Raise ValueError when record_id already exists - - 'overwrite': Replace existing entry with new data - create_base_path: Whether to create the base path if it doesn't exist - max_hierarchy_depth: Maximum allowed depth for source paths (safety limit) - batch_size: Number of records to batch before writing to Delta table - """ - # Validate duplicate behavior - if duplicate_entry_behavior not in ["error", "overwrite"]: - raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - - self.duplicate_entry_behavior = duplicate_entry_behavior - self.base_path = Path(base_path) - self.max_hierarchy_depth = max_hierarchy_depth - self.batch_size = batch_size - - if create_base_path: - self.base_path.mkdir(parents=True, exist_ok=True) - elif not self.base_path.exists(): - raise ValueError( - f"Base path {self.base_path} does not exist and create_base_path=False" - ) - - # Cache for Delta tables to avoid repeated initialization - self._delta_table_cache: dict[str, DeltaTable] = {} - - # Batch management - self._pending_batches: dict[str, dict[str, pa.Table]] = defaultdict(dict) - - logger.info( - f"Initialized DeltaTableArrowDataStore at {self.base_path} " - f"with duplicate_entry_behavior='{duplicate_entry_behavior}', " - f"batch_size={batch_size}, as" - ) - - def flush(self) -> None: - """ - Flush all pending batches immediately. - - This method is called to ensure all pending data is written to the Delta tables. - """ - try: - self.flush_all_batches() - except Exception as e: - logger.error(f"Error during flush: {e}") - - def flush_batch(self, record_path: tuple[str, ...]) -> None: - """ - Flush pending batch for a specific source path. - - Args: - record_path: Tuple of path components - """ - logger.debug("Flushing triggered!!") - source_key = self._get_source_key(record_path) - - if ( - source_key not in self._pending_batches - or not self._pending_batches[source_key] - ): - return - - # Get all pending records - pending_tables = self._pending_batches[source_key] - self._pending_batches[source_key] = {} - - try: - # Combine all tables in the batch - combined_table = pa.concat_tables(pending_tables.values()).combine_chunks() - - table_path = self._get_table_path(record_path) - table_path.mkdir(parents=True, exist_ok=True) - - # Check if table exists - delta_table = self._get_existing_delta_table(record_path) - - if delta_table is None: - # TODO: reconsider mode="overwrite" here - write_deltalake( - table_path, - combined_table, - mode="overwrite", - ) - logger.debug( - f"Created new Delta table for {source_key} with {len(combined_table)} records" - ) - else: - if self.duplicate_entry_behavior == "overwrite": - # Get entry IDs from the batch - record_ids = combined_table.column( - self.RECORD_ID_COLUMN - ).to_pylist() - unique_record_ids = cast(list[str], list(set(record_ids))) - - # Delete existing records with these IDs - if unique_record_ids: - record_ids_str = "', '".join(unique_record_ids) - delete_predicate = ( - f"{self.RECORD_ID_COLUMN} IN ('{record_ids_str}')" - ) - try: - delta_table.delete(delete_predicate) - logger.debug( - f"Deleted {len(unique_record_ids)} existing records from {source_key}" - ) - except Exception as e: - logger.debug( - f"No existing records to delete from {source_key}: {e}" - ) - - # otherwise, only insert if same record_id does not exist yet - delta_table.merge( - source=combined_table, - predicate=f"target.{self.RECORD_ID_COLUMN} = source.{self.RECORD_ID_COLUMN}", - source_alias="source", - target_alias="target", - ).when_not_matched_insert_all().execute() - - logger.debug( - f"Appended batch of {len(combined_table)} records to {source_key}" - ) - - # Update cache - self._delta_table_cache[source_key] = DeltaTable(str(table_path)) - - except Exception as e: - logger.error(f"Error flushing batch for {source_key}: {e}") - # Put the tables back in the pending queue - self._pending_batches[source_key] = pending_tables - raise - - def flush_all_batches(self) -> None: - """Flush all pending batches.""" - source_keys = list(self._pending_batches.keys()) - - # TODO: capture and re-raise exceptions at the end - for source_key in source_keys: - record_path = tuple(source_key.split("/")) - try: - self.flush_batch(record_path) - except Exception as e: - logger.error(f"Error flushing batch for {source_key}: {e}") - - def __del__(self): - """Cleanup when object is destroyed.""" - self.flush() - - def _validate_record_path(self, record_path: tuple[str, ...]) -> None: - # TODO: consider removing this as path creation can be tried directly - """ - Validate source path components. - - Args: - record_path: Tuple of path components - - Raises: - ValueError: If path is invalid - """ - if not record_path: - raise ValueError("Source path cannot be empty") - - if len(record_path) > self.max_hierarchy_depth: - raise ValueError( - f"Source path depth {len(record_path)} exceeds maximum {self.max_hierarchy_depth}" - ) - - # Validate path components - for i, component in enumerate(record_path): - if not component or not isinstance(component, str): - raise ValueError( - f"Source path component {i} is invalid: {repr(component)}" - ) - - # Check for filesystem-unsafe characters - unsafe_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\0"] - if any(char in component for char in unsafe_chars): - raise ValueError( - f"Source path {record_path} component {component} contains invalid characters: {repr(component)}" - ) - - def _get_source_key(self, record_path: tuple[str, ...]) -> str: - """Generate cache key for source storage.""" - return "/".join(record_path) - - def _get_table_path(self, record_path: tuple[str, ...]) -> Path: - """Get the filesystem path for a given source path.""" - path = self.base_path - for subpath in record_path: - path = path / subpath - return path - - def _get_existing_delta_table( - self, record_path: tuple[str, ...] - ) -> DeltaTable | None: - """ - Get or create a Delta table, handling schema initialization properly. - - Args: - record_path: Tuple of path components - - Returns: - DeltaTable instance or None if table doesn't exist - """ - source_key = self._get_source_key(record_path) - table_path = self._get_table_path(record_path) - - # Check cache first - if dt := self._delta_table_cache.get(source_key): - return dt - - try: - # Try to load existing table - delta_table = DeltaTable(str(table_path)) - self._delta_table_cache[source_key] = delta_table - logger.debug(f"Loaded existing Delta table for {source_key}") - return delta_table - except TableNotFoundError: - # Table doesn't exist - return None - except Exception as e: - logger.error(f"Error loading Delta table for {source_key}: {e}") - # Try to clear any corrupted cache and retry once - if source_key in self._delta_table_cache: - del self._delta_table_cache[source_key] - return None - - def _ensure_record_id_column( - self, arrow_data: "pa.Table", record_id: str - ) -> "pa.Table": - """Ensure the table has an record id column.""" - if self.RECORD_ID_COLUMN not in arrow_data.column_names: - # Add record_id column at the beginning - key_array = pa.array([record_id] * len(arrow_data), type=pa.large_string()) - arrow_data = arrow_data.add_column(0, self.RECORD_ID_COLUMN, key_array) - return arrow_data - - def _remove_record_id_column(self, arrow_data: "pa.Table") -> "pa.Table": - """Remove the record id column if it exists.""" - if self.RECORD_ID_COLUMN in arrow_data.column_names: - column_names = arrow_data.column_names - indices_to_keep = [ - i - for i, name in enumerate(column_names) - if name != self.RECORD_ID_COLUMN - ] - arrow_data = arrow_data.select(indices_to_keep) - return arrow_data - - def _handle_record_id_column( - self, arrow_data: "pa.Table", record_id_column: str | None = None - ) -> "pa.Table": - """ - Handle record_id column based on add_record_id_column parameter. - - Args: - arrow_data: Arrow table with record id column - record_id_column: Control entry ID column inclusion: - - """ - if not record_id_column: - # Remove the record id column - return self._remove_record_id_column(arrow_data) - - # Rename record id column - if self.RECORD_ID_COLUMN in arrow_data.column_names: - schema = arrow_data.schema - new_names = [ - record_id_column if name == self.RECORD_ID_COLUMN else name - for name in schema.names - ] - return arrow_data.rename_columns(new_names) - else: - raise ValueError( - f"Record ID column '{self.RECORD_ID_COLUMN}' not found in the table and cannot be renamed." - ) - - def _create_record_id_filter(self, record_id: str) -> list: - """ - Create a proper filter expression for Delta Lake. - - Args: - record_id: The entry ID to filter by - - Returns: - List containing the filter expression for Delta Lake - """ - return [(self.RECORD_ID_COLUMN, "=", record_id)] - - def _create_record_ids_filter(self, record_ids: list[str]) -> list: - """ - Create a proper filter expression for multiple entry IDs. - - Args: - record_ids: List of entry IDs to filter by - - Returns: - List containing the filter expression for Delta Lake - """ - return [(self.RECORD_ID_COLUMN, "in", record_ids)] - - def _read_table_with_filter( - self, - delta_table: DeltaTable, - filters: list | None = None, - ) -> "pa.Table": - """ - Read table using to_pyarrow_dataset with original schema preservation. - - Args: - delta_table: The Delta table to read from - filters: Optional filters to apply - - Returns: - Arrow table with preserved schema - """ - # Use to_pyarrow_dataset with as_large_types for Polars compatible arrow table loading - dataset = delta_table.to_pyarrow_dataset(as_large_types=True) - if filters: - # Apply filters at dataset level for better performance - import pyarrow.compute as pc - - filter_expr = None - for filt in filters: - if len(filt) == 3: - col, op, val = filt - if op == "=": - expr = pc.equal(pc.field(col), pa.scalar(val)) # type: ignore - elif op == "in": - expr = pc.is_in(pc.field(col), pa.array(val)) # type: ignore - else: - logger.warning( - f"Unsupported filter operation: {op}. Falling back to table-level filter application which may be less efficient." - ) - # Fallback to table-level filtering - return dataset.to_table()(filters=filters) - - if filter_expr is None: - filter_expr = expr - else: - filter_expr = pc.and_(filter_expr, expr) # type: ignore - - if filter_expr is not None: - return dataset.to_table(filter=filter_expr) - - return dataset.to_table() - - def add_record( - self, - record_path: tuple[str, ...], - record_id: str, - data: "pa.Table", - ignore_duplicates: bool | None = None, - overwrite_existing: bool = False, - force_flush: bool = False, - ) -> "pa.Table": - self._validate_record_path(record_path) - source_key = self._get_source_key(record_path) - - # Check for existing entry - if ignore_duplicates is None: - ignore_duplicates = self.duplicate_entry_behavior != "error" - if not ignore_duplicates: - pending_table = self._pending_batches[source_key].get(record_id, None) - if pending_table is not None: - raise ValueError( - f"Entry '{record_id}' already exists in pending batch for {source_key}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - existing_record = self.get_record_by_id(record_path, record_id, flush=False) - if existing_record is not None: - raise ValueError( - f"Entry '{record_id}' already exists in {'/'.join(record_path)}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - - # Add record_id column to the data - data_with_record_id = self._ensure_record_id_column(data, record_id) - - if force_flush: - # Write immediately - table_path = self._get_table_path(record_path) - table_path.mkdir(parents=True, exist_ok=True) - - delta_table = self._get_existing_delta_table(record_path) - - if delta_table is None: - # Create new table - save original schema first - write_deltalake(str(table_path), data_with_record_id, mode="overwrite") - logger.debug(f"Created new Delta table for {source_key}") - else: - if self.duplicate_entry_behavior == "overwrite": - try: - delta_table.delete( - f"{self.RECORD_ID_COLUMN} = '{record_id.replace(chr(39), chr(39) + chr(39))}'" - ) - logger.debug( - f"Deleted existing record {record_id} from {source_key}" - ) - except Exception as e: - logger.debug( - f"No existing record to delete for {record_id}: {e}" - ) - - write_deltalake( - table_path, - data_with_record_id, - mode="append", - schema_mode="merge", - ) - - # Update cache - self._delta_table_cache[source_key] = DeltaTable(str(table_path)) - else: - # Add to the batch for later flushing - self._pending_batches[source_key][record_id] = data_with_record_id - batch_size = len(self._pending_batches[source_key]) - - # Check if we need to flush - if batch_size >= self.batch_size: - self.flush_batch(record_path) - - logger.debug(f"Added record {record_id} to {source_key}") - return data - - def add_records( - self, - record_path: tuple[str, ...], - records: "pa.Table", - record_id_column: str | None = None, - ignore_duplicates: bool | None = None, - overwrite_existing: bool = False, - force_flush: bool = False, - ) -> list[str]: - """ - Add multiple records to the Delta table, using one column as record_id. - - Args: - record_path: Path tuple identifying the table location - records: PyArrow table containing the records to add - record_id_column: Column name to use as record_id (defaults to first column) - ignore_duplicates: Whether to ignore duplicate entries - overwrite_existing: Whether to overwrite existing records with same ID - force_flush: Whether to write immediately instead of batching - - Returns: - List of record IDs that were added - """ - self._validate_record_path(record_path) - source_key = self._get_source_key(record_path) - - # Determine record_id column - if record_id_column is None: - record_id_column = records.column_names[0] - - # Validate that the record_id column exists - if record_id_column not in records.column_names: - raise ValueError( - f"Record ID column '{record_id_column}' not found in table. " - f"Available columns: {records.column_names}" - ) - - # Rename the record_id column to the standard name - column_mapping = {record_id_column: self.RECORD_ID_COLUMN} - records_renamed = records.rename_columns( - [column_mapping.get(col, col) for col in records.column_names] - ) - - # Get unique record IDs from the data - record_ids_array = records_renamed[self.RECORD_ID_COLUMN] - unique_record_ids = pc.unique(record_ids_array).to_pylist() - - # Set default behavior for duplicates - if ignore_duplicates is None: - ignore_duplicates = self.duplicate_entry_behavior != "error" - - added_record_ids = [] - - # Check for duplicates if needed - if not ignore_duplicates: - # Check pending batches - pending_duplicates = [] - for record_id in unique_record_ids: - if record_id in self._pending_batches[source_key]: - pending_duplicates.append(record_id) - - if pending_duplicates: - raise ValueError( - f"Records {pending_duplicates} already exist in pending batch for {source_key}. " - f"Use ignore_duplicates=True or duplicate_entry_behavior='overwrite' to allow updates." - ) - - # Check existing table - existing_duplicates = [] - try: - for record_id in unique_record_ids: - existing_record = self.get_record_by_id( - record_path, str(record_id), flush=False - ) - if existing_record is not None: - existing_duplicates.append(record_id) - except Exception as e: - logger.debug(f"Error checking existing records: {e}") - - if existing_duplicates: - raise ValueError( - f"Records {existing_duplicates} already exist in {'/'.join(record_path)}. " - f"Use ignore_duplicates=True or duplicate_entry_behavior='overwrite' to allow updates." - ) - - if force_flush: - # Write immediately - table_path = self._get_table_path(record_path) - table_path.mkdir(parents=True, exist_ok=True) - - delta_table = self._get_existing_delta_table(record_path) - - if delta_table is None: - # Create new table - write_deltalake(str(table_path), records_renamed, mode="overwrite") - logger.debug(f"Created new Delta table for {source_key}") - added_record_ids = unique_record_ids - else: - # Handle existing table - if self.duplicate_entry_behavior == "overwrite" or overwrite_existing: - # Delete existing records with matching IDs - try: - # Create SQL condition for multiple record IDs - escaped_ids = [ - str(rid).replace("'", "''") for rid in unique_record_ids - ] - id_list = "', '".join(escaped_ids) - delete_condition = f"{self.RECORD_ID_COLUMN} IN ('{id_list}')" - - delta_table.delete(delete_condition) - logger.debug( - f"Deleted existing records {unique_record_ids} from {source_key}" - ) - except Exception as e: - logger.debug(f"No existing records to delete: {e}") - - # Filter out duplicates if not overwriting - if not ( - self.duplicate_entry_behavior == "overwrite" or overwrite_existing - ): - # Get existing record IDs - try: - existing_table = delta_table.to_pyarrow_table() - if len(existing_table) > 0: - existing_ids = pc.unique( - existing_table[self.RECORD_ID_COLUMN] - ) - - # Filter out records that already exist - mask = pc.invert( - pc.is_in( - records_renamed[self.RECORD_ID_COLUMN], existing_ids - ) - ) - records_renamed = pc.filter(records_renamed, mask) # type: ignore - - # Update the list of record IDs that will actually be added - if len(records_renamed) > 0: - added_record_ids = pc.unique( - records_renamed[self.RECORD_ID_COLUMN] - ).to_pylist() - else: - added_record_ids = [] - else: - added_record_ids = unique_record_ids - except Exception as e: - logger.debug(f"Error filtering duplicates: {e}") - added_record_ids = unique_record_ids - else: - added_record_ids = unique_record_ids - - # Append the (possibly filtered) records - if len(records_renamed) > 0: - write_deltalake( - table_path, - records_renamed, - mode="append", - schema_mode="merge", - ) - - # Update cache - self._delta_table_cache[source_key] = DeltaTable(str(table_path)) - - else: - # Add to batches for later flushing - # Group records by record_id for individual batch entries - for record_id in unique_record_ids: - # Filter records for this specific record_id - mask = pc.equal(records_renamed[self.RECORD_ID_COLUMN], record_id) # type: ignore - single_record = pc.filter(records_renamed, mask) # type: ignore - - # Add to pending batch (will overwrite if duplicate_entry_behavior allows) - if ( - self.duplicate_entry_behavior == "overwrite" - or overwrite_existing - or record_id not in self._pending_batches[source_key] - ): - self._pending_batches[source_key][str(record_id)] = single_record - added_record_ids.append(record_id) - elif ignore_duplicates: - logger.debug(f"Ignoring duplicate record {record_id}") - else: - # This should have been caught earlier, but just in case - logger.warning(f"Skipping duplicate record {record_id}") - - # Check if we need to flush - batch_size = len(self._pending_batches[source_key]) - if batch_size >= self.batch_size: - self.flush_batch(record_path) - - logger.debug(f"Added {len(added_record_ids)} records to {source_key}") - return [str(rid) for rid in added_record_ids] - - def get_record_by_id( - self, - record_path: tuple[str, ...], - record_id: str, - record_id_column: str | None = None, - flush: bool = False, - ) -> "pa.Table | None": - """ - Get a specific record by record_id with schema preservation. - - Args: - record_path: Tuple of path components - record_id: Unique identifier for the record - - Returns: - Arrow table for the record or None if not found - """ - - if flush: - self.flush_batch(record_path) - self._validate_record_path(record_path) - - # check if record_id is found in pending batches - source_key = self._get_source_key(record_path) - if record_id in self._pending_batches[source_key]: - # Return the pending record after removing the entry id column - return self._remove_record_id_column( - self._pending_batches[source_key][record_id] - ) - - delta_table = self._get_existing_delta_table(record_path) - if delta_table is None: - return None - - try: - # Use schema-preserving read - filter_expr = self._create_record_id_filter(record_id) - result = self._read_table_with_filter(delta_table, filters=filter_expr) - - if len(result) == 0: - return None - - # Handle (remove/rename) the record id column before returning - return self._handle_record_id_column(result, record_id_column) - - except Exception as e: - logger.error( - f"Error getting record {record_id} from {'/'.join(record_path)}: {e}" - ) - raise e - - def get_all_records( - self, - record_path: tuple[str, ...], - record_id_column: str | None = None, - retrieve_pending: bool = True, - flush: bool = False, - ) -> "pa.Table | None": - """ - Retrieve all records for a given source path as a single table with schema preservation. - - Args: - record_path: Tuple of path components - record_id_column: If not None or empty, record id is returned in the result with the specified column name - - Returns: - Arrow table containing all records with original schema, or None if no records found - """ - # TODO: this currently reads everything into memory and then return. Consider implementation that performs everything lazily - - if flush: - self.flush_batch(record_path) - self._validate_record_path(record_path) - - collected_tables = [] - if retrieve_pending: - # Check if there are pending records in the batch - for record_id, arrow_table in self._pending_batches[ - self._get_source_key(record_path) - ].items(): - collected_tables.append( - self._ensure_record_id_column(arrow_table, record_id) - ) - - delta_table = self._get_existing_delta_table(record_path) - if delta_table is not None: - try: - # Use filter-based read - result = self._read_table_with_filter(delta_table) - - if len(result) != 0: - collected_tables.append(result) - - except Exception as e: - logger.error( - f"Error getting all records from {'/'.join(record_path)}: {e}" - ) - if collected_tables: - total_table = pa.concat_tables(collected_tables) - - # Handle record_id column based on parameter - return self._handle_record_id_column(total_table, record_id_column) - - return None - - def get_records_by_ids( - self, - record_path: tuple[str, ...], - record_ids: "list[str] | pl.Series | pa.Array", - record_id_column: str | None = None, - flush: bool = False, - ) -> "pa.Table | None": - """ - Retrieve records by entry IDs as a single table with schema preservation. - - Args: - record_path: Tuple of path components - record_ids: Entry IDs to retrieve - add_record_id_column: Control entry ID column inclusion - preserve_input_order: If True, return results in input order with nulls for missing - - Returns: - Arrow table containing all found records with original schema, or None if no records found - """ - - if flush: - self.flush_batch(record_path) - - self._validate_record_path(record_path) - - # Convert input to list of strings for consistency - if isinstance(record_ids, list): - if not record_ids: - return None - record_ids_list = record_ids - elif isinstance(record_ids, pl.Series): - if len(record_ids) == 0: - return None - record_ids_list = record_ids.to_list() - elif isinstance(record_ids, (pa.Array, pa.ChunkedArray)): - if len(record_ids) == 0: - return None - record_ids_list = record_ids.to_pylist() - else: - raise TypeError( - f"record_ids must be list[str], pl.Series, or pa.Array, got {type(record_ids)}" - ) - - delta_table = self._get_existing_delta_table(record_path) - if delta_table is None: - return None - - try: - # Use schema-preserving read with filters - filter_expr = self._create_record_ids_filter( - cast(list[str], record_ids_list) - ) - result = self._read_table_with_filter(delta_table, filters=filter_expr) - - if len(result) == 0: - return None - - # Handle record_id column based on parameter - return self._handle_record_id_column(result, record_id_column) - - except Exception as e: - logger.error( - f"Error getting records by IDs from {'/'.join(record_path)}: {e}" - ) - return None - - def get_pending_batch_info(self) -> dict[str, int]: - """ - Get information about pending batches. - - Returns: - Dictionary mapping source keys to number of pending records - """ - return { - source_key: len(tables) - for source_key, tables in self._pending_batches.items() - if tables - } - - def list_sources(self) -> list[tuple[str, ...]]: - """ - List all available source paths. - - Returns: - List of source path tuples - """ - sources = [] - - def _scan_directory(current_path: Path, path_components: tuple[str, ...]): - """Recursively scan for Delta tables.""" - for item in current_path.iterdir(): - if not item.is_dir(): - continue - - new_path_components = path_components + (item.name,) - - # Check if this directory contains a Delta table - try: - DeltaTable(str(item)) - sources.append(new_path_components) - except TableNotFoundError: - # Not a Delta table, continue scanning subdirectories - if len(new_path_components) < self.max_hierarchy_depth: - _scan_directory(item, new_path_components) - - _scan_directory(self.base_path, ()) - return sources - - def delete_source(self, record_path: tuple[str, ...]) -> bool: - """ - Delete an entire source (all records for a source path). - - Args: - record_path: Tuple of path components - - Returns: - True if source was deleted, False if it didn't exist - """ - self._validate_record_path(record_path) - - # Flush any pending batches first - self.flush_batch(record_path) - - table_path = self._get_table_path(record_path) - source_key = self._get_source_key(record_path) - - if not table_path.exists(): - return False - - try: - # Remove from caches - if source_key in self._delta_table_cache: - del self._delta_table_cache[source_key] - - # Remove directory - import shutil - - shutil.rmtree(table_path) - - logger.info(f"Deleted source {source_key}") - return True - - except Exception as e: - logger.error(f"Error deleting source {source_key}: {e}") - return False - - def delete_record(self, record_path: tuple[str, ...], record_id: str) -> bool: - """ - Delete a specific record. - - Args: - record_path: Tuple of path components - record_id: ID of the record to delete - - Returns: - True if record was deleted, False if it didn't exist - """ - self._validate_record_path(record_path) - - # Flush any pending batches first - self.flush_batch(record_path) - - delta_table = self._get_existing_delta_table(record_path) - if delta_table is None: - return False - - try: - # Check if record exists using proper filter - filter_expr = self._create_record_id_filter(record_id) - existing = self._read_table_with_filter(delta_table, filters=filter_expr) - if len(existing) == 0: - return False - - # Delete the record using SQL-style predicate (this is correct for delete operations) - delta_table.delete( - f"{self.RECORD_ID_COLUMN} = '{record_id.replace(chr(39), chr(39) + chr(39))}'" - ) - - # Update cache - source_key = self._get_source_key(record_path) - self._delta_table_cache[source_key] = delta_table - - logger.debug(f"Deleted record {record_id} from {'/'.join(record_path)}") - return True - - except Exception as e: - logger.error( - f"Error deleting record {record_id} from {'/'.join(record_path)}: {e}" - ) - return False - - def get_table_info(self, record_path: tuple[str, ...]) -> dict[str, Any] | None: - """ - Get metadata information about a Delta table. - - Args: - record_path: Tuple of path components - - Returns: - Dictionary with table metadata, or None if table doesn't exist - """ - self._validate_record_path(record_path) - - delta_table = self._get_existing_delta_table(record_path) - if delta_table is None: - return None - - try: - # Get basic info - schema = delta_table.schema() - history = delta_table.history() - source_key = self._get_source_key(record_path) - - # Add pending batch info - pending_info = self.get_pending_batch_info() - pending_count = pending_info.get(source_key, 0) - - return { - "path": str(self._get_table_path(record_path)), - "record_path": record_path, - "schema": schema, - "version": delta_table.version(), - "num_files": len(delta_table.files()), - "history_length": len(history), - "latest_commit": history[0] if history else None, - "pending_records": pending_count, - } - - except Exception as e: - logger.error(f"Error getting table info for {'/'.join(record_path)}: {e}") - return None diff --git a/src/orcapod/databases/delta_lake_databases.py b/src/orcapod/databases/delta_lake_databases.py index 0270b295..0ee7f4e5 100644 --- a/src/orcapod/databases/delta_lake_databases.py +++ b/src/orcapod/databases/delta_lake_databases.py @@ -79,11 +79,25 @@ def _get_record_key(self, record_path: tuple[str, ...]) -> str: """Generate cache key for source storage.""" return "/".join(record_path) + @staticmethod + def _sanitize_path_component(component: str) -> str: + """Sanitize a path component for the current OS. + + On Windows, colons are not allowed in filenames (reserved for drive + letters). Replace them with '!' so that URIs containing ':' can still + be stored safely on all platforms. + """ + import sys + + if sys.platform == "win32": + return component.replace(":", "!") + return component + def _get_table_path(self, record_path: tuple[str, ...]) -> Path: """Get the filesystem path for a given source path.""" path = self.base_path for subpath in record_path: - path = path / subpath + path = path / self._sanitize_path_component(subpath) return path def _validate_record_path(self, record_path: tuple[str, ...]) -> None: @@ -112,8 +126,10 @@ def _validate_record_path(self, record_path: tuple[str, ...]) -> None: f"Source path component {i} is invalid: {repr(component)}" ) - # Check for filesystem-unsafe characters - unsafe_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\0"] + # Check for filesystem-unsafe characters. + # ':' is handled by _sanitize_path_component (replaced on Windows), + # so it is intentionally absent from this list. + unsafe_chars = ["/", "\\", "*", "?", '"', "<", ">", "|", "\0"] if any(char in component for char in unsafe_chars): raise ValueError( f"Source path {record_path} component {component} contains invalid characters: {repr(component)}" @@ -315,6 +331,8 @@ def add_records( Raises: ValueError: If any record IDs already exist and skip_duplicates=False """ + self._validate_record_path(record_path) + if records.num_rows == 0: return @@ -804,6 +822,27 @@ def flush(self) -> None: except Exception as e: logger.error(f"Error flushing batch for {record_key}: {e}") + def to_config(self) -> dict[str, Any]: + """Serialize database configuration to a JSON-compatible dict.""" + return { + "type": "delta_table", + "base_path": str(self.base_path), + "batch_size": self.batch_size, + "max_hierarchy_depth": self.max_hierarchy_depth, + "allow_schema_evolution": self.allow_schema_evolution, + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "DeltaTableDatabase": + """Reconstruct a DeltaTableDatabase from a config dict.""" + return cls( + base_path=config["base_path"], + create_base_path=True, + batch_size=config.get("batch_size", 1000), + max_hierarchy_depth=config.get("max_hierarchy_depth", 10), + allow_schema_evolution=config.get("allow_schema_evolution", True), + ) + def flush_batch(self, record_path: tuple[str, ...]) -> None: """ Flush pending batch for a specific source path. diff --git a/src/orcapod/databases/in_memory_databases.py b/src/orcapod/databases/in_memory_databases.py new file mode 100644 index 00000000..9e31e0ba --- /dev/null +++ b/src/orcapod/databases/in_memory_databases.py @@ -0,0 +1,379 @@ +import logging +from collections import defaultdict +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any, cast + +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa + import pyarrow.compute as pc +else: + pa = LazyModule("pyarrow") + pc = LazyModule("pyarrow.compute") + +logger = logging.getLogger(__name__) + + +class InMemoryArrowDatabase: + """ + A pure in-memory implementation of the ArrowDatabaseProtocol protocol. + + Records are stored in PyArrow tables held in process memory. + Data is lost when the process exits — intended for tests and ephemeral use. + + Supports the same pending-batch semantics as DeltaTableDatabase: + records are buffered in a pending batch and become part of the committed + store only after flush() is called (or flush=True is passed to a write method). + """ + + RECORD_ID_COLUMN = "__record_id" + + def __init__(self, max_hierarchy_depth: int = 10): + self.max_hierarchy_depth = max_hierarchy_depth + self._tables: dict[str, pa.Table] = {} + self._pending_batches: dict[str, pa.Table] = {} + self._pending_record_ids: dict[str, set[str]] = defaultdict(set) + + # ------------------------------------------------------------------ + # Path helpers + # ------------------------------------------------------------------ + + def _get_record_key(self, record_path: tuple[str, ...]) -> str: + return "/".join(record_path) + + def _validate_record_path(self, record_path: tuple[str, ...]) -> None: + if not record_path: + raise ValueError("record_path cannot be empty") + + if len(record_path) > self.max_hierarchy_depth: + raise ValueError( + f"record_path depth {len(record_path)} exceeds maximum {self.max_hierarchy_depth}" + ) + + # Only restrict characters that break the "/".join(record_path) key scheme. + # Unlike DeltaTableDatabase (filesystem-backed), there are no OS-level restrictions here. + unsafe_chars = ["/", "\0"] + for i, component in enumerate(record_path): + if not component or not isinstance(component, str): + raise ValueError( + f"record_path component {i} is invalid: {repr(component)}" + ) + if any(char in component for char in unsafe_chars): + raise ValueError( + f"record_path component {repr(component)} contains invalid characters" + ) + + # ------------------------------------------------------------------ + # Record-ID column helpers + # ------------------------------------------------------------------ + + def _ensure_record_id_column( + self, arrow_data: "pa.Table", record_id: str + ) -> "pa.Table": + if self.RECORD_ID_COLUMN not in arrow_data.column_names: + key_array = pa.array([record_id] * len(arrow_data), type=pa.large_string()) + arrow_data = arrow_data.add_column(0, self.RECORD_ID_COLUMN, key_array) + return arrow_data + + def _remove_record_id_column(self, arrow_data: "pa.Table") -> "pa.Table": + if self.RECORD_ID_COLUMN in arrow_data.column_names: + arrow_data = arrow_data.drop([self.RECORD_ID_COLUMN]) + return arrow_data + + def _handle_record_id_column( + self, arrow_data: "pa.Table", record_id_column: str | None = None + ) -> "pa.Table": + if not record_id_column: + return self._remove_record_id_column(arrow_data) + if self.RECORD_ID_COLUMN in arrow_data.column_names: + new_names = [ + record_id_column if name == self.RECORD_ID_COLUMN else name + for name in arrow_data.schema.names + ] + return arrow_data.rename_columns(new_names) + raise ValueError( + f"Record ID column '{self.RECORD_ID_COLUMN}' not found in the table." + ) + + # ------------------------------------------------------------------ + # Deduplication + # ------------------------------------------------------------------ + + def _deduplicate_within_table(self, table: "pa.Table") -> "pa.Table": + """Keep the last occurrence of each record ID within a single table.""" + if table.num_rows <= 1: + return table + + ROW_INDEX = "__row_index" + indices = pa.array(range(table.num_rows)) + table_with_idx = table.add_column(0, ROW_INDEX, indices) + grouped = table_with_idx.group_by([self.RECORD_ID_COLUMN]).aggregate( + [(ROW_INDEX, "max")] + ) + max_indices = grouped[f"{ROW_INDEX}_max"].to_pylist() + mask = pc.is_in(indices, pa.array(max_indices)) + return table.filter(mask) + + # ------------------------------------------------------------------ + # Internal helpers for duplicate detection + # ------------------------------------------------------------------ + + def _committed_ids(self, record_key: str) -> set[str]: + committed = self._tables.get(record_key) + if committed is None or committed.num_rows == 0: + return set() + existing_ids = committed[self.RECORD_ID_COLUMN].to_pylist() + existing_ids = [str(id) for id in existing_ids if id is not None] + # TODO: evaluate the efficiency of this implementation + return set(existing_ids) + + def _filter_existing_records( + self, record_key: str, table: "pa.Table" + ) -> "pa.Table": + """Filter out records whose IDs are already in pending or committed store.""" + input_ids = set(table[self.RECORD_ID_COLUMN].to_pylist()) + all_existing = input_ids & ( + self._pending_record_ids[record_key] | self._committed_ids(record_key) + ) + if not all_existing: + return table + mask = pc.invert( + pc.is_in(table[self.RECORD_ID_COLUMN], pa.array(list(all_existing))) + ) + return table.filter(mask) + + # ------------------------------------------------------------------ + # Write methods + # ------------------------------------------------------------------ + + def add_record( + self, + record_path: tuple[str, ...], + record_id: str, + record: "pa.Table", + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + data_with_id = self._ensure_record_id_column(record, record_id) + self.add_records( + record_path=record_path, + records=data_with_id, + record_id_column=self.RECORD_ID_COLUMN, + skip_duplicates=skip_duplicates, + flush=flush, + ) + + def add_records( + self, + record_path: tuple[str, ...], + records: "pa.Table", + record_id_column: str | None = None, + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + self._validate_record_path(record_path) + + if records.num_rows == 0: + return + + if record_id_column is None: + record_id_column = records.column_names[0] + + if record_id_column not in records.column_names: + raise ValueError( + f"record_id_column '{record_id_column}' not found in table columns: " + f"{records.column_names}" + ) + + # Normalise to internal column name + if record_id_column != self.RECORD_ID_COLUMN: + rename_map = {record_id_column: self.RECORD_ID_COLUMN} + records = records.rename_columns( + [rename_map.get(c, c) for c in records.column_names] + ) + + # Deduplicate within the incoming batch (keep last) + records = self._deduplicate_within_table(records) + + record_key = self._get_record_key(record_path) + + if skip_duplicates: + records = self._filter_existing_records(record_key, records) + if records.num_rows == 0: + return + else: + # Check for conflicts in the pending batch only + input_ids = set(records[self.RECORD_ID_COLUMN].to_pylist()) + pending_conflicts = input_ids & self._pending_record_ids[record_key] + if pending_conflicts: + raise ValueError( + f"Records with IDs {pending_conflicts} already exist in the " + f"pending batch. Use skip_duplicates=True to skip them." + ) + + # Add to pending batch + existing_pending = self._pending_batches.get(record_key) + if existing_pending is None: + self._pending_batches[record_key] = records + else: + self._pending_batches[record_key] = pa.concat_tables( + [existing_pending, records] + ) + pending_ids = cast(list[str], records[self.RECORD_ID_COLUMN].to_pylist()) + self._pending_record_ids[record_key].update(pending_ids) + + if flush: + self.flush() + + # ------------------------------------------------------------------ + # Flush + # ------------------------------------------------------------------ + + def flush(self) -> None: + for record_key in list(self._pending_batches.keys()): + pending = self._pending_batches.pop(record_key) + self._pending_record_ids.pop(record_key, None) + + committed = self._tables.get(record_key) + if committed is None: + self._tables[record_key] = pending + else: + # Insert-if-not-exists: keep committed rows not overwritten by new batch, + # then append the new batch on top. + new_ids = set(pending[self.RECORD_ID_COLUMN].to_pylist()) + mask = pc.invert( + pc.is_in(committed[self.RECORD_ID_COLUMN], pa.array(list(new_ids))) + ) + kept = committed.filter(mask) + self._tables[record_key] = pa.concat_tables([kept, pending]) + + # ------------------------------------------------------------------ + # Read helpers + # ------------------------------------------------------------------ + + def _combined_table(self, record_key: str) -> "pa.Table | None": + """Return pending + committed data for a key, or None if nothing exists.""" + parts = [] + committed = self._tables.get(record_key) + if committed is not None and committed.num_rows > 0: + parts.append(committed) + pending = self._pending_batches.get(record_key) + if pending is not None and pending.num_rows > 0: + parts.append(pending) + if not parts: + return None + return parts[0] if len(parts) == 1 else pa.concat_tables(parts) + + # ------------------------------------------------------------------ + # Read methods + # ------------------------------------------------------------------ + + def get_record_by_id( + self, + record_path: tuple[str, ...], + record_id: str, + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + + record_key = self._get_record_key(record_path) + + # Check pending first + if record_id in self._pending_record_ids[record_key]: + pending = self._pending_batches[record_key] + filtered = pending.filter(pc.field(self.RECORD_ID_COLUMN) == record_id) + if filtered.num_rows > 0: + return self._handle_record_id_column(filtered, record_id_column) + + # Check committed store + committed = self._tables.get(record_key) + if committed is None: + return None + filtered = committed.filter(pc.field(self.RECORD_ID_COLUMN) == record_id) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + def get_all_records( + self, + record_path: tuple[str, ...], + record_id_column: str | None = None, + ) -> "pa.Table | None": + record_key = self._get_record_key(record_path) + table = self._combined_table(record_key) + if table is None: + return None + return self._handle_record_id_column(table, record_id_column) + + def get_records_by_ids( + self, + record_path: tuple[str, ...], + record_ids: "Collection[str]", + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + + record_ids_list = list(record_ids) + if not record_ids_list: + return None + + record_key = self._get_record_key(record_path) + table = self._combined_table(record_key) + if table is None: + return None + + filtered = table.filter( + pc.is_in(table[self.RECORD_ID_COLUMN], pa.array(record_ids_list)) + ) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) + + def to_config(self) -> dict[str, Any]: + """Serialize database configuration to a JSON-compatible dict.""" + return { + "type": "in_memory", + "max_hierarchy_depth": self.max_hierarchy_depth, + } + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "InMemoryArrowDatabase": + """Reconstruct an InMemoryArrowDatabase from a config dict.""" + return cls( + max_hierarchy_depth=config.get("max_hierarchy_depth", 10), + ) + + def get_records_with_column_value( + self, + record_path: tuple[str, ...], + column_values: "Collection[tuple[str, Any]] | Mapping[str, Any]", + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + if flush: + self.flush() + + record_key = self._get_record_key(record_path) + table = self._combined_table(record_key) + if table is None: + return None + + if isinstance(column_values, Mapping): + pair_list = list(column_values.items()) + else: + pair_list = cast(list[tuple[str, Any]], list(column_values)) + + expressions = [pc.field(c) == v for c, v in pair_list] + combined_expr = expressions[0] + for expr in expressions[1:]: + combined_expr = combined_expr & expr + + filtered = table.filter(combined_expr) + if filtered.num_rows == 0: + return None + return self._handle_record_id_column(filtered, record_id_column) diff --git a/src/orcapod/databases/legacy/delta_table_arrow_data_store.py b/src/orcapod/databases/legacy/delta_table_arrow_data_store.py deleted file mode 100644 index 56bbbfa7..00000000 --- a/src/orcapod/databases/legacy/delta_table_arrow_data_store.py +++ /dev/null @@ -1,864 +0,0 @@ -import pyarrow as pa -import pyarrow.compute as pc -import pyarrow.dataset as ds -import polars as pl -from pathlib import Path -from typing import Any -import logging -from deltalake import DeltaTable, write_deltalake -from deltalake.exceptions import TableNotFoundError -from collections import defaultdict - - -# Module-level logger -logger = logging.getLogger(__name__) - - -class DeltaTableArrowDataStore: - """ - Delta Table-based Arrow data store with flexible hierarchical path support and schema preservation. - - Uses tuple-based source paths for robust parameter handling: - - ("source_name", "source_id") -> source_name/source_id/ - - ("org", "project", "dataset") -> org/project/dataset/ - - ("year", "month", "day", "experiment") -> year/month/day/experiment/ - """ - - def __init__( - self, - base_path: str | Path, - duplicate_entry_behavior: str = "error", - create_base_path: bool = True, - max_hierarchy_depth: int = 10, - batch_size: int = 100, - ): - """ - Initialize the DeltaTableArrowDataStore. - - Args: - base_path: Base directory path where Delta tables will be stored - duplicate_entry_behavior: How to handle duplicate entry_ids: - - 'error': Raise ValueError when entry_id already exists - - 'overwrite': Replace existing entry with new data - create_base_path: Whether to create the base path if it doesn't exist - max_hierarchy_depth: Maximum allowed depth for source paths (safety limit) - batch_size: Number of records to batch before writing to Delta table - auto_flush_interval: Time in seconds to auto-flush pending batches (0 to disable) - """ - # Validate duplicate behavior - if duplicate_entry_behavior not in ["error", "overwrite"]: - raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - - self.duplicate_entry_behavior = duplicate_entry_behavior - self.base_path = Path(base_path) - self.max_hierarchy_depth = max_hierarchy_depth - self.batch_size = batch_size - - if create_base_path: - self.base_path.mkdir(parents=True, exist_ok=True) - elif not self.base_path.exists(): - raise ValueError( - f"Base path {self.base_path} does not exist and create_base_path=False" - ) - - # Cache for Delta tables to avoid repeated initialization - self._delta_table_cache: dict[str, DeltaTable] = {} - - # Batch management - self._pending_batches: dict[str, dict[str, pa.Table]] = defaultdict(dict) - - logger.info( - f"Initialized DeltaTableArrowDataStore at {self.base_path} " - f"with duplicate_entry_behavior='{duplicate_entry_behavior}', " - f"batch_size={batch_size}, as" - ) - - def flush(self) -> None: - """ - Flush all pending batches immediately. - - This method is called to ensure all pending data is written to the Delta tables. - """ - try: - self.flush_all_batches() - except Exception as e: - logger.error(f"Error during flush: {e}") - - def flush_batch(self, source_path: tuple[str, ...]) -> None: - """ - Flush pending batch for a specific source path. - - Args: - source_path: Tuple of path components - """ - logger.debug("Flushing triggered!!") - source_key = self._get_source_key(source_path) - - if ( - source_key not in self._pending_batches - or not self._pending_batches[source_key] - ): - return - - # Get all pending records - pending_tables = self._pending_batches[source_key] - self._pending_batches[source_key] = {} - - try: - # Combine all tables in the batch - combined_table = pa.concat_tables(pending_tables.values()).combine_chunks() - - table_path = self._get_table_path(source_path) - table_path.mkdir(parents=True, exist_ok=True) - - # Check if table exists - delta_table = self._get_existing_delta_table(source_path) - - if delta_table is None: - # TODO: reconsider mode="overwrite" here - write_deltalake( - table_path, - combined_table, - mode="overwrite", - ) - logger.debug( - f"Created new Delta table for {source_key} with {len(combined_table)} records" - ) - else: - if self.duplicate_entry_behavior == "overwrite": - # Get entry IDs from the batch - entry_ids = combined_table.column("__entry_id").to_pylist() - unique_entry_ids = list(set(entry_ids)) - - # Delete existing records with these IDs - if unique_entry_ids: - entry_ids_str = "', '".join(unique_entry_ids) - delete_predicate = f"__entry_id IN ('{entry_ids_str}')" - try: - delta_table.delete(delete_predicate) - logger.debug( - f"Deleted {len(unique_entry_ids)} existing records from {source_key}" - ) - except Exception as e: - logger.debug( - f"No existing records to delete from {source_key}: {e}" - ) - - # otherwise, only insert if same entry_id does not exist yet - delta_table.merge( - source=combined_table, - predicate="target.__entry_id = source.__entry_id", - source_alias="source", - target_alias="target", - ).when_not_matched_insert_all().execute() - - logger.debug( - f"Appended batch of {len(combined_table)} records to {source_key}" - ) - - # Update cache - self._delta_table_cache[source_key] = DeltaTable(str(table_path)) - - except Exception as e: - logger.error(f"Error flushing batch for {source_key}: {e}") - # Put the tables back in the pending queue - self._pending_batches[source_key] = pending_tables - raise - - def flush_all_batches(self) -> None: - """Flush all pending batches.""" - source_keys = list(self._pending_batches.keys()) - - # TODO: capture and re-raise exceptions at the end - for source_key in source_keys: - source_path = tuple(source_key.split("/")) - try: - self.flush_batch(source_path) - except Exception as e: - logger.error(f"Error flushing batch for {source_key}: {e}") - - def __del__(self): - """Cleanup when object is destroyed.""" - self.flush() - - def _validate_source_path(self, source_path: tuple[str, ...]) -> None: - # TODO: consider removing this as path creation can be tried directly - """ - Validate source path components. - - Args: - source_path: Tuple of path components - - Raises: - ValueError: If path is invalid - """ - if not source_path: - raise ValueError("Source path cannot be empty") - - if len(source_path) > self.max_hierarchy_depth: - raise ValueError( - f"Source path depth {len(source_path)} exceeds maximum {self.max_hierarchy_depth}" - ) - - # Validate path components - for i, component in enumerate(source_path): - if not component or not isinstance(component, str): - raise ValueError( - f"Source path component {i} is invalid: {repr(component)}" - ) - - # Check for filesystem-unsafe characters - unsafe_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|", "\0"] - if any(char in component for char in unsafe_chars): - raise ValueError( - f"Source path component contains invalid characters: {repr(component)}" - ) - - def _get_source_key(self, source_path: tuple[str, ...]) -> str: - """Generate cache key for source storage.""" - return "/".join(source_path) - - def _get_table_path(self, source_path: tuple[str, ...]) -> Path: - """Get the filesystem path for a given source path.""" - path = self.base_path - for subpath in source_path: - path = path / subpath - return path - - def _get_existing_delta_table( - self, source_path: tuple[str, ...] - ) -> DeltaTable | None: - """ - Get or create a Delta table, handling schema initialization properly. - - Args: - source_path: Tuple of path components - - Returns: - DeltaTable instance or None if table doesn't exist - """ - source_key = self._get_source_key(source_path) - table_path = self._get_table_path(source_path) - - # Check cache first - if dt := self._delta_table_cache.get(source_key): - return dt - - try: - # Try to load existing table - delta_table = DeltaTable(str(table_path)) - self._delta_table_cache[source_key] = delta_table - logger.debug(f"Loaded existing Delta table for {source_key}") - return delta_table - except TableNotFoundError: - # Table doesn't exist - return None - except Exception as e: - logger.error(f"Error loading Delta table for {source_key}: {e}") - # Try to clear any corrupted cache and retry once - if source_key in self._delta_table_cache: - del self._delta_table_cache[source_key] - return None - - def _ensure_entry_id_column(self, arrow_data: pa.Table, entry_id: str) -> pa.Table: - """Ensure the table has an __entry_id column.""" - if "__entry_id" not in arrow_data.column_names: - # Add entry_id column at the beginning - key_array = pa.array([entry_id] * len(arrow_data), type=pa.large_string()) - arrow_data = arrow_data.add_column(0, "__entry_id", key_array) - return arrow_data - - def _remove_entry_id_column(self, arrow_data: pa.Table) -> pa.Table: - """Remove the __entry_id column if it exists.""" - if "__entry_id" in arrow_data.column_names: - column_names = arrow_data.column_names - indices_to_keep = [ - i for i, name in enumerate(column_names) if name != "__entry_id" - ] - arrow_data = arrow_data.select(indices_to_keep) - return arrow_data - - def _handle_entry_id_column( - self, arrow_data: pa.Table, add_entry_id_column: bool | str = False - ) -> pa.Table: - """ - Handle entry_id column based on add_entry_id_column parameter. - - Args: - arrow_data: Arrow table with __entry_id column - add_entry_id_column: Control entry ID column inclusion: - - False: Remove __entry_id column - - True: Keep __entry_id column as is - - str: Rename __entry_id column to custom name - """ - if add_entry_id_column is False: - # Remove the __entry_id column - return self._remove_entry_id_column(arrow_data) - elif isinstance(add_entry_id_column, str): - # Rename __entry_id to custom name - if "__entry_id" in arrow_data.column_names: - schema = arrow_data.schema - new_names = [ - add_entry_id_column if name == "__entry_id" else name - for name in schema.names - ] - return arrow_data.rename_columns(new_names) - # If add_entry_id_column is True, keep __entry_id as is - return arrow_data - - def _create_entry_id_filter(self, entry_id: str) -> list: - """ - Create a proper filter expression for Delta Lake. - - Args: - entry_id: The entry ID to filter by - - Returns: - List containing the filter expression for Delta Lake - """ - return [("__entry_id", "=", entry_id)] - - def _create_entry_ids_filter(self, entry_ids: list[str]) -> list: - """ - Create a proper filter expression for multiple entry IDs. - - Args: - entry_ids: List of entry IDs to filter by - - Returns: - List containing the filter expression for Delta Lake - """ - return [("__entry_id", "in", entry_ids)] - - def _read_table_with_filter( - self, - delta_table: DeltaTable, - filters: list | None = None, - ) -> pa.Table: - """ - Read table using to_pyarrow_dataset with original schema preservation. - - Args: - delta_table: The Delta table to read from - filters: Optional filters to apply - - Returns: - Arrow table with preserved schema - """ - # Use to_pyarrow_dataset with as_large_types for Polars compatible arrow table loading - dataset: ds.Dataset = delta_table.to_pyarrow_dataset(as_large_types=True) - if filters: - # Apply filters at dataset level for better performance - import pyarrow.compute as pc - - filter_expr = None - for filt in filters: - if len(filt) == 3: - col, op, val = filt - if op == "=": - expr = pc.equal(pc.field(col), pa.scalar(val)) # type: ignore - elif op == "in": - expr = pc.is_in(pc.field(col), pa.array(val)) # type: ignore - else: - logger.warning( - f"Unsupported filter operation: {op}. Falling back to table-level filter application which may be less efficient." - ) - # Fallback to table-level filtering - return dataset.to_table()(filters=filters) - - if filter_expr is None: - filter_expr = expr - else: - filter_expr = pc.and_(filter_expr, expr) # type: ignore - - if filter_expr is not None: - return dataset.to_table(filter=filter_expr) - - return dataset.to_table() - - def add_record( - self, - source_path: tuple[str, ...], - entry_id: str, - arrow_data: pa.Table, - force_flush: bool = False, - ) -> pa.Table: - """ - Add a record to the Delta table (batched). - - Args: - source_path: Tuple of path components (e.g., ("org", "project", "dataset")) - entry_id: Unique identifier for this record - arrow_data: The Arrow table data to store - ignore_duplicate: If True, ignore duplicate entry error - force_flush: If True, immediately flush this record to disk - - Returns: - The Arrow table data that was stored - - Raises: - ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' - """ - self._validate_source_path(source_path) - source_key = self._get_source_key(source_path) - - # Check for existing entry - if self.duplicate_entry_behavior == "error": - # Only check existing table, not pending batch for performance - pending_table = self._pending_batches[source_key].get(entry_id, None) - if pending_table is not None: - raise ValueError( - f"Entry '{entry_id}' already exists in pending batch for {source_key}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - existing_record = self.get_record(source_path, entry_id, flush=False) - if existing_record is not None: - raise ValueError( - f"Entry '{entry_id}' already exists in {'/'.join(source_path)}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - - # Add entry_id column to the data - data_with_entry_id = self._ensure_entry_id_column(arrow_data, entry_id) - - if force_flush: - # Write immediately - table_path = self._get_table_path(source_path) - table_path.mkdir(parents=True, exist_ok=True) - - delta_table = self._get_existing_delta_table(source_path) - - if delta_table is None: - # Create new table - save original schema first - write_deltalake(str(table_path), data_with_entry_id, mode="overwrite") - logger.debug(f"Created new Delta table for {source_key}") - else: - if self.duplicate_entry_behavior == "overwrite": - try: - delta_table.delete( - f"__entry_id = '{entry_id.replace(chr(39), chr(39) + chr(39))}'" - ) - logger.debug( - f"Deleted existing record {entry_id} from {source_key}" - ) - except Exception as e: - logger.debug( - f"No existing record to delete for {entry_id}: {e}" - ) - - write_deltalake( - table_path, - data_with_entry_id, - mode="append", - schema_mode="merge", - ) - - # Update cache - self._delta_table_cache[source_key] = DeltaTable(str(table_path)) - else: - # Add to the batch for later flushing - self._pending_batches[source_key][entry_id] = data_with_entry_id - batch_size = len(self._pending_batches[source_key]) - - # Check if we need to flush - if batch_size >= self.batch_size: - self.flush_batch(source_path) - - logger.debug(f"Added record {entry_id} to {source_key}") - return arrow_data - - def get_pending_batch_info(self) -> dict[str, int]: - """ - Get information about pending batches. - - Returns: - Dictionary mapping source keys to number of pending records - """ - return { - source_key: len(tables) - for source_key, tables in self._pending_batches.items() - if tables - } - - def get_record( - self, source_path: tuple[str, ...], entry_id: str, flush: bool = False - ) -> pa.Table | None: - """ - Get a specific record by entry_id with schema preservation. - - Args: - source_path: Tuple of path components - entry_id: Unique identifier for the record - - Returns: - Arrow table for the record or None if not found - """ - if flush: - self.flush_batch(source_path) - self._validate_source_path(source_path) - - # check if entry_id is found in pending batches - source_key = self._get_source_key(source_path) - if entry_id in self._pending_batches[source_key]: - # Return the pending record directly - return self._pending_batches[source_key][entry_id] - - delta_table = self._get_existing_delta_table(source_path) - if delta_table is None: - return None - - try: - # Use schema-preserving read - filter_expr = self._create_entry_id_filter(entry_id) - result = self._read_table_with_filter(delta_table, filters=filter_expr) - - if len(result) == 0: - return None - - # Remove the __entry_id column before returning - return self._remove_entry_id_column(result) - - except Exception as e: - logger.error( - f"Error getting record {entry_id} from {'/'.join(source_path)}: {e}" - ) - raise e - - def get_all_records( - self, - source_path: tuple[str, ...], - add_entry_id_column: bool | str = False, - retrieve_pending: bool = True, - flush: bool = False, - ) -> pa.Table | None: - """ - Retrieve all records for a given source path as a single table with schema preservation. - - Args: - source_path: Tuple of path components - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - - Returns: - Arrow table containing all records with original schema, or None if no records found - """ - if flush: - self.flush_batch(source_path) - self._validate_source_path(source_path) - - collected_arrays = [] - if retrieve_pending: - # Check if there are pending records in the batch - for entry_id, arrow_table in self._pending_batches[ - self._get_source_key(source_path) - ].items(): - collected_arrays.append( - self._ensure_entry_id_column(arrow_table, entry_id) - ) - - delta_table = self._get_existing_delta_table(source_path) - if delta_table is not None: - try: - # Use filter-based read - result = self._read_table_with_filter(delta_table) - - if len(result) != 0: - collected_arrays.append(result) - - except Exception as e: - logger.error( - f"Error getting all records from {'/'.join(source_path)}: {e}" - ) - if collected_arrays: - total_table = pa.Table.concatenate(collected_arrays) - - # Handle entry_id column based on parameter - return self._handle_entry_id_column(total_table, add_entry_id_column) - - return None - - def get_all_records_as_polars( - self, source_path: tuple[str, ...], flush: bool = True - ) -> pl.LazyFrame | None: - """ - Retrieve all records for a given source path as a single Polars LazyFrame. - - Args: - source_path: Tuple of path components - - Returns: - Polars LazyFrame containing all records, or None if no records found - """ - all_records = self.get_all_records(source_path, flush=flush) - if all_records is None: - return None - return pl.LazyFrame(all_records) - - def get_records_by_ids( - self, - source_path: tuple[str, ...], - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - flush: bool = False, - ) -> pa.Table | None: - """ - Retrieve records by entry IDs as a single table with schema preservation. - - Args: - source_path: Tuple of path components - entry_ids: Entry IDs to retrieve - add_entry_id_column: Control entry ID column inclusion - preserve_input_order: If True, return results in input order with nulls for missing - - Returns: - Arrow table containing all found records with original schema, or None if no records found - """ - if flush: - self.flush_batch(source_path) - - self._validate_source_path(source_path) - - # Convert input to list of strings for consistency - if isinstance(entry_ids, list): - if not entry_ids: - return None - entry_ids_list = entry_ids - elif isinstance(entry_ids, pl.Series): - if len(entry_ids) == 0: - return None - entry_ids_list = entry_ids.to_list() - elif isinstance(entry_ids, pa.Array): - if len(entry_ids) == 0: - return None - entry_ids_list = entry_ids.to_pylist() - else: - raise TypeError( - f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" - ) - - delta_table = self._get_existing_delta_table(source_path) - if delta_table is None: - return None - - try: - # Use schema-preserving read with filters - filter_expr = self._create_entry_ids_filter(entry_ids_list) - result = self._read_table_with_filter(delta_table, filters=filter_expr) - - if len(result) == 0: - return None - - if preserve_input_order: - # Need to reorder results and add nulls for missing entries - import pandas as pd - - df = result.to_pandas() - df = df.set_index("__entry_id") - - # Create a DataFrame with the desired order, filling missing with NaN - ordered_df = df.reindex(entry_ids_list) - - # Convert back to Arrow - result = pa.Table.from_pandas(ordered_df.reset_index()) - - # Handle entry_id column based on parameter - return self._handle_entry_id_column(result, add_entry_id_column) - - except Exception as e: - logger.error( - f"Error getting records by IDs from {'/'.join(source_path)}: {e}" - ) - return None - - def get_records_by_ids_as_polars( - self, - source_path: tuple[str, ...], - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - flush: bool = False, - ) -> pl.LazyFrame | None: - """ - Retrieve records by entry IDs as a single Polars LazyFrame. - - Args: - source_path: Tuple of path components - entry_ids: Entry IDs to retrieve - add_entry_id_column: Control entry ID column inclusion - preserve_input_order: If True, return results in input order with nulls for missing - - Returns: - Polars LazyFrame containing all found records, or None if no records found - """ - arrow_result = self.get_records_by_ids( - source_path, - entry_ids, - add_entry_id_column, - preserve_input_order, - flush=flush, - ) - - if arrow_result is None: - return None - - # Convert to Polars LazyFrame - return pl.LazyFrame(arrow_result) - - # Additional utility methods - def list_sources(self) -> list[tuple[str, ...]]: - """ - List all available source paths. - - Returns: - List of source path tuples - """ - sources = [] - - def _scan_directory(current_path: Path, path_components: tuple[str, ...]): - """Recursively scan for Delta tables.""" - for item in current_path.iterdir(): - if not item.is_dir(): - continue - - new_path_components = path_components + (item.name,) - - # Check if this directory contains a Delta table - try: - DeltaTable(str(item)) - sources.append(new_path_components) - except TableNotFoundError: - # Not a Delta table, continue scanning subdirectories - if len(new_path_components) < self.max_hierarchy_depth: - _scan_directory(item, new_path_components) - - _scan_directory(self.base_path, ()) - return sources - - def delete_source(self, source_path: tuple[str, ...]) -> bool: - """ - Delete an entire source (all records for a source path). - - Args: - source_path: Tuple of path components - - Returns: - True if source was deleted, False if it didn't exist - """ - self._validate_source_path(source_path) - - # Flush any pending batches first - self.flush_batch(source_path) - - table_path = self._get_table_path(source_path) - source_key = self._get_source_key(source_path) - - if not table_path.exists(): - return False - - try: - # Remove from caches - if source_key in self._delta_table_cache: - del self._delta_table_cache[source_key] - if source_key in self._schema_cache: - del self._schema_cache[source_key] - - # Remove directory - import shutil - - shutil.rmtree(table_path) - - logger.info(f"Deleted source {source_key}") - return True - - except Exception as e: - logger.error(f"Error deleting source {source_key}: {e}") - return False - - def delete_record(self, source_path: tuple[str, ...], entry_id: str) -> bool: - """ - Delete a specific record. - - Args: - source_path: Tuple of path components - entry_id: ID of the record to delete - - Returns: - True if record was deleted, False if it didn't exist - """ - self._validate_source_path(source_path) - - # Flush any pending batches first - self._flush_batch(source_path) - - delta_table = self._get_existing_delta_table(source_path) - if delta_table is None: - return False - - try: - # Check if record exists using proper filter - filter_expr = self._create_entry_id_filter(entry_id) - existing = self._read_table_with_filter(delta_table, filters=filter_expr) - if len(existing) == 0: - return False - - # Delete the record using SQL-style predicate (this is correct for delete operations) - delta_table.delete( - f"__entry_id = '{entry_id.replace(chr(39), chr(39) + chr(39))}'" - ) - - # Update cache - source_key = self._get_source_key(source_path) - self._delta_table_cache[source_key] = delta_table - - logger.debug(f"Deleted record {entry_id} from {'/'.join(source_path)}") - return True - - except Exception as e: - logger.error( - f"Error deleting record {entry_id} from {'/'.join(source_path)}: {e}" - ) - return False - - def get_table_info(self, source_path: tuple[str, ...]) -> dict[str, Any] | None: - """ - Get metadata information about a Delta table. - - Args: - source_path: Tuple of path components - - Returns: - Dictionary with table metadata, or None if table doesn't exist - """ - self._validate_source_path(source_path) - - delta_table = self._get_existing_delta_table(source_path) - if delta_table is None: - return None - - try: - # Get basic info - schema = delta_table.schema() - history = delta_table.history() - source_key = self._get_source_key(source_path) - - # Add pending batch info - pending_info = self.get_pending_batch_info() - pending_count = pending_info.get(source_key, 0) - - return { - "path": str(self._get_table_path(source_path)), - "source_path": source_path, - "schema": schema, - "version": delta_table.version(), - "num_files": len(delta_table.files()), - "history_length": len(history), - "latest_commit": history[0] if history else None, - "pending_records": pending_count, - } - - except Exception as e: - logger.error(f"Error getting table info for {'/'.join(source_path)}: {e}") - return None diff --git a/src/orcapod/databases/legacy/dict_data_stores.py b/src/orcapod/databases/legacy/dict_data_stores.py deleted file mode 100644 index 63d79746..00000000 --- a/src/orcapod/databases/legacy/dict_data_stores.py +++ /dev/null @@ -1,229 +0,0 @@ -import json -import logging -import shutil -from os import PathLike -from pathlib import Path - -from orcapod.hashing.legacy_core import hash_packet -from orcapod.hashing.types import LegacyPacketHasher -from orcapod.hashing.defaults import get_default_composite_file_hasher -from orcapod.databases.legacy.types import DataStore -from orcapod.types import Packet, PacketLike - -logger = logging.getLogger(__name__) - - -class NoOpDataStore(DataStore): - """ - An empty data store that does not store anything. - This is useful for testing purposes or when no memoization is needed. - """ - - def __init__(self): - """ - Initialize the NoOpDataStore. - This does not require any parameters. - """ - pass - - def memoize( - self, - function_name: str, - function_hash: str, - packet: PacketLike, - output_packet: PacketLike, - overwrite: bool = False, - ) -> PacketLike: - return output_packet - - def retrieve_memoized( - self, function_name: str, function_hash: str, packet: PacketLike - ) -> PacketLike | None: - return None - - -class DirDataStore(DataStore): - def __init__( - self, - store_dir: str | PathLike = "./pod_data", - packet_hasher: LegacyPacketHasher | None = None, - copy_files=True, - preserve_filename=True, - overwrite=False, - supplement_source=False, - legacy_mode=False, - legacy_algorithm="sha256", - ) -> None: - self.store_dir = Path(store_dir) - # Create the data directory if it doesn't exist - self.store_dir.mkdir(parents=True, exist_ok=True) - self.copy_files = copy_files - self.preserve_filename = preserve_filename - self.overwrite = overwrite - self.supplement_source = supplement_source - if packet_hasher is None and not legacy_mode: - packet_hasher = get_default_composite_file_hasher(with_cache=True) - self.packet_hasher = packet_hasher - self.legacy_mode = legacy_mode - self.legacy_algorithm = legacy_algorithm - - def memoize( - self, - function_name: str, - function_hash: str, - packet: PacketLike, - output_packet: PacketLike, - ) -> PacketLike: - if self.legacy_mode: - packet_hash = hash_packet(packet, algorithm=self.legacy_algorithm) - else: - packet_hash = self.packet_hasher.hash_packet(packet) # type: ignore[no-untyped-call] - output_dir = self.store_dir / function_name / function_hash / str(packet_hash) - info_path = output_dir / "_info.json" - source_path = output_dir / "_source.json" - - if info_path.exists() and not self.overwrite: - raise ValueError( - f"Entry for packet {packet} already exists, and will not be overwritten" - ) - else: - output_dir.mkdir(parents=True, exist_ok=True) - if self.copy_files: - new_output_packet = {} - # copy the files to the output directory - for key, value in output_packet.items(): - if not isinstance(value, (str, PathLike)): - raise NotImplementedError( - f"Pathset that is not a simple path is not yet supported: {value} was given" - ) - if self.preserve_filename: - relative_output_path = Path(value).name - else: - # preserve the suffix of the original if present - relative_output_path = key + Path(value).suffix - - output_path = output_dir / relative_output_path - if output_path.exists() and not self.overwrite: - logger.warning( - f"File {relative_output_path} already exists in {output_path}" - ) - if not self.overwrite: - raise ValueError( - f"File {relative_output_path} already exists in {output_path}" - ) - else: - logger.warning( - f"Removing file {relative_output_path} in {output_path}" - ) - shutil.rmtree(output_path) - logger.info(f"Copying file {value} to {output_path}") - shutil.copy(value, output_path) - # register the key with the new path - new_output_packet[key] = str(relative_output_path) - output_packet = new_output_packet - # store the output packet in a json file - with open(info_path, "w") as f: - json.dump(output_packet, f) - # store the source packet in a json file - with open(source_path, "w") as f: - json.dump(packet, f) - logger.info(f"Stored output for packet {packet} at {output_dir}") - - # retrieve back the memoized packet and return - # TODO: consider if we want to return the original packet or the memoized one - retrieved_output_packet = self.retrieve_memoized( - function_name, function_hash, packet - ) - if retrieved_output_packet is None: - raise ValueError(f"Memoized packet {packet} not found after storing it") - return retrieved_output_packet - - def retrieve_memoized( - self, function_name: str, function_hash: str, packet: PacketLike - ) -> Packet | None: - if self.legacy_mode: - packet_hash = hash_packet(packet, algorithm=self.legacy_algorithm) - else: - assert self.packet_hasher is not None, ( - "Packer hasher should be configured if not in legacy mode" - ) - packet_hash = self.packet_hasher.hash_packet(packet) - output_dir = self.store_dir / function_name / function_hash / str(packet_hash) - info_path = output_dir / "_info.json" - source_path = output_dir / "_source.json" - - if info_path.exists(): - # TODO: perform better error handling - try: - with open(info_path, "r") as f: - output_packet = json.load(f) - # update the paths to be absolute - for key, value in output_packet.items(): - # Note: if value is an absolute path, this will not change it as - # Pathlib is smart enough to preserve the last occurring absolute path (if present) - output_packet[key] = str(output_dir / value) - logger.info(f"Retrieved output for packet {packet} from {info_path}") - # check if source json exists -- if not, supplement it - if self.supplement_source and not source_path.exists(): - with open(source_path, "w") as f: - json.dump(packet, f) - logger.info( - f"Supplemented source for packet {packet} at {source_path}" - ) - except (IOError, json.JSONDecodeError) as e: - logger.error( - f"Error loading memoized output for packet {packet} from {info_path}: {e}" - ) - return None - return output_packet - else: - logger.info(f"No memoized output found for packet {packet}") - return None - - def clear_store(self, function_name: str) -> None: - # delete the folder self.data_dir and its content - shutil.rmtree(self.store_dir / function_name) - - def clear_all_stores(self, interactive=True, function_name="", force=False) -> None: - """ - Clear all stores in the data directory. - This is a dangerous operation -- please double- and triple-check before proceeding! - - Args: - interactive (bool): If True, prompt the user for confirmation before deleting. - If False, it will delete only if `force=True`. The user will be prompted - to type in the full name of the storage (as shown in the prompt) - to confirm deletion. - function_name (str): The name of the function to delete. If not using interactive mode, - this must be set to the store_dir path in order to proceed with the deletion. - force (bool): If True, delete the store without prompting the user for confirmation. - If False and interactive is False, the `function_name` must match the store_dir - for the deletion to proceed. - """ - # delete the folder self.data_dir and its content - # This is a dangerous operation -- double prompt the user for confirmation! - if not force and interactive: - confirm = input( - f"Are you sure you want to delete all stores in {self.store_dir}? (y/n): " - ) - if confirm.lower() != "y": - logger.info("Aborting deletion of all stores") - return - function_name = input( - f"Type in the function name {self.store_dir} to confirm the deletion: " - ) - if function_name != str(self.store_dir): - logger.info("Aborting deletion of all stores") - return - - if not force and function_name != str(self.store_dir): - logger.info(f"Aborting deletion of all stores in {self.store_dir}") - return - - logger.info(f"Deleting all stores in {self.store_dir}") - try: - shutil.rmtree(self.store_dir) - except: - logger.error(f"Error during the deletion of all stores in {self.store_dir}") - raise - logger.info(f"Deleted all stores in {self.store_dir}") diff --git a/src/orcapod/databases/legacy/dict_transfer_data_store.py b/src/orcapod/databases/legacy/dict_transfer_data_store.py deleted file mode 100644 index 99709e85..00000000 --- a/src/orcapod/databases/legacy/dict_transfer_data_store.py +++ /dev/null @@ -1,70 +0,0 @@ -# Implements transfer data store that lets you transfer memoized packets between data stores. - -from orcapod.databases.legacy.types import DataStore -from orcapod.types import PacketLike - - -class TransferDataStore(DataStore): - """ - A data store that allows transferring recorded data between different data stores. - This is useful for moving data between different storage backends. - """ - - def __init__(self, source_store: DataStore, target_store: DataStore) -> None: - self.source_store = source_store - self.target_store = target_store - - def transfer( - self, function_name: str, content_hash: str, packet: PacketLike - ) -> PacketLike: - """ - Transfer a memoized packet from the source store to the target store. - """ - retrieved_packet = self.source_store.retrieve_memoized( - function_name, content_hash, packet - ) - if retrieved_packet is None: - raise ValueError("Packet not found in source store.") - - return self.target_store.memoize( - function_name, content_hash, packet, retrieved_packet - ) - - def retrieve_memoized( - self, function_name: str, function_hash: str, packet: PacketLike - ) -> PacketLike | None: - """ - Retrieve a memoized packet from the target store. - """ - # Try retrieving from the target store first - memoized_packet = self.target_store.retrieve_memoized( - function_name, function_hash, packet - ) - if memoized_packet is not None: - return memoized_packet - - # If not found, try retrieving from the source store - memoized_packet = self.source_store.retrieve_memoized( - function_name, function_hash, packet - ) - if memoized_packet is not None: - # Memoize the packet in the target store as part of the transfer - self.target_store.memoize( - function_name, function_hash, packet, memoized_packet - ) - - return memoized_packet - - def memoize( - self, - function_name: str, - function_hash: str, - packet: PacketLike, - output_packet: PacketLike, - ) -> PacketLike: - """ - Memoize a packet in the target store. - """ - return self.target_store.memoize( - function_name, function_hash, packet, output_packet - ) diff --git a/src/orcapod/databases/legacy/legacy_arrow_data_stores.py b/src/orcapod/databases/legacy/legacy_arrow_data_stores.py deleted file mode 100644 index acac1984..00000000 --- a/src/orcapod/databases/legacy/legacy_arrow_data_stores.py +++ /dev/null @@ -1,2078 +0,0 @@ -import pyarrow as pa -import pyarrow.parquet as pq -import polars as pl -import threading -from pathlib import Path -from typing import Any, cast -from dataclasses import dataclass -from datetime import datetime, timedelta -import logging -from orcapod.databases.types import DuplicateError -from pathlib import Path - -# Module-level logger -logger = logging.getLogger(__name__) - - -class MockArrowDataStore: - """ - Mock Arrow data store for testing purposes. - This class simulates the behavior of ArrowDataStore without actually saving anything. - It is useful for unit tests where you want to avoid any I/O operations or when you need - to test the behavior of your code without relying on external systems. If you need some - persistence of saved data, consider using SimpleParquetDataStore without providing a - file path instead. - """ - - def __init__(self): - logger.info("Initialized MockArrowDataStore") - - def add_record( - self, - source_pathh: tuple[str, ...], - source_id: str, - entry_id: str, - arrow_data: pa.Table, - ) -> pa.Table: - """Add a record to the mock store.""" - return arrow_data - - def get_record( - self, source_path: tuple[str, ...], source_id: str, entry_id: str - ) -> pa.Table | None: - """Get a specific record.""" - return None - - def get_all_records( - self, source_path: tuple[str, ...], source_id: str - ) -> pa.Table | None: - """Retrieve all records for a given source as a single table.""" - return None - - def get_all_records_as_polars( - self, source_path: tuple[str, ...], source_id: str - ) -> pl.LazyFrame | None: - """Retrieve all records for a given source as a single Polars LazyFrame.""" - return None - - def get_records_by_ids( - self, - source_path: tuple[str, ...], - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pa.Table | None: - """ - Retrieve records by entry IDs as a single table. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Arrow table containing all found records, or None if no records found - """ - return None - - def get_records_by_ids_as_polars( - self, - source_path: tuple[str, ...], - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pl.LazyFrame | None: - return None - - -class SimpleParquetDataStore: - """ - Simple Parquet-based Arrow data store, primarily to be used for development purposes. - If no file path is provided, it will not save anything to disk. Instead, all data will be stored in memory. - If a file path is provided, it will save data to a single Parquet files in a directory structure reflecting - the provided source_path. To speed up the process, data will be stored in memory and only saved to disk - when the `flush` method is called. If used as part of pipeline, flush is automatically called - at the end of pipeline execution. - Note that this store provides only very basic functionality and is not suitable for production use. - For each distinct source_path, only a single parquet file is created to store all data entries. - Appending is not efficient as it requires reading the entire file into the memory, appending new data, - and then writing the entire file back to disk. This is not suitable for large datasets or frequent updates. - However, for development/testing purposes, this data store provides a simple way to store and retrieve - data without the overhead of a full database or file system and provides very high performance. - """ - - def __init__( - self, path: str | Path | None = None, duplicate_entry_behavior: str = "error" - ): - """ - Initialize the InMemoryArrowDataStore. - - Args: - duplicate_entry_behavior: How to handle duplicate entry_ids: - - 'error': Raise ValueError when entry_id already exists - - 'overwrite': Replace existing entry with new data - """ - # Validate duplicate behavior - if duplicate_entry_behavior not in ["error", "overwrite"]: - raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - self.duplicate_entry_behavior = duplicate_entry_behavior - - # Store Arrow tables: {source_key: {entry_id: arrow_table}} - self._in_memory_store: dict[str, dict[str, pa.Table]] = {} - logger.info( - f"Initialized InMemoryArrowDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'" - ) - self.base_path = Path(path) if path else None - if self.base_path: - try: - self.base_path.mkdir(parents=True, exist_ok=True) - except Exception as e: - logger.error(f"Error creating base path {self.base_path}: {e}") - - def _get_source_key(self, source_path: tuple[str, ...]) -> str: - """Generate key for source storage.""" - return "/".join(source_path) - - def add_record( - self, - source_path: tuple[str, ...], - entry_id: str, - arrow_data: pa.Table, - ignore_duplicate: bool = False, - ) -> pa.Table: - """ - Add a record to the in-memory store. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_id: Unique identifier for this record - arrow_data: The Arrow table data to store - - Returns: - arrow_data equivalent to having loaded the corresponding entry that was just saved - - Raises: - ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' - """ - source_key = self._get_source_key(source_path) - - # Initialize source if it doesn't exist - if source_key not in self._in_memory_store: - self._in_memory_store[source_key] = {} - - local_data = self._in_memory_store[source_key] - - # Check for duplicate entry - if entry_id in local_data: - if not ignore_duplicate and self.duplicate_entry_behavior == "error": - raise ValueError( - f"Entry '{entry_id}' already exists in {source_key}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - - # Store the record - local_data[entry_id] = arrow_data - - action = "Updated" if entry_id in local_data else "Added" - logger.debug(f"{action} record {entry_id} in {source_key}") - return arrow_data - - def load_existing_record(self, source_path: tuple[str, ...]): - source_key = self._get_source_key(source_path) - if self.base_path is not None and source_key not in self._in_memory_store: - self.load_from_parquet(self.base_path, source_path) - - def get_record( - self, source_path: tuple[str, ...], entry_id: str - ) -> pa.Table | None: - """Get a specific record.""" - self.load_existing_record(source_path) - source_key = self._get_source_key(source_path) - local_data = self._in_memory_store.get(source_key, {}) - return local_data.get(entry_id) - - def get_all_records( - self, source_path: tuple[str, ...], add_entry_id_column: bool | str = False - ) -> pa.Table | None: - """Retrieve all records for a given source as a single table.""" - self.load_existing_record(source_path) - source_key = self._get_source_key(source_path) - local_data = self._in_memory_store.get(source_key, {}) - - if not local_data: - return None - - tables_with_keys = [] - for key, table in local_data.items(): - # Add entry_id column to each table - key_array = pa.array([key] * len(table), type=pa.large_string()) - table_with_key = table.add_column(0, "__entry_id", key_array) - tables_with_keys.append(table_with_key) - - # Concatenate all tables - if tables_with_keys: - combined_table = pa.concat_tables(tables_with_keys) - if not add_entry_id_column: - combined_table = combined_table.drop(columns=["__entry_id"]) - return combined_table - return None - - def get_all_records_as_polars( - self, source_path: tuple[str, ...] - ) -> pl.LazyFrame | None: - """Retrieve all records for a given source as a single Polars LazyFrame.""" - all_records = self.get_all_records(source_path) - if all_records is None: - return None - return pl.LazyFrame(all_records) - - def get_records_by_ids( - self, - source_path: tuple[str, ...], - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pa.Table | None: - """ - Retrieve records by entry IDs as a single table. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Arrow table containing all found records, or None if no records found - """ - # Convert input to list of strings for consistency - if isinstance(entry_ids, list): - if not entry_ids: - return None - entry_ids_list = entry_ids - elif isinstance(entry_ids, pl.Series): - if len(entry_ids) == 0: - return None - entry_ids_list = entry_ids.to_list() - elif isinstance(entry_ids, pa.Array): - if len(entry_ids) == 0: - return None - entry_ids_list = entry_ids.to_pylist() - else: - raise TypeError( - f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" - ) - - self.load_existing_record(source_path) - - source_key = self._get_source_key(source_path) - local_data = self._in_memory_store.get(source_key, {}) - - if not local_data: - return None - - # Collect matching tables - found_tables = [] - found_entry_ids = [] - - if preserve_input_order: - # Preserve input order, include nulls for missing entries - first_table_schema = None - - for entry_id in entry_ids_list: - if entry_id in local_data: - table = local_data[entry_id] - # Add entry_id column - key_array = pa.array([entry_id] * len(table), type=pa.string()) - table_with_key = table.add_column(0, "__entry_id", key_array) - found_tables.append(table_with_key) - found_entry_ids.append(entry_id) - - # Store schema for creating null rows - if first_table_schema is None: - first_table_schema = table_with_key.schema - else: - # Create a null row with the same schema as other tables - if first_table_schema is not None: - # Create null row - null_data = {} - for field in first_table_schema: - if field.name == "__entry_id": - null_data[field.name] = pa.array( - [entry_id], type=field.type - ) - else: - # Create null array with proper type - null_array = pa.array([None], type=field.type) - null_data[field.name] = null_array - - null_table = pa.table(null_data, schema=first_table_schema) - found_tables.append(null_table) - found_entry_ids.append(entry_id) - else: - # Storage order (faster) - only include existing entries - for entry_id in entry_ids_list: - if entry_id in local_data: - table = local_data[entry_id] - # Add entry_id column - key_array = pa.array([entry_id] * len(table), type=pa.string()) - table_with_key = table.add_column(0, "__entry_id", key_array) - found_tables.append(table_with_key) - found_entry_ids.append(entry_id) - - if not found_tables: - return None - - # Concatenate all found tables - if len(found_tables) == 1: - combined_table = found_tables[0] - else: - combined_table = pa.concat_tables(found_tables) - - # Handle entry_id column based on add_entry_id_column parameter - if add_entry_id_column is False: - # Remove the __entry_id column - column_names = combined_table.column_names - if "__entry_id" in column_names: - indices_to_keep = [ - i for i, name in enumerate(column_names) if name != "__entry_id" - ] - combined_table = combined_table.select(indices_to_keep) - elif isinstance(add_entry_id_column, str): - # Rename __entry_id to custom name - schema = combined_table.schema - new_names = [ - add_entry_id_column if name == "__entry_id" else name - for name in schema.names - ] - combined_table = combined_table.rename_columns(new_names) - # If add_entry_id_column is True, keep __entry_id as is - - return combined_table - - def get_records_by_ids_as_polars( - self, - source_path: tuple[str, ...], - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pl.LazyFrame | None: - """ - Retrieve records by entry IDs as a single Polars LazyFrame. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Polars LazyFrame containing all found records, or None if no records found - """ - # Get Arrow result and convert to Polars - arrow_result = self.get_records_by_ids( - source_path, entry_ids, add_entry_id_column, preserve_input_order - ) - - if arrow_result is None: - return None - - # Convert to Polars LazyFrame - return pl.LazyFrame(arrow_result) - - def save_to_parquet(self, base_path: str | Path) -> None: - """ - Save all data to Parquet files in a directory structure. - - Directory structure: base_path/source_name/source_id/data.parquet - - Args: - base_path: Base directory path where to save the Parquet files - """ - base_path = Path(base_path) - base_path.mkdir(parents=True, exist_ok=True) - - saved_count = 0 - - for source_id, local_data in self._in_memory_store.items(): - if not local_data: - continue - - # Create directory structure - source_dir = base_path / source_id - source_dir.mkdir(parents=True, exist_ok=True) - - # Combine all tables for this source with entry_id column - tables_with_keys = [] - for entry_id, table in local_data.items(): - # Add entry_id column to each table - key_array = pa.array([entry_id] * len(table), type=pa.string()) - table_with_key = table.add_column(0, "__entry_id", key_array) - tables_with_keys.append(table_with_key) - - # Concatenate all tables - if tables_with_keys: - combined_table = pa.concat_tables(tables_with_keys) - - # Save as Parquet file - # TODO: perform safe "atomic" write - parquet_path = source_dir / "data.parquet" - import pyarrow.parquet as pq - - pq.write_table(combined_table, parquet_path) - - saved_count += 1 - logger.debug( - f"Saved {len(combined_table)} records for {source_id} to {parquet_path}" - ) - - logger.info(f"Saved {saved_count} sources to Parquet files in {base_path}") - - def load_from_parquet( - self, base_path: str | Path, source_path: tuple[str, ...] - ) -> None: - """ - Load data from Parquet files with the expected directory structure. - - Expected structure: base_path/source_name/source_id/data.parquet - - Args: - base_path: Base directory path containing the Parquet files - """ - - source_key = self._get_source_key(source_path) - target_path = Path(base_path) / source_key - - if not target_path.exists(): - logger.info(f"Base path {base_path} does not exist") - return - - loaded_count = 0 - - # Look for Parquet files in this directory - parquet_files = list(target_path.glob("*.parquet")) - if not parquet_files: - logger.debug(f"No Parquet files found in {target_path}") - return - - # Load all Parquet files and combine them - all_records = [] - - for parquet_file in parquet_files: - try: - import pyarrow.parquet as pq - - table = pq.read_table(parquet_file) - - # Validate that __entry_id column exists - if "__entry_id" not in table.column_names: - logger.warning( - f"Parquet file {parquet_file} missing __entry_id column, skipping" - ) - continue - - all_records.append(table) - logger.debug(f"Loaded {len(table)} records from {parquet_file}") - - except Exception as e: - logger.error(f"Failed to load Parquet file {parquet_file}: {e}") - continue - - # Process all records for this source - if all_records: - # Combine all tables - if len(all_records) == 1: - combined_table = all_records[0] - else: - combined_table = pa.concat_tables(all_records) - - # Split back into individual records by entry_id - local_data = {} - entry_ids = combined_table.column("__entry_id").to_pylist() - - # Group records by entry_id - entry_id_groups = {} - for i, entry_id in enumerate(entry_ids): - if entry_id not in entry_id_groups: - entry_id_groups[entry_id] = [] - entry_id_groups[entry_id].append(i) - - # Extract each entry_id's records - for entry_id, indices in entry_id_groups.items(): - # Take rows for this entry_id and remove __entry_id column - entry_table = combined_table.take(indices) - - # Remove __entry_id column - column_names = entry_table.column_names - if "__entry_id" in column_names: - indices_to_keep = [ - i for i, name in enumerate(column_names) if name != "__entry_id" - ] - entry_table = entry_table.select(indices_to_keep) - - local_data[entry_id] = entry_table - - self._in_memory_store[source_key] = local_data - loaded_count += 1 - - record_count = len(combined_table) - unique_entries = len(entry_id_groups) - logger.info( - f"Loaded {record_count} records ({unique_entries} unique entries) for {source_key}" - ) - - def flush(self): - """ - Flush all in-memory data to Parquet files in the base path. - This will overwrite existing files. - """ - if self.base_path is None: - logger.warning("Base path is not set, cannot flush data") - return - - logger.info(f"Flushing data to Parquet files in {self.base_path}") - self.save_to_parquet(self.base_path) - - -@dataclass -class RecordMetadata: - """Metadata for a stored record.""" - - source_name: str - source_id: str - entry_id: str - created_at: datetime - updated_at: datetime - schema_hash: str - parquet_path: str | None = None # Path to the specific partition - - -class SourceCache: - """Cache for a specific source_name/source_id combination.""" - - def __init__( - self, - source_name: str, - source_id: str, - base_path: Path, - partition_prefix_length: int = 2, - ): - self.source_name = source_name - self.source_id = source_id - self.base_path = base_path - self.source_dir = base_path / source_name / source_id - self.partition_prefix_length = partition_prefix_length - - # In-memory data - only for this source - self._memory_table: pl.DataFrame | None = None - self._loaded = False - self._dirty = False - self._last_access = datetime.now() - - # Track which entries are in memory vs on disk - self._memory_entries: set[str] = set() - self._disk_entries: set[str] = set() - - # Track which partitions are dirty (need to be rewritten) - self._dirty_partitions: set[str] = set() - - self._lock = threading.RLock() - - def _get_partition_key(self, entry_id: str) -> str: - """Get the partition key for an entry_id.""" - if len(entry_id) < self.partition_prefix_length: - return entry_id.ljust(self.partition_prefix_length, "0") - return entry_id[: self.partition_prefix_length] - - def _get_partition_path(self, entry_id: str) -> Path: - """Get the partition directory for an entry_id.""" - partition_key = self._get_partition_key(entry_id) - # Use prefix_ instead of entry_id= to avoid Hive partitioning issues - return self.source_dir / f"prefix_{partition_key}" - - def _get_partition_parquet_path(self, entry_id: str) -> Path: - """Get the Parquet file path for a partition.""" - partition_dir = self._get_partition_path(entry_id) - partition_key = self._get_partition_key(entry_id) - return partition_dir / f"partition_{partition_key}.parquet" - - def _load_from_disk_lazy(self) -> None: - """Lazily load data from disk only when first accessed.""" - if self._loaded: - return - - with self._lock: - if self._loaded: # Double-check after acquiring lock - return - - logger.debug(f"Lazy loading {self.source_name}/{self.source_id}") - - all_tables = [] - - if self.source_dir.exists(): - # Scan all partition directories - for partition_dir in self.source_dir.iterdir(): - if not partition_dir.is_dir() or not ( - partition_dir.name.startswith("entry_id=") - or partition_dir.name.startswith("prefix_") - ): - continue - - # Load the partition Parquet file (one per partition) - if partition_dir.name.startswith("entry_id="): - partition_key = partition_dir.name.split("=")[1] - else: # prefix_XX format - partition_key = partition_dir.name.split("_")[1] - - parquet_file = partition_dir / f"partition_{partition_key}.parquet" - - if parquet_file.exists(): - try: - table = pq.read_table(parquet_file) - if len(table) > 0: - polars_df = pl.from_arrow(table) - all_tables.append(polars_df) - - logger.debug( - f"Loaded partition {parquet_file}: {len(table)} rows, {len(table.columns)} columns" - ) - logger.debug(f" Columns: {table.column_names}") - - # Track disk entries from this partition - if "__entry_id" in table.column_names: - entry_ids = set( - table.column("__entry_id").to_pylist() - ) - self._disk_entries.update(entry_ids) - - except Exception as e: - logger.error(f"Failed to load {parquet_file}: {e}") - - # Combine all tables - if all_tables: - self._memory_table = pl.concat(all_tables) - self._memory_entries = self._disk_entries.copy() - logger.debug( - f"Combined loaded data: {len(self._memory_table)} rows, {len(self._memory_table.columns)} columns" - ) - logger.debug(f" Final columns: {self._memory_table.columns}") - - self._loaded = True - self._last_access = datetime.now() - - def add_entry( - self, - entry_id: str, - table_with_metadata: pa.Table, - allow_overwrite: bool = False, - ) -> None: - """Add an entry to this source cache.""" - with self._lock: - self._load_from_disk_lazy() # Ensure we're loaded - - # Check if entry already exists - entry_exists = ( - entry_id in self._memory_entries or entry_id in self._disk_entries - ) - - if entry_exists and not allow_overwrite: - raise ValueError( - f"Entry {entry_id} already exists in {self.source_name}/{self.source_id}" - ) - - # We know this returns DataFrame since we're passing a Table - polars_table = cast(pl.DataFrame, pl.from_arrow(table_with_metadata)) - - if self._memory_table is None: - self._memory_table = polars_table - else: - # Remove existing entry if it exists (for overwrite case) - if entry_id in self._memory_entries: - mask = self._memory_table["__entry_id"] != entry_id - self._memory_table = self._memory_table.filter(mask) - logger.debug(f"Removed existing entry {entry_id} for overwrite") - - # Debug schema mismatch - existing_cols = self._memory_table.columns - new_cols = polars_table.columns - - if len(existing_cols) != len(new_cols): - logger.error(f"Schema mismatch for entry {entry_id}:") - logger.error( - f" Existing columns ({len(existing_cols)}): {existing_cols}" - ) - logger.error(f" New columns ({len(new_cols)}): {new_cols}") - logger.error( - f" Missing in new: {set(existing_cols) - set(new_cols)}" - ) - logger.error( - f" Extra in new: {set(new_cols) - set(existing_cols)}" - ) - - raise ValueError( - f"Schema mismatch: existing table has {len(existing_cols)} columns, " - f"new table has {len(new_cols)} columns" - ) - - # Ensure column order matches - if existing_cols != new_cols: - logger.debug("Reordering columns to match existing schema") - polars_table = polars_table.select(existing_cols) - - # Add new entry - self._memory_table = pl.concat([self._memory_table, polars_table]) - - self._memory_entries.add(entry_id) - self._dirty = True - - # Mark the partition as dirty - partition_key = self._get_partition_key(entry_id) - self._dirty_partitions.add(partition_key) - - self._last_access = datetime.now() - - if entry_exists: - logger.info(f"Overwrote existing entry {entry_id}") - else: - logger.debug(f"Added new entry {entry_id}") - - def get_entry(self, entry_id: str) -> pa.Table | None: - """Get a specific entry.""" - with self._lock: - self._load_from_disk_lazy() - - if self._memory_table is None: - return None - - mask = self._memory_table["__entry_id"] == entry_id - filtered = self._memory_table.filter(mask) - - if len(filtered) == 0: - return None - - self._last_access = datetime.now() - return filtered.to_arrow() - - def get_all_entries(self) -> pa.Table | None: - """Get all entries for this source.""" - with self._lock: - self._load_from_disk_lazy() - - if self._memory_table is None: - return None - - self._last_access = datetime.now() - return self._memory_table.to_arrow() - - def get_all_entries_as_polars(self) -> pl.LazyFrame | None: - """Get all entries as a Polars LazyFrame.""" - with self._lock: - self._load_from_disk_lazy() - - if self._memory_table is None: - return None - - self._last_access = datetime.now() - return self._memory_table.lazy() - - def sync_to_disk(self) -> None: - """Sync dirty partitions to disk using efficient Parquet files.""" - with self._lock: - if not self._dirty or self._memory_table is None: - return - - logger.debug(f"Syncing {self.source_name}/{self.source_id} to disk") - - # Only sync dirty partitions - for partition_key in self._dirty_partitions: - try: - # Get all entries for this partition - partition_mask = ( - self._memory_table["__entry_id"].str.slice( - 0, self.partition_prefix_length - ) - == partition_key - ) - partition_data = self._memory_table.filter(partition_mask) - - if len(partition_data) == 0: - continue - - logger.debug(f"Syncing partition {partition_key}:") - logger.debug(f" Rows: {len(partition_data)}") - logger.debug(f" Columns: {partition_data.columns}") - logger.debug( - f" Sample __entry_id values: {partition_data['__entry_id'].head(3).to_list()}" - ) - - # Ensure partition directory exists - partition_dir = self.source_dir / f"prefix_{partition_key}" - partition_dir.mkdir(parents=True, exist_ok=True) - - # Write entire partition to single Parquet file - partition_path = ( - partition_dir / f"partition_{partition_key}.parquet" - ) - arrow_table = partition_data.to_arrow() - - logger.debug( - f" Arrow table columns before write: {arrow_table.column_names}" - ) - logger.debug(f" Arrow table shape: {arrow_table.shape}") - - pq.write_table(arrow_table, partition_path) - - # Verify what was written - verification_table = pq.read_table(partition_path) - logger.debug( - f" Verification - columns after write: {verification_table.column_names}" - ) - logger.debug(f" Verification - shape: {verification_table.shape}") - - entry_count = len(set(partition_data["__entry_id"].to_list())) - logger.debug( - f"Wrote partition {partition_key} with {entry_count} entries ({len(partition_data)} rows)" - ) - - except Exception as e: - logger.error(f"Failed to write partition {partition_key}: {e}") - import traceback - - logger.error(f"Traceback: {traceback.format_exc()}") - - # Clear dirty markers - self._dirty_partitions.clear() - self._dirty = False - - def is_loaded(self) -> bool: - """Check if this cache is loaded in memory.""" - return self._loaded - - def get_last_access(self) -> datetime: - """Get the last access time.""" - return self._last_access - - def unload(self) -> None: - """Unload from memory (after syncing if dirty).""" - with self._lock: - if self._dirty: - self.sync_to_disk() - - self._memory_table = None - self._loaded = False - self._memory_entries.clear() - # Keep _disk_entries for reference - - def entry_exists(self, entry_id: str) -> bool: - """Check if an entry exists (in memory or on disk).""" - with self._lock: - self._load_from_disk_lazy() - return entry_id in self._memory_entries or entry_id in self._disk_entries - - def list_entries(self) -> set[str]: - """List all entry IDs in this source.""" - with self._lock: - self._load_from_disk_lazy() - return self._memory_entries | self._disk_entries - - def get_stats(self) -> dict[str, Any]: - """Get statistics for this cache.""" - with self._lock: - return { - "source_name": self.source_name, - "source_id": self.source_id, - "loaded": self._loaded, - "dirty": self._dirty, - "memory_entries": len(self._memory_entries), - "disk_entries": len(self._disk_entries), - "memory_rows": len(self._memory_table) - if self._memory_table is not None - else 0, - "last_access": self._last_access.isoformat(), - } - - -class ParquetArrowDataStore: - """ - Lazy-loading, append-only Arrow data store with entry_id partitioning. - - Features: - - Lazy loading: Only loads source data when first accessed - - Separate memory management per source_name/source_id - - Entry_id partitioning: Multiple entries per Parquet file based on prefix - - Configurable duplicate entry_id handling (error or overwrite) - - Automatic cache eviction for memory management - - Single-row constraint: Each record must contain exactly one row - """ - - _system_columns = [ - "__source_name", - "__source_id", - "__entry_id", - "__created_at", - "__updated_at", - "__schema_hash", - ] - - def __init__( - self, - base_path: str | Path, - sync_interval_seconds: int = 300, # 5 minutes default - auto_sync: bool = True, - max_loaded_sources: int = 100, - cache_eviction_hours: int = 2, - duplicate_entry_behavior: str = "error", - partition_prefix_length: int = 2, - ): - """ - Initialize the ParquetArrowDataStore. - - Args: - base_path: Directory path for storing Parquet files - sync_interval_seconds: How often to sync dirty caches to disk - auto_sync: Whether to automatically sync on a timer - max_loaded_sources: Maximum number of source caches to keep in memory - cache_eviction_hours: Hours of inactivity before evicting from memory - duplicate_entry_behavior: How to handle duplicate entry_ids: - - 'error': Raise ValueError when entry_id already exists - - 'overwrite': Replace existing entry with new data - partition_prefix_length: Number of characters from entry_id to use for partitioning (default 2) - """ - self.base_path = Path(base_path) - self.base_path.mkdir(parents=True, exist_ok=True) - self.sync_interval = sync_interval_seconds - self.auto_sync = auto_sync - self.max_loaded_sources = max_loaded_sources - self.cache_eviction_hours = cache_eviction_hours - self.partition_prefix_length = max( - 1, min(8, partition_prefix_length) - ) # Clamp between 1-8 - - # Validate duplicate behavior - if duplicate_entry_behavior not in ["error", "overwrite"]: - raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - self.duplicate_entry_behavior = duplicate_entry_behavior - - # Cache management - self._source_caches: dict[str, SourceCache] = {} # key: "source_name:source_id" - self._global_lock = threading.RLock() - - # Record metadata (always in memory for fast lookups) - self._record_metadata: dict[str, RecordMetadata] = {} - self._load_metadata_index() - - # Sync management - self._sync_timer: threading.Timer | None = None - self._shutdown = False - - # Start auto-sync and cleanup if enabled - if self.auto_sync: - self._start_sync_timer() - - logger.info(f"Initialized lazy ParquetArrowDataStore at {base_path}") - - def _get_source_key(self, source_name: str, source_id: str) -> str: - """Generate key for source cache.""" - return f"{source_name}:{source_id}" - - def _get_record_key(self, source_name: str, source_id: str, entry_id: str) -> str: - """Generate unique key for a record.""" - return f"{source_name}:{source_id}:{entry_id}" - - def _load_metadata_index(self) -> None: - """Load metadata index from disk (lightweight - just file paths and timestamps).""" - logger.info("Loading metadata index...") - - if not self.base_path.exists(): - return - - for source_name_dir in self.base_path.iterdir(): - if not source_name_dir.is_dir(): - continue - - source_name = source_name_dir.name - - for source_id_dir in source_name_dir.iterdir(): - if not source_id_dir.is_dir(): - continue - - source_id = source_id_dir.name - - # Scan partition directories for parquet files - for partition_dir in source_id_dir.iterdir(): - if not partition_dir.is_dir() or not ( - partition_dir.name.startswith("entry_id=") - or partition_dir.name.startswith("prefix_") - ): - continue - - for parquet_file in partition_dir.glob("partition_*.parquet"): - try: - # Read the parquet file to extract entry IDs - table = pq.read_table(parquet_file) - if "__entry_id" in table.column_names: - entry_ids = set(table.column("__entry_id").to_pylist()) - - # Get file stats - stat = parquet_file.stat() - created_at = datetime.fromtimestamp(stat.st_ctime) - updated_at = datetime.fromtimestamp(stat.st_mtime) - - for entry_id in entry_ids: - record_key = self._get_record_key( - source_name, source_id, entry_id - ) - self._record_metadata[record_key] = RecordMetadata( - source_name=source_name, - source_id=source_id, - entry_id=entry_id, - created_at=created_at, - updated_at=updated_at, - schema_hash="unknown", # Will be computed if needed - parquet_path=str(parquet_file), - ) - except Exception as e: - logger.error( - f"Failed to read metadata from {parquet_file}: {e}" - ) - - logger.info(f"Loaded metadata for {len(self._record_metadata)} records") - - def _get_or_create_source_cache( - self, source_name: str, source_id: str - ) -> SourceCache: - """Get or create a source cache, handling eviction if needed.""" - source_key = self._get_source_key(source_name, source_id) - - with self._global_lock: - if source_key not in self._source_caches: - # Check if we need to evict old caches - if len(self._source_caches) >= self.max_loaded_sources: - self._evict_old_caches() - - # Create new cache with partition configuration - self._source_caches[source_key] = SourceCache( - source_name, source_id, self.base_path, self.partition_prefix_length - ) - logger.debug(f"Created cache for {source_key}") - - return self._source_caches[source_key] - - def _evict_old_caches(self) -> None: - """Evict old caches based on last access time.""" - cutoff_time = datetime.now() - timedelta(hours=self.cache_eviction_hours) - - to_evict = [] - for source_key, cache in self._source_caches.items(): - if cache.get_last_access() < cutoff_time: - to_evict.append(source_key) - - for source_key in to_evict: - cache = self._source_caches.pop(source_key) - cache.unload() # This will sync if dirty - logger.debug(f"Evicted cache for {source_key}") - - def _compute_schema_hash(self, table: pa.Table) -> str: - """Compute a hash of the table schema.""" - import hashlib - - schema_str = str(table.schema) - return hashlib.sha256(schema_str.encode()).hexdigest()[:16] - - def _add_system_columns( - self, table: pa.Table, metadata: RecordMetadata - ) -> pa.Table: - """Add system columns to track record metadata.""" - # Keep all system columns for self-describing data - # Use large_string for all string columns - large_string_type = pa.large_string() - - system_columns = [ - ( - "__source_name", - pa.array([metadata.source_name] * len(table), type=large_string_type), - ), - ( - "__source_id", - pa.array([metadata.source_id] * len(table), type=large_string_type), - ), - ( - "__entry_id", - pa.array([metadata.entry_id] * len(table), type=large_string_type), - ), - ("__created_at", pa.array([metadata.created_at] * len(table))), - ("__updated_at", pa.array([metadata.updated_at] * len(table))), - ( - "__schema_hash", - pa.array([metadata.schema_hash] * len(table), type=large_string_type), - ), - ] - - # Combine user columns + system columns in consistent order - new_columns = list(table.columns) + [col[1] for col in system_columns] - new_names = table.column_names + [col[0] for col in system_columns] - - result = pa.table(new_columns, names=new_names) - logger.debug( - f"Added system columns: {len(table.columns)} -> {len(result.columns)} columns" - ) - return result - - def _remove_system_columns(self, table: pa.Table) -> pa.Table: - """Remove system columns to get original user data.""" - return table.drop(self._system_columns) - - def add_record( - self, source_name: str, source_id: str, entry_id: str, arrow_data: pa.Table - ) -> pa.Table: - """ - Add or update a record (append-only operation). - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_id: Unique identifier for this record (typically 32-char hash) - arrow_data: The Arrow table data to store (MUST contain exactly 1 row) - - Returns: - The original arrow_data table - - Raises: - ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' - ValueError: If arrow_data contains more than 1 row - ValueError: If arrow_data schema doesn't match existing data for this source - """ - # normalize arrow_data to conform to polars string. TODO: consider a clearner approach - arrow_data = pl.DataFrame(arrow_data).to_arrow() - - # CRITICAL: Enforce single-row constraint - if len(arrow_data) != 1: - raise ValueError( - f"Each record must contain exactly 1 row, got {len(arrow_data)} rows. " - f"This constraint ensures that for each source_name/source_id combination, " - f"there is only one valid entry per entry_id." - ) - - # Validate entry_id format (assuming 8+ char identifier) - if not entry_id or len(entry_id) < 8: - raise ValueError( - f"entry_id must be at least 8 characters long, got: '{entry_id}'" - ) - - # Check if this source already has data and validate schema compatibility - cache = self._get_or_create_source_cache(source_name, source_id) - - # Load existing data to check schema compatibility - cache._load_from_disk_lazy() - - if cache._memory_table is not None: - # Extract user columns from existing data (remove system columns) - existing_arrow = cache._memory_table.to_arrow() - existing_user_data = self._remove_system_columns(existing_arrow) - - # Check if schemas match - existing_schema = existing_user_data.schema - new_schema = arrow_data.schema - - if not existing_schema.equals(new_schema): - existing_cols = existing_user_data.column_names - new_cols = arrow_data.column_names - - logger.error(f"Schema mismatch for {source_name}/{source_id}:") - logger.error(f" Existing user columns: {existing_cols}") - logger.error(f" New user columns: {new_cols}") - logger.error(f" Missing in new: {set(existing_cols) - set(new_cols)}") - logger.error(f" Extra in new: {set(new_cols) - set(existing_cols)}") - - raise ValueError( - f"Schema mismatch for {source_name}/{source_id}. " - f"Existing data has columns {existing_cols}, " - f"but new data has columns {new_cols}. " - f"All records in a source must have the same schema." - ) - - now = datetime.now() - record_key = self._get_record_key(source_name, source_id, entry_id) - - # Check for existing entry - existing_metadata = self._record_metadata.get(record_key) - entry_exists = existing_metadata is not None - - if entry_exists and self.duplicate_entry_behavior == "error": - raise DuplicateError( - f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - - # Create/update metadata - schema_hash = self._compute_schema_hash(arrow_data) - metadata = RecordMetadata( - source_name=source_name, - source_id=source_id, - entry_id=entry_id, - created_at=existing_metadata.created_at if existing_metadata else now, - updated_at=now, - schema_hash=schema_hash, - ) - - # Add system columns - table_with_metadata = self._add_system_columns(arrow_data, metadata) - - # Get or create source cache and add entry - allow_overwrite = self.duplicate_entry_behavior == "overwrite" - - try: - cache.add_entry(entry_id, table_with_metadata, allow_overwrite) - except ValueError as e: - # Re-raise with more context - raise ValueError(f"Failed to add record: {e}") - - # Update metadata - self._record_metadata[record_key] = metadata - - action = "Updated" if entry_exists else "Added" - logger.info(f"{action} record {record_key} with {len(arrow_data)} rows") - return arrow_data - - def get_record( - self, source_name: str, source_id: str, entry_id: str - ) -> pa.Table | None: - """Retrieve a specific record.""" - record_key = self._get_record_key(source_name, source_id, entry_id) - - if record_key not in self._record_metadata: - return None - - cache = self._get_or_create_source_cache(source_name, source_id) - table = cache.get_entry(entry_id) - - if table is None: - return None - - return self._remove_system_columns(table) - - def get_all_records( - self, source_name: str, source_id: str, _keep_system_columns: bool = False - ) -> pa.Table | None: - """Retrieve all records for a given source as a single Arrow table.""" - cache = self._get_or_create_source_cache(source_name, source_id) - table = cache.get_all_entries() - - if table is None: - return None - - if _keep_system_columns: - return table - return self._remove_system_columns(table) - - def get_all_records_as_polars( - self, source_name: str, source_id: str, _keep_system_columns: bool = False - ) -> pl.LazyFrame | None: - """Retrieve all records for a given source as a Polars LazyFrame.""" - cache = self._get_or_create_source_cache(source_name, source_id) - lazy_frame = cache.get_all_entries_as_polars() - - if lazy_frame is None: - return None - - if _keep_system_columns: - return lazy_frame - - return lazy_frame.drop(self._system_columns) - - def get_records_by_ids( - self, - source_name: str, - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pa.Table | None: - """ - Retrieve multiple records by their entry_ids as a single Arrow table. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Arrow table containing all found records, or None if no records found - When preserve_input_order=True, table length equals input length - When preserve_input_order=False, records are in storage order - """ - # Get Polars result using the Polars method - polars_result = self.get_records_by_ids_as_polars( - source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order - ) - - if polars_result is None: - return None - - # Convert to Arrow table - return polars_result.collect().to_arrow() - - def get_records_by_ids_as_polars( - self, - source_name: str, - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pl.LazyFrame | None: - """ - Retrieve multiple records by their entry_ids as a Polars LazyFrame. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Polars LazyFrame containing all found records, or None if no records found - When preserve_input_order=True, frame length equals input length - When preserve_input_order=False, records are in storage order (existing behavior) - """ - # Convert input to Polars Series - if isinstance(entry_ids, list): - if not entry_ids: - return None - entry_ids_series = pl.Series("entry_id", entry_ids) - elif isinstance(entry_ids, pl.Series): - if len(entry_ids) == 0: - return None - entry_ids_series = entry_ids - elif isinstance(entry_ids, pa.Array): - if len(entry_ids) == 0: - return None - entry_ids_series = pl.Series( - "entry_id", entry_ids - ) # Direct from Arrow array - else: - raise TypeError( - f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" - ) - - cache = self._get_or_create_source_cache(source_name, source_id) - lazy_frame = cache.get_all_entries_as_polars() - - if lazy_frame is None: - return None - - # Define system columns that are always excluded (except optionally __entry_id) - system_cols = [ - "__source_name", - "__source_id", - "__created_at", - "__updated_at", - "__schema_hash", - ] - - # Add __entry_id to system columns if we don't want it in the result - if add_entry_id_column is False: - system_cols.append("__entry_id") - - # Handle input order preservation vs filtering - if preserve_input_order: - # Create ordered DataFrame with input IDs and join to preserve order with nulls - ordered_df = pl.DataFrame({"__entry_id": entry_ids_series}).lazy() - # Join with all data to get results in input order with nulls for missing - result_frame = ordered_df.join(lazy_frame, on="__entry_id", how="left") - else: - # Standard filtering approach for storage order -- should be faster in general - result_frame = lazy_frame.filter( - pl.col("__entry_id").is_in(entry_ids_series) - ) - - # Apply column selection (same for both paths) - result_frame = result_frame.drop(system_cols) - - # Rename __entry_id column if custom name provided - if isinstance(add_entry_id_column, str): - result_frame = result_frame.rename({"__entry_id": add_entry_id_column}) - - return result_frame - - def _sync_all_dirty_caches(self) -> None: - """Sync all dirty caches to disk.""" - with self._global_lock: - dirty_count = 0 - for cache in self._source_caches.values(): - if cache._dirty: - cache.sync_to_disk() - dirty_count += 1 - - if dirty_count > 0: - logger.info(f"Synced {dirty_count} dirty caches to disk") - - def _start_sync_timer(self) -> None: - """Start the automatic sync timer.""" - if self._shutdown: - return - - self._sync_timer = threading.Timer( - self.sync_interval, self._sync_and_reschedule - ) - self._sync_timer.daemon = True - self._sync_timer.start() - - def _sync_and_reschedule(self) -> None: - """Sync dirty caches and reschedule.""" - try: - self._sync_all_dirty_caches() - self._evict_old_caches() - except Exception as e: - logger.error(f"Auto-sync failed: {e}") - finally: - if not self._shutdown: - self._start_sync_timer() - - def force_sync(self) -> None: - """Manually trigger a sync of all dirty caches.""" - self._sync_all_dirty_caches() - - def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: - """Check if a specific entry exists.""" - record_key = self._get_record_key(source_name, source_id, entry_id) - - # Check metadata first (fast) - if record_key in self._record_metadata: - return True - - # If not in metadata, check if source cache knows about it - source_key = self._get_source_key(source_name, source_id) - if source_key in self._source_caches: - cache = self._source_caches[source_key] - return cache.entry_exists(entry_id) - - # Not loaded and not in metadata - doesn't exist - return False - - def list_entries(self, source_name: str, source_id: str) -> set[str]: - """List all entry IDs for a specific source.""" - cache = self._get_or_create_source_cache(source_name, source_id) - return cache.list_entries() - - def list_sources(self) -> set[tuple[str, str]]: - """List all (source_name, source_id) combinations.""" - sources = set() - - # From metadata - for metadata in self._record_metadata.values(): - sources.add((metadata.source_name, metadata.source_id)) - - return sources - - def get_stats(self) -> dict[str, Any]: - """Get comprehensive statistics about the data store.""" - with self._global_lock: - loaded_caches = len(self._source_caches) - dirty_caches = sum( - 1 for cache in self._source_caches.values() if cache._dirty - ) - - cache_stats = [cache.get_stats() for cache in self._source_caches.values()] - - return { - "total_records": len(self._record_metadata), - "loaded_source_caches": loaded_caches, - "dirty_caches": dirty_caches, - "max_loaded_sources": self.max_loaded_sources, - "sync_interval": self.sync_interval, - "auto_sync": self.auto_sync, - "cache_eviction_hours": self.cache_eviction_hours, - "base_path": str(self.base_path), - "duplicate_entry_behavior": self.duplicate_entry_behavior, - "partition_prefix_length": self.partition_prefix_length, - "cache_details": cache_stats, - } - - def shutdown(self) -> None: - """Shutdown the data store, ensuring all data is synced.""" - logger.info("Shutting down ParquetArrowDataStore...") - self._shutdown = True - - if self._sync_timer: - self._sync_timer.cancel() - - # Final sync of all caches - self._sync_all_dirty_caches() - - logger.info("Shutdown complete") - - def __del__(self): - """Ensure cleanup on destruction.""" - if not self._shutdown: - self.shutdown() - - -# Example usage and testing -def demo_single_row_constraint(): - """Demonstrate the single-row constraint in the ParquetArrowDataStore.""" - import tempfile - import random - from datetime import timedelta - - def create_single_row_record(entry_id: str, value: float | None = None) -> pa.Table: - """Create a single-row Arrow table.""" - if value is None: - value = random.uniform(0, 100) - - return pa.table( - { - "entry_id": [entry_id], - "timestamp": [datetime.now()], - "value": [value], - "category": [random.choice(["A", "B", "C"])], - } - ) - - def create_multi_row_record(entry_id: str, num_rows: int = 3) -> pa.Table: - """Create a multi-row Arrow table (should be rejected).""" - return pa.table( - { - "entry_id": [entry_id] * num_rows, - "timestamp": [ - datetime.now() + timedelta(seconds=i) for i in range(num_rows) - ], - "value": [random.uniform(0, 100) for _ in range(num_rows)], - "category": [random.choice(["A", "B", "C"]) for _ in range(num_rows)], - } - ) - - print("Testing Single-Row Constraint...") - - with tempfile.TemporaryDirectory() as temp_dir: - store = ParquetArrowDataStore( - base_path=temp_dir, - sync_interval_seconds=10, - auto_sync=False, # Manual sync for testing - duplicate_entry_behavior="overwrite", - ) - - try: - print("\n=== Testing Valid Single-Row Records ===") - - # Test 1: Add valid single-row records - valid_entries = [ - "entry_001_abcdef1234567890abcdef1234567890", - "entry_002_abcdef1234567890abcdef1234567890", - "entry_003_abcdef1234567890abcdef1234567890", - ] - - for i, entry_id in enumerate(valid_entries): - data = create_single_row_record(entry_id, value=100.0 + i) - store.add_record("experiments", "dataset_A", entry_id, data) - print( - f"✓ Added single-row record {entry_id[:16]}... (value: {100.0 + i})" - ) - - print(f"\nTotal records stored: {len(store._record_metadata)}") - - print("\n=== Testing Invalid Multi-Row Records ===") - - # Test 2: Try to add multi-row record (should fail) - invalid_entry = "entry_004_abcdef1234567890abcdef1234567890" - try: - invalid_data = create_multi_row_record(invalid_entry, num_rows=3) - store.add_record( - "experiments", "dataset_A", invalid_entry, invalid_data - ) - print("✗ ERROR: Multi-row record was accepted!") - except ValueError as e: - print(f"✓ Correctly rejected multi-row record: {str(e)[:80]}...") - - # Test 3: Try to add empty record (should fail) - empty_entry = "entry_005_abcdef1234567890abcdef1234567890" - try: - empty_data = pa.table({"col1": pa.array([], type=pa.int64())}) - store.add_record("experiments", "dataset_A", empty_entry, empty_data) - print("✗ ERROR: Empty record was accepted!") - except ValueError as e: - print(f"✓ Correctly rejected empty record: {str(e)[:80]}...") - - print("\n=== Testing Retrieval ===") - - # Test 4: Retrieve records - retrieved = store.get_record("experiments", "dataset_A", valid_entries[0]) - if retrieved and len(retrieved) == 1: - print(f"✓ Retrieved single record: {len(retrieved)} row") - print(f" Value: {retrieved.column('value')[0].as_py()}") - else: - print("✗ Failed to retrieve record or wrong size") - - # Test 5: Get all records - all_records = store.get_all_records("experiments", "dataset_A") - if all_records: - print(f"✓ Retrieved all records: {len(all_records)} rows total") - unique_entries = len(set(all_records.column("entry_id").to_pylist())) - print(f" Unique entries: {unique_entries}") - - # Verify each entry appears exactly once - entry_counts = {} - for entry_id in all_records.column("entry_id").to_pylist(): - entry_counts[entry_id] = entry_counts.get(entry_id, 0) + 1 - - all_single = all(count == 1 for count in entry_counts.values()) - if all_single: - print( - "✓ Each entry appears exactly once (single-row constraint maintained)" - ) - else: - print("✗ Some entries appear multiple times!") - - print("\n=== Testing Overwrite Behavior ===") - - # Test 6: Overwrite existing single-row record - overwrite_data = create_single_row_record(valid_entries[0], value=999.0) - store.add_record( - "experiments", "dataset_A", valid_entries[0], overwrite_data - ) - print("✓ Overwrote existing record") - - # Verify overwrite - updated_record = store.get_record( - "experiments", "dataset_A", valid_entries[0] - ) - if updated_record and updated_record.column("value")[0].as_py() == 999.0: - print( - f"✓ Overwrite successful: new value = {updated_record.column('value')[0].as_py()}" - ) - - # Sync and show final stats - store.force_sync() - stats = store.get_stats() - print("\n=== Final Statistics ===") - print(f"Total records: {stats['total_records']}") - print(f"Loaded caches: {stats['loaded_source_caches']}") - print(f"Dirty caches: {stats['dirty_caches']}") - - finally: - store.shutdown() - - print("\n✓ Single-row constraint testing completed successfully!") - - -class InMemoryPolarsDataStore: - """ - In-memory Arrow data store using Polars DataFrames for efficient storage and retrieval. - This class provides the same interface as InMemoryArrowDataStore but uses Polars internally - for better performance with large datasets and complex queries. - - Uses dict of Polars DataFrames for efficient storage and retrieval. - Each DataFrame contains all records for a source with an __entry_id column. - """ - - def __init__(self, duplicate_entry_behavior: str = "error"): - """ - Initialize the InMemoryPolarsDataStore. - - Args: - duplicate_entry_behavior: How to handle duplicate entry_ids: - - 'error': Raise ValueError when entry_id already exists - - 'overwrite': Replace existing entry with new data - """ - # Validate duplicate behavior - if duplicate_entry_behavior not in ["error", "overwrite"]: - raise ValueError("duplicate_entry_behavior must be 'error' or 'overwrite'") - self.duplicate_entry_behavior = duplicate_entry_behavior - - # Store Polars DataFrames: {source_key: polars_dataframe} - # Each DataFrame has an __entry_id column plus user data columns - self._in_memory_store: dict[str, pl.DataFrame] = {} - logger.info( - f"Initialized InMemoryPolarsDataStore with duplicate_entry_behavior='{duplicate_entry_behavior}'" - ) - - def _get_source_key(self, source_name: str, source_id: str) -> str: - """Generate key for source storage.""" - return f"{source_name}:{source_id}" - - def add_record( - self, - source_name: str, - source_id: str, - entry_id: str, - arrow_data: pa.Table, - ) -> pa.Table: - """ - Add a record to the in-memory store. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_id: Unique identifier for this record - arrow_data: The Arrow table data to store - - Returns: - arrow_data equivalent to having loaded the corresponding entry that was just saved - - Raises: - ValueError: If entry_id already exists and duplicate_entry_behavior is 'error' - """ - source_key = self._get_source_key(source_name, source_id) - - # Convert Arrow table to Polars DataFrame and add entry_id column - polars_data = cast(pl.DataFrame, pl.from_arrow(arrow_data)) - - # Add __entry_id column - polars_data = polars_data.with_columns(pl.lit(entry_id).alias("__entry_id")) - - # Check if source exists - if source_key not in self._in_memory_store: - # First record for this source - self._in_memory_store[source_key] = polars_data - logger.debug(f"Created new source {source_key} with entry {entry_id}") - else: - existing_df = self._in_memory_store[source_key] - - # Check for duplicate entry - entry_exists = ( - existing_df.filter(pl.col("__entry_id") == entry_id).shape[0] > 0 - ) - - if entry_exists: - if self.duplicate_entry_behavior == "error": - raise ValueError( - f"Entry '{entry_id}' already exists in {source_name}/{source_id}. " - f"Use duplicate_entry_behavior='overwrite' to allow updates." - ) - else: # validity of value is checked in constructor so it must be "ovewrite" - # Remove existing entry and add new one - existing_df = existing_df.filter(pl.col("__entry_id") != entry_id) - self._in_memory_store[source_key] = pl.concat( - [existing_df, polars_data] - ) - logger.debug(f"Overwrote entry {entry_id} in {source_key}") - else: - # Append new entry - try: - self._in_memory_store[source_key] = pl.concat( - [existing_df, polars_data] - ) - logger.debug(f"Added entry {entry_id} to {source_key}") - except Exception as e: - # Handle schema mismatch - existing_cols = set(existing_df.columns) - {"__entry_id"} - new_cols = set(polars_data.columns) - {"__entry_id"} - - if existing_cols != new_cols: - raise ValueError( - f"Schema mismatch for {source_key}. " - f"Existing columns: {sorted(existing_cols)}, " - f"New columns: {sorted(new_cols)}" - ) from e - else: - raise e - - return arrow_data - - def get_record( - self, source_name: str, source_id: str, entry_id: str - ) -> pa.Table | None: - """Get a specific record.""" - source_key = self._get_source_key(source_name, source_id) - - if source_key not in self._in_memory_store: - return None - - df = self._in_memory_store[source_key] - - # Filter for the specific entry_id - filtered_df = df.filter(pl.col("__entry_id") == entry_id) - - if filtered_df.shape[0] == 0: - return None - - # Remove __entry_id column and convert to Arrow - result_df = filtered_df.drop("__entry_id") - return result_df.to_arrow() - - def get_all_records( - self, source_name: str, source_id: str, add_entry_id_column: bool | str = False - ) -> pa.Table | None: - """Retrieve all records for a given source as a single table.""" - df = self.get_all_records_as_polars( - source_name, source_id, add_entry_id_column=add_entry_id_column - ) - if df is None: - return None - return df.collect().to_arrow() - - def get_all_records_as_polars( - self, source_name: str, source_id: str, add_entry_id_column: bool | str = False - ) -> pl.LazyFrame | None: - """Retrieve all records for a given source as a single Polars LazyFrame.""" - source_key = self._get_source_key(source_name, source_id) - - if source_key not in self._in_memory_store: - return None - - df = self._in_memory_store[source_key] - - if df.shape[0] == 0: - return None - - # perform column selection lazily - df = df.lazy() - - # Handle entry_id column based on parameter - if add_entry_id_column is False: - # Remove __entry_id column - result_df = df.drop("__entry_id") - elif add_entry_id_column is True: - # Keep __entry_id column as is - result_df = df - elif isinstance(add_entry_id_column, str): - # Rename __entry_id to custom name - result_df = df.rename({"__entry_id": add_entry_id_column}) - else: - raise ValueError( - f"add_entry_id_column must be a bool or str but {add_entry_id_column} was given" - ) - - return result_df - - def get_records_by_ids( - self, - source_name: str, - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pa.Table | None: - """ - Retrieve records by entry IDs as a single table. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Arrow table containing all found records, or None if no records found - """ - # Convert input to Polars Series - if isinstance(entry_ids, list): - if not entry_ids: - return None - entry_ids_series = pl.Series("entry_id", entry_ids) - elif isinstance(entry_ids, pl.Series): - if len(entry_ids) == 0: - return None - entry_ids_series = entry_ids - elif isinstance(entry_ids, pa.Array): - if len(entry_ids) == 0: - return None - entry_ids_series: pl.Series = pl.from_arrow( - pa.table({"entry_id": entry_ids}) - )["entry_id"] # type: ignore - else: - raise TypeError( - f"entry_ids must be list[str], pl.Series, or pa.Array, got {type(entry_ids)}" - ) - - source_key = self._get_source_key(source_name, source_id) - - if source_key not in self._in_memory_store: - return None - - df = self._in_memory_store[source_key] - - if preserve_input_order: - # Create DataFrame with input order and join to preserve order with nulls - ordered_df = pl.DataFrame({"__entry_id": entry_ids_series}) - result_df = ordered_df.join(df, on="__entry_id", how="left") - else: - # Filter for matching entry_ids (storage order) - result_df = df.filter(pl.col("__entry_id").is_in(entry_ids_series)) - - if result_df.shape[0] == 0: - return None - - # Handle entry_id column based on parameter - if add_entry_id_column is False: - # Remove __entry_id column - result_df = result_df.drop("__entry_id") - elif add_entry_id_column is True: - # Keep __entry_id column as is - pass - elif isinstance(add_entry_id_column, str): - # Rename __entry_id to custom name - result_df = result_df.rename({"__entry_id": add_entry_id_column}) - - return result_df.to_arrow() - - def get_records_by_ids_as_polars( - self, - source_name: str, - source_id: str, - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pl.LazyFrame | None: - """ - Retrieve records by entry IDs as a single Polars LazyFrame. - - Args: - source_name: Name of the data source - source_id: ID of the specific dataset within the source - entry_ids: Entry IDs to retrieve. Can be: - - list[str]: List of entry ID strings - - pl.Series: Polars Series containing entry IDs - - pa.Array: PyArrow Array containing entry IDs - add_entry_id_column: Control entry ID column inclusion: - - False: Don't include entry ID column (default) - - True: Include entry ID column as "__entry_id" - - str: Include entry ID column with custom name - preserve_input_order: If True, return results in the same order as input entry_ids, - with null rows for missing entries. If False, return in storage order. - - Returns: - Polars LazyFrame containing all found records, or None if no records found - """ - # Get Arrow result and convert to Polars LazyFrame - arrow_result = self.get_records_by_ids( - source_name, source_id, entry_ids, add_entry_id_column, preserve_input_order - ) - - if arrow_result is None: - return None - - # Convert to Polars LazyFrame - df = cast(pl.DataFrame, pl.from_arrow(arrow_result)) - return df.lazy() - - def entry_exists(self, source_name: str, source_id: str, entry_id: str) -> bool: - """Check if a specific entry exists.""" - source_key = self._get_source_key(source_name, source_id) - - if source_key not in self._in_memory_store: - return False - - df = self._in_memory_store[source_key] - return df.filter(pl.col("__entry_id") == entry_id).shape[0] > 0 - - def list_entries(self, source_name: str, source_id: str) -> set[str]: - """List all entry IDs for a specific source.""" - source_key = self._get_source_key(source_name, source_id) - - if source_key not in self._in_memory_store: - return set() - - df = self._in_memory_store[source_key] - return set(df["__entry_id"].to_list()) - - def list_sources(self) -> set[tuple[str, str]]: - """List all (source_name, source_id) combinations.""" - sources = set() - for source_key in self._in_memory_store.keys(): - if ":" in source_key: - source_name, source_id = source_key.split(":", 1) - sources.add((source_name, source_id)) - return sources - - def clear_source(self, source_name: str, source_id: str) -> None: - """Clear all records for a specific source.""" - source_key = self._get_source_key(source_name, source_id) - if source_key in self._in_memory_store: - del self._in_memory_store[source_key] - logger.debug(f"Cleared source {source_key}") - - def clear_all(self) -> None: - """Clear all records from the store.""" - self._in_memory_store.clear() - logger.info("Cleared all records from store") - - def get_stats(self) -> dict[str, Any]: - """Get comprehensive statistics about the data store.""" - total_records = 0 - total_memory_mb = 0 - source_stats = [] - - for source_key, df in self._in_memory_store.items(): - record_count = df.shape[0] - total_records += record_count - - # Estimate memory usage (rough approximation) - memory_bytes = df.estimated_size() - memory_mb = memory_bytes / (1024 * 1024) - total_memory_mb += memory_mb - - source_stats.append( - { - "source_key": source_key, - "record_count": record_count, - "column_count": df.shape[1] - 1, # Exclude __entry_id - "memory_mb": round(memory_mb, 2), - "columns": [col for col in df.columns if col != "__entry_id"], - } - ) - - return { - "total_records": total_records, - "total_sources": len(self._in_memory_store), - "total_memory_mb": round(total_memory_mb, 2), - "duplicate_entry_behavior": self.duplicate_entry_behavior, - "source_details": source_stats, - } - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - demo_single_row_constraint() diff --git a/src/orcapod/databases/legacy/safe_dir_data_store.py b/src/orcapod/databases/legacy/safe_dir_data_store.py deleted file mode 100644 index 72f8ef05..00000000 --- a/src/orcapod/databases/legacy/safe_dir_data_store.py +++ /dev/null @@ -1,492 +0,0 @@ -# safedirstore.py - SafeDirDataStore implementation - -import errno -import fcntl -import json -import logging -import os -import time -from contextlib import contextmanager -from pathlib import Path -from typing import Optional, Union - -from ..file_utils import atomic_copy, atomic_write - -logger = logging.getLogger(__name__) - - -class FileLockError(Exception): - """Exception raised when a file lock cannot be acquired""" - - pass - - -@contextmanager -def file_lock( - lock_path: str | Path, - shared: bool = False, - timeout: float = 30.0, - delay: float = 0.1, - stale_threshold: float = 3600.0, -): - """ - A context manager for file locking that supports both shared and exclusive locks. - - Args: - lock_path: Path to the lock file - shared: If True, acquire a shared (read) lock; if False, acquire an exclusive (write) lock - timeout: Maximum time to wait for the lock in seconds - delay: Time between retries in seconds - stale_threshold: Time in seconds after which a lock is considered stale - - Yields: - None when the lock is acquired - - Raises: - FileLockError: If the lock cannot be acquired within the timeout - """ - lock_path = Path(lock_path) - lock_file = f"{lock_path}.lock" - - # Ensure parent directory exists - lock_path.parent.mkdir(parents=True, exist_ok=True) - - # Choose lock type based on shared flag - lock_type = fcntl.LOCK_SH if shared else fcntl.LOCK_EX - - # Add non-blocking flag for the initial attempt - lock_type_nb = lock_type | fcntl.LOCK_NB - - fd = None - start_time = time.time() - - try: - while True: - try: - # Open the lock file (create if it doesn't exist) - fd = os.open(lock_file, os.O_CREAT | os.O_RDWR) - - try: - # Try to acquire the lock in non-blocking mode - fcntl.flock(fd, lock_type_nb) - - # If we get here, lock was acquired - if not shared: # For exclusive locks only - # Write PID and timestamp to lock file - os.ftruncate(fd, 0) # Clear the file - os.write(fd, f"{os.getpid()},{time.time()}".encode()) - - break # Exit the retry loop - we got the lock - - except IOError as e: - # Close the file descriptor if we couldn't acquire the lock - if fd is not None: - os.close(fd) - fd = None - - if e.errno != errno.EAGAIN: - # If it's not "resource temporarily unavailable", re-raise - raise - - # Check if the lock file is stale (only for exclusive locks) - if os.path.exists(lock_file) and not shared: - try: - with open(lock_file, "r") as f: - content = f.read().strip() - if "," in content: - pid_str, timestamp_str = content.split(",", 1) - lock_pid = int(pid_str) - lock_time = float(timestamp_str) - - # Check if process exists - process_exists = True - try: - os.kill(lock_pid, 0) - except OSError: - process_exists = False - - # Check if lock is stale - if ( - not process_exists - or time.time() - lock_time > stale_threshold - ): - logger.warning( - f"Removing stale lock: {lock_file}" - ) - os.unlink(lock_file) - continue # Try again immediately - except (ValueError, IOError): - # If we can't read the lock file properly, continue with retry - pass - except Exception as e: - logger.debug( - f"Error while trying to acquire lock {lock_file}: {str(e)}" - ) - - # If fd was opened, make sure it's closed - if fd is not None: - os.close(fd) - fd = None - - # Check if we've exceeded the timeout - if time.time() - start_time >= timeout: - if fd is not None: - os.close(fd) - lock_type_name = "shared" if shared else "exclusive" - raise FileLockError( - f"Couldn't acquire {lock_type_name} lock on {lock_file} " - f"after {timeout} seconds" - ) - - # Sleep before retrying - time.sleep(delay) - - # If we get here, we've acquired the lock - logger.debug( - f"Acquired {'shared' if shared else 'exclusive'} lock on {lock_file}" - ) - - # Yield control back to the caller - yield - - finally: - # Release the lock and close the file descriptor - if fd is not None: - fcntl.flock(fd, fcntl.LOCK_UN) - os.close(fd) - - # Remove the lock file only if it was an exclusive lock - if not shared: - try: - os.unlink(lock_file) - except OSError as e: - logger.warning(f"Failed to remove lock file {lock_file}: {str(e)}") - - logger.debug( - f"Released {'shared' if shared else 'exclusive'} lock on {lock_file}" - ) - - -class SafeDirDataStore: - """ - A thread-safe and process-safe directory-based data store for memoization. - Uses file locks and atomic operations to ensure consistency. - """ - - def __init__( - self, - store_dir="./pod_data", - copy_files=True, - preserve_filename=True, - overwrite=False, - lock_timeout=30, - lock_stale_threshold=3600, - ): - """ - Initialize the data store. - - Args: - store_dir: Base directory for storing data - copy_files: Whether to copy files to the data store - preserve_filename: Whether to preserve original filenames - overwrite: Whether to overwrite existing entries - lock_timeout: Timeout for acquiring locks in seconds - lock_stale_threshold: Time in seconds after which a lock is considered stale - """ - self.store_dir = Path(store_dir) - self.copy_files = copy_files - self.preserve_filename = preserve_filename - self.overwrite = overwrite - self.lock_timeout = lock_timeout - self.lock_stale_threshold = lock_stale_threshold - - # Create the data directory if it doesn't exist - self.store_dir.mkdir(parents=True, exist_ok=True) - - def _get_output_dir(self, function_name, content_hash, packet): - """Get the output directory for a specific packet""" - from orcapod.hashing.legacy_core import hash_dict - - packet_hash = hash_dict(packet) - return self.store_dir / function_name / content_hash / str(packet_hash) - - def memoize( - self, - function_name: str, - content_hash: str, - packet: dict, - output_packet: dict, - ) -> dict: - """ - Memoize the output packet for a given store, content hash, and input packet. - Uses file locking to ensure thread safety and process safety. - - Args: - function_name: Name of the function - content_hash: Hash of the function/operation - packet: Input packet - output_packet: Output packet to memoize - - Returns: - The memoized output packet with paths adjusted to the store - - Raises: - FileLockError: If the lock cannot be acquired - ValueError: If the entry already exists and overwrite is False - """ - output_dir = self._get_output_dir(function_name, content_hash, packet) - info_path = output_dir / "_info.json" - lock_path = output_dir / "_lock" - completion_marker = output_dir / "_complete" - - # Create the output directory - output_dir.mkdir(parents=True, exist_ok=True) - - # First check if we already have a completed entry (with a shared lock) - try: - with file_lock(lock_path, shared=True, timeout=self.lock_timeout): - if completion_marker.exists() and not self.overwrite: - logger.info(f"Entry already exists for packet {packet}") - return self.retrieve_memoized(function_name, content_hash, packet) - except FileLockError: - logger.warning("Could not acquire shared lock to check completion status") - # Continue to try with exclusive lock - - # Now try to acquire an exclusive lock for writing - with file_lock( - lock_path, - shared=False, - timeout=self.lock_timeout, - stale_threshold=self.lock_stale_threshold, - ): - # Double-check if the entry already exists (another process might have created it) - if completion_marker.exists() and not self.overwrite: - logger.info( - f"Entry already exists for packet {packet} (verified with exclusive lock)" - ) - return self.retrieve_memoized(function_name, content_hash, packet) - - # Check for partial results and clean up if necessary - partial_marker = output_dir / "_partial" - if partial_marker.exists(): - partial_time = float(partial_marker.read_text().strip()) - if time.time() - partial_time > self.lock_stale_threshold: - logger.warning( - f"Found stale partial results in {output_dir}, cleaning up" - ) - for item in output_dir.glob("*"): - if item.name not in ("_lock", "_lock.lock"): - if item.is_file(): - item.unlink(missing_ok=True) - else: - import shutil - - shutil.rmtree(item, ignore_errors=True) - - # Create partial marker - atomic_write(partial_marker, str(time.time())) - - try: - # Process files - new_output_packet = {} - if self.copy_files: - for key, value in output_packet.items(): - value_path = Path(value) - - if self.preserve_filename: - relative_output_path = value_path.name - else: - # Preserve the suffix of the original if present - relative_output_path = key + value_path.suffix - - output_path = output_dir / relative_output_path - - # Use atomic copy to ensure consistency - atomic_copy(value_path, output_path) - - # Register the key with the new path - new_output_packet[key] = str(relative_output_path) - else: - new_output_packet = output_packet.copy() - - # Write info JSON atomically - atomic_write(info_path, json.dumps(new_output_packet, indent=2)) - - # Create completion marker (atomic write ensures it's either fully there or not at all) - atomic_write(completion_marker, str(time.time())) - - logger.info(f"Stored output for packet {packet} at {output_dir}") - - # Retrieve the memoized packet to ensure consistency - # We don't need to acquire a new lock since we already have an exclusive lock - return self._retrieve_without_lock( - function_name, content_hash, packet, output_dir - ) - - finally: - # Remove partial marker if it exists - if partial_marker.exists(): - partial_marker.unlink(missing_ok=True) - - def retrieve_memoized( - self, function_name: str, content_hash: str, packet: dict - ) -> Optional[dict]: - """ - Retrieve a memoized output packet. - - Uses a shared lock to allow concurrent reads while preventing writes during reads. - - Args: - function_name: Name of the function - content_hash: Hash of the function/operation - packet: Input packet - - Returns: - The memoized output packet with paths adjusted to absolute paths, - or None if the packet is not found - """ - output_dir = self._get_output_dir(function_name, content_hash, packet) - lock_path = output_dir / "_lock" - - # Use a shared lock for reading to allow concurrent reads - try: - with file_lock(lock_path, shared=True, timeout=self.lock_timeout): - return self._retrieve_without_lock( - function_name, content_hash, packet, output_dir - ) - except FileLockError: - logger.warning(f"Could not acquire shared lock to read {output_dir}") - return None - - def _retrieve_without_lock( - self, function_name: str, content_hash: str, packet: dict, output_dir: Path - ) -> Optional[dict]: - """ - Helper to retrieve a memoized packet without acquiring a lock. - - This is used internally when we already have a lock. - - Args: - function_name: Name of the function - content_hash: Hash of the function/operation - packet: Input packet - output_dir: Directory containing the output - - Returns: - The memoized output packet with paths adjusted to absolute paths, - or None if the packet is not found - """ - info_path = output_dir / "_info.json" - completion_marker = output_dir / "_complete" - - # Only return if the completion marker exists - if not completion_marker.exists(): - logger.info(f"No completed output found for packet {packet}") - return None - - if not info_path.exists(): - logger.warning( - f"Completion marker exists but info file missing for {packet}" - ) - return None - - try: - with open(info_path, "r") as f: - output_packet = json.load(f) - - # Update paths to be absolute - for key, value in output_packet.items(): - file_path = output_dir / value - if not file_path.exists(): - logger.warning(f"Referenced file {file_path} does not exist") - return None - output_packet[key] = str(file_path) - - logger.info(f"Retrieved output for packet {packet} from {info_path}") - return output_packet - - except json.JSONDecodeError: - logger.error(f"Error decoding JSON from {info_path}") - return None - except Exception as e: - logger.error(f"Error loading memoized output for packet {packet}: {e}") - return None - - def clear_store(self, function_name: str) -> None: - """ - Clear a specific store. - - Args: - function_name: Name of the function to clear - """ - import shutil - - store_path = self.store_dir / function_name - if store_path.exists(): - shutil.rmtree(store_path) - - def clear_all_stores(self) -> None: - """Clear all stores""" - import shutil - - if self.store_dir.exists(): - shutil.rmtree(self.store_dir) - self.store_dir.mkdir(parents=True, exist_ok=True) - - def clean_stale_data(self, function_name=None, max_age=86400): - """ - Clean up stale data in the store. - - Args: - function_name: Optional name of the function to clean, or None for all functions - max_age: Maximum age of data in seconds before it's considered stale - """ - import shutil - - if function_name is None: - # Clean all stores - for store_dir in self.store_dir.iterdir(): - if store_dir.is_dir(): - self.clean_stale_data(store_dir.name, max_age) - return - - store_path = self.store_dir / function_name - if not store_path.is_dir(): - return - - now = time.time() - - # Find all directories with partial markers - for content_hash_dir in store_path.iterdir(): - if not content_hash_dir.is_dir(): - continue - - for packet_hash_dir in content_hash_dir.iterdir(): - if not packet_hash_dir.is_dir(): - continue - - # Try to acquire an exclusive lock with a short timeout - lock_path = packet_hash_dir / "_lock" - try: - with file_lock(lock_path, shared=False, timeout=1.0): - partial_marker = packet_hash_dir / "_partial" - completion_marker = packet_hash_dir / "_complete" - - # Check for partial results with no completion marker - if partial_marker.exists() and not completion_marker.exists(): - try: - partial_time = float(partial_marker.read_text().strip()) - if now - partial_time > max_age: - logger.info( - f"Cleaning up stale data in {packet_hash_dir}" - ) - shutil.rmtree(packet_hash_dir) - except (ValueError, IOError): - # If we can't read the marker, assume it's stale - logger.info( - f"Cleaning up invalid partial data in {packet_hash_dir}" - ) - shutil.rmtree(packet_hash_dir) - except FileLockError: - # Skip if we couldn't acquire the lock - continue diff --git a/src/orcapod/databases/legacy/types.py b/src/orcapod/databases/legacy/types.py deleted file mode 100644 index 42b0ed57..00000000 --- a/src/orcapod/databases/legacy/types.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Protocol, runtime_checkable - -from orcapod.types import Tag, PacketLike -import pyarrow as pa -import polars as pl - - -class DuplicateError(ValueError): - pass - - -@runtime_checkable -class DataStore(Protocol): - """ - Protocol for data stores that can memoize and retrieve packets. - This is used to define the interface for data stores like DirDataStore. - """ - - def __init__(self, *args, **kwargs) -> None: ... - def memoize( - self, - function_name: str, - function_hash: str, - packet: PacketLike, - output_packet: PacketLike, - ) -> PacketLike: ... - - def retrieve_memoized( - self, function_name: str, function_hash: str, packet: PacketLike - ) -> PacketLike | None: ... - - -@runtime_checkable -class ArrowDataStore(Protocol): - """ - Protocol for data stores that can memoize and retrieve packets. - This is used to define the interface for data stores like DirDataStore. - """ - - def __init__(self, *args, **kwargs) -> None: ... - - def add_record( - self, - source_path: tuple[str, ...], - entry_id: str, - arrow_data: pa.Table, - ignore_duplicate: bool = False, - ) -> pa.Table: ... - - def get_record( - self, source_path: tuple[str, ...], entry_id: str - ) -> pa.Table | None: ... - - def get_all_records(self, source_path: tuple[str, ...]) -> pa.Table | None: - """Retrieve all records for a given source as a single table.""" - ... - - def get_all_records_as_polars( - self, source_path: tuple[str, ...] - ) -> pl.LazyFrame | None: - """Retrieve all records for a given source as a single Polars DataFrame.""" - ... - - def get_records_by_ids( - self, - source_path: tuple[str, ...], - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pa.Table | None: - """Retrieve records by entry IDs as a single table.""" - ... - - def get_records_by_ids_as_polars( - self, - source_path: tuple[str, ...], - entry_ids: list[str] | pl.Series | pa.Array, - add_entry_id_column: bool | str = False, - preserve_input_order: bool = False, - ) -> pl.LazyFrame | None: - """Retrieve records by entry IDs as a single Polars DataFrame.""" - ... - - def flush(self) -> None: - """Flush all pending writes/saves to the data store.""" - ... diff --git a/src/orcapod/databases/noop_database.py b/src/orcapod/databases/noop_database.py new file mode 100644 index 00000000..a65ef88e --- /dev/null +++ b/src/orcapod/databases/noop_database.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from collections.abc import Collection, Mapping +from typing import TYPE_CHECKING, Any + +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + + +class NoOpArrowDatabase: + """ + An ArrowDatabaseProtocol implementation that performs no real storage. + + All write operations are silently discarded. All read operations return + None (empty / not found). Useful as a placeholder where a database + dependency is required by an interface but persistence is unwanted — + e.g. dry-run pipelines, testing that code paths execute without I/O, + or benchmarking pure compute overhead. + """ + + def add_record( + self, + record_path: tuple[str, ...], + record_id: str, + record: "pa.Table", + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + pass + + def add_records( + self, + record_path: tuple[str, ...], + records: "pa.Table", + record_id_column: str | None = None, + skip_duplicates: bool = False, + flush: bool = False, + ) -> None: + pass + + def get_record_by_id( + self, + record_path: tuple[str, ...], + record_id: str, + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + return None + + def get_all_records( + self, + record_path: tuple[str, ...], + record_id_column: str | None = None, + ) -> "pa.Table | None": + return None + + def get_records_by_ids( + self, + record_path: tuple[str, ...], + record_ids: Collection[str], + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + return None + + def get_records_with_column_value( + self, + record_path: tuple[str, ...], + column_values: Collection[tuple[str, Any]] | Mapping[str, Any], + record_id_column: str | None = None, + flush: bool = False, + ) -> "pa.Table | None": + return None + + def flush(self) -> None: + pass + + def to_config(self) -> dict[str, Any]: + """Serialize database configuration to a JSON-compatible dict.""" + return {"type": "noop"} + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "NoOpArrowDatabase": + """Reconstruct a NoOpArrowDatabase from a config dict.""" + return cls() diff --git a/src/orcapod/errors.py b/src/orcapod/errors.py index 3775ee9e..9d1c05cf 100644 --- a/src/orcapod/errors.py +++ b/src/orcapod/errors.py @@ -9,3 +9,19 @@ class DuplicateTagError(ValueError): """Raised when duplicate tag values are found and skip_duplicates=False""" pass + + +class FieldNotResolvableError(LookupError): + """ + Raised when a source cannot resolve a field value for a given record ID. + + This may happen because: + - The source is transient or randomly generated (no stable backing data) + - The record ID is not found in the source + - The field name does not exist in the source schema + - The source type does not support field resolution + + The exception message should describe which condition applies. + """ + + pass diff --git a/src/orcapod/hashing/__init__.py b/src/orcapod/hashing/__init__.py index b90f228e..4cf7e3a7 100644 --- a/src/orcapod/hashing/__init__.py +++ b/src/orcapod/hashing/__init__.py @@ -1,28 +1,152 @@ -# from .defaults import ( -# get_default_object_hasher, -# get_default_arrow_hasher, -# ) +""" +OrcaPod hashing package. +Public API +---------- + BaseSemanticHasher -- content-based recursive object hasher (concrete) + SemanticHasherProtocol -- protocol for semantic hashers + TypeHandlerRegistry -- registry mapping types to TypeHandlerProtocol instances + get_default_semantic_hasher -- global default SemanticHasherProtocol factory + get_default_type_handler_registry -- global default TypeHandlerRegistry factory + ContentIdentifiableMixin -- convenience mixin for content-identifiable objects + +Built-in handlers (importable for custom registry setup): + PathContentHandler + UUIDHandler + BytesHandler + FunctionHandler + TypeObjectHandler + register_builtin_handlers + +Legacy names (kept for backward compatibility): + HashableMixin -- legacy mixin from legacy_core (deprecated) + +Utility: + FileContentHasherProtocol + StringCacherProtocol + FunctionInfoExtractorProtocol + ArrowHasherProtocol +""" + +# --------------------------------------------------------------------------- +# New API -- SemanticHasherProtocol, registry, mixin +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Default hasher factories +# --------------------------------------------------------------------------- +from orcapod.hashing.defaults import ( + get_default_arrow_hasher, + get_default_semantic_hasher, + get_default_type_handler_registry, +) + +# --------------------------------------------------------------------------- +# File hashing utilities +# --------------------------------------------------------------------------- +from orcapod.hashing.file_hashers import BasicFileHasher, CachedFileHasher +from orcapod.hashing.hash_utils import hash_file +from orcapod.hashing.semantic_hashing.builtin_handlers import ( + BytesHandler, + FunctionHandler, + PathContentHandler, + TypeObjectHandler, + UUIDHandler, + register_builtin_handlers, +) +from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( + ContentIdentifiableMixin, +) + +# --------------------------------------------------------------------------- +# Legacy API (deprecated -- kept for backward compatibility) +# These imports are guarded because legacy_core.py has pre-existing import +# issues (e.g. references to removed types) that should not block the new API. +# --------------------------------------------------------------------------- +try: + from orcapod.hashing.legacy_core import ( + HashableMixin, + function_content_hash, + get_function_signature, + hash_function, + hash_packet, + hash_pathset, + hash_to_hex, + hash_to_int, + hash_to_uuid, + ) +except ImportError: + HashableMixin = None # type: ignore[assignment,misc] + function_content_hash = None # type: ignore[assignment] + get_function_signature = None # type: ignore[assignment] + hash_function = None # type: ignore[assignment] + hash_packet = None # type: ignore[assignment] + hash_pathset = None # type: ignore[assignment] + hash_to_hex = None # type: ignore[assignment] + hash_to_int = None # type: ignore[assignment] + hash_to_uuid = None # type: ignore[assignment] +from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + BuiltinTypeHandlerRegistry, + TypeHandlerRegistry, +) + +# --------------------------------------------------------------------------- +# Protocols (re-exported for convenience) +# --------------------------------------------------------------------------- +from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + ContentIdentifiableProtocol, + FileContentHasherProtocol, + FunctionInfoExtractorProtocol, + SemanticHasherProtocol, + SemanticTypeHasherProtocol, + StringCacherProtocol, + TypeHandlerProtocol, +) + +# --------------------------------------------------------------------------- +# __all__ -- defines the public surface of this package +# --------------------------------------------------------------------------- __all__ = [ - "FileContentHasher", - "LegacyPacketHasher", - "StringCacher", - "ObjectHasher", - "LegacyCompositeFileHasher", - "FunctionInfoExtractor", + # ---- New API: concrete implementation ---- + "BaseSemanticHasher", + "TypeHandlerRegistry", + "BuiltinTypeHandlerRegistry", + "get_default_type_handler_registry", + "get_default_semantic_hasher", + "ContentIdentifiableMixin", + # Built-in handlers + "PathContentHandler", + "UUIDHandler", + "BytesHandler", + "FunctionHandler", + "TypeObjectHandler", + "register_builtin_handlers", + # ---- Protocols ---- + "SemanticHasherProtocol", + "ContentIdentifiableProtocol", + "TypeHandlerProtocol", + "FileContentHasherProtocol", + "ArrowHasherProtocol", + "StringCacherProtocol", + "FunctionInfoExtractorProtocol", + "SemanticTypeHasherProtocol", + # ---- File hashing ---- + "BasicFileHasher", + "CachedFileHasher", "hash_file", - "hash_pathset", - "hash_packet", + # ---- Legacy / backward-compatible ---- + # TODO: remove legacy section + "get_default_arrow_hasher", + "HashableMixin", "hash_to_hex", "hash_to_int", "hash_to_uuid", "hash_function", "get_function_signature", "function_content_hash", - "HashableMixin", - "get_default_composite_file_hasher", - "get_default_object_hasher", - "get_default_arrow_hasher", - "ContentIdentifiableBase", + "hash_pathset", + "hash_packet", ] diff --git a/src/orcapod/hashing/arrow_hashers.py b/src/orcapod/hashing/arrow_hashers.py index 8576f836..77cb3b49 100644 --- a/src/orcapod/hashing/arrow_hashers.py +++ b/src/orcapod/hashing/arrow_hashers.py @@ -1,14 +1,15 @@ import hashlib +import json +from collections.abc import Callable from typing import Any + import pyarrow as pa -import json -from orcapod.semantic_types import SemanticTypeRegistry + from orcapod.hashing import arrow_serialization -from collections.abc import Callable from orcapod.hashing.visitors import SemanticHashingVisitor +from orcapod.semantic_types import SemanticTypeRegistry +from orcapod.types import ContentHash from orcapod.utils import arrow_utils -from orcapod.protocols.hashing_protocols import ContentHash - SERIALIZATION_METHOD_LUT: dict[str, Callable[[pa.Table], bytes]] = { "logical": arrow_serialization.serialize_table_logical, diff --git a/src/orcapod/hashing/arrow_serialization.py b/src/orcapod/hashing/arrow_serialization.py index fa0500f7..d271970c 100644 --- a/src/orcapod/hashing/arrow_serialization.py +++ b/src/orcapod/hashing/arrow_serialization.py @@ -725,7 +725,7 @@ def _serialize_array_fallback( try: value = array[i].as_py() _serialize_complex_value(buffer, value, data_type) - except Exception as e: + except Exception: # If .as_py() fails, try alternative approaches try: # For some array types, we can access scalar values differently @@ -1152,7 +1152,7 @@ def run_comprehensive_tests(): row_indep_hash = hashes["Row-order independent"] full_indep_hash = hashes["Fully order-independent"] - print(f"\nHash uniqueness:") + print("\nHash uniqueness:") print(f" Default != Col-independent: {default_hash != col_indep_hash}") print(f" Default != Row-independent: {default_hash != row_indep_hash}") print(f" Default != Fully independent: {default_hash != full_indep_hash}") diff --git a/src/orcapod/hashing/defaults.py b/src/orcapod/hashing/defaults.py index 20067616..5dd68ea7 100644 --- a/src/orcapod/hashing/defaults.py +++ b/src/orcapod/hashing/defaults.py @@ -1,62 +1,90 @@ -# A collection of utility function that provides a "default" implementation of hashers. -# This is often used as the fallback hasher in the library code. +# Default hasher accessors for the OrcaPod hashing system. +# +# All "default" hashers are obtained through the data context system, which is +# the single source of truth for versioned component configuration. The +# functions below are thin convenience wrappers around the context system so +# that call-sites don't need to import from orcapod.contexts directly. +# +# DO NOT construct hashers directly here (e.g. via versioned_hashers). +# That is the job of the context registry when it instantiates a DataContext +# from its JSON spec. Constructing them here would bypass versioning and +# produce hashers that are decoupled from the active data context. + +from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry from orcapod.protocols import hashing_protocols as hp -from orcapod.hashing.string_cachers import InMemoryCacher -# from orcapod.hashing.object_hashers import LegacyObjectHasher -from orcapod.hashing.function_info_extractors import FunctionInfoExtractorFactory -from orcapod.hashing.versioned_hashers import ( - get_versioned_semantic_arrow_hasher, - get_versioned_object_hasher, -) +def get_default_type_handler_registry() -> TypeHandlerRegistry: + """ + Return the TypeHandlerRegistry from the default data context. + Returns: + TypeHandlerRegistry: The type handler registry from the default data context. + """ + from orcapod.contexts import get_default_context -def get_default_arrow_hasher( - cache_file_hash: bool | hp.StringCacher = True, -) -> hp.ArrowHasher: + return get_default_context().type_handler_registry + + +def get_default_semantic_hasher() -> hp.SemanticHasherProtocol: """ - Get the default Arrow hasher with semantic type support. - If `cache_file_hash` is True, it uses an in-memory cacher for caching hash values. If a `StringCacher` is provided, it uses that for caching file hashes. + Return the SemanticHasherProtocol from the default data context. + + The hasher is owned by the active DataContext and is therefore consistent + with all other versioned components (arrow hasher, type converter, etc.) + that belong to the same context. + + Returns: + SemanticHasherProtocol: The object hasher from the default data context. """ - arrow_hasher = get_versioned_semantic_arrow_hasher() - if cache_file_hash: - # use unlimited caching - if cache_file_hash is True: - string_cacher = InMemoryCacher(max_size=None) - else: - string_cacher = cache_file_hash + # Late import to avoid circular dependencies: contexts imports from + # protocols and hashing, so we must not import contexts at module level + # inside the hashing package. + from orcapod.contexts import get_default_context - arrow_hasher.set_cacher("path", string_cacher) + return get_default_context().semantic_hasher - return arrow_hasher +def get_default_arrow_hasher( + cache_file_hash: bool | hp.StringCacherProtocol = True, +) -> hp.ArrowHasherProtocol: + """ + Return the ArrowHasherProtocol from the default data context. -def get_default_object_hasher() -> hp.ObjectHasher: - object_hasher = get_versioned_object_hasher() - return object_hasher + If ``cache_file_hash`` is True an in-memory StringCacherProtocol is attached to + the hasher so that repeated hashes of the same file path are served from + cache. Pass a ``StringCacherProtocol`` instance to use a custom caching backend + (e.g. SQLite-backed). + Note: caching is applied on top of the context's arrow hasher each time + this function is called. If you need a single shared cached instance, + obtain it once and store it yourself. -# def get_legacy_object_hasher() -> hp.ObjectHasher: -# function_info_extractor = ( -# FunctionInfoExtractorFactory.create_function_info_extractor( -# strategy="signature" -# ) -# ) -# return LegacyObjectHasher(function_info_extractor=function_info_extractor) + Args: + cache_file_hash: True to use an ephemeral in-memory cache, a + StringCacherProtocol instance to use a custom cache, or False/None to + disable caching. + Returns: + ArrowHasherProtocol: The arrow hasher from the default data context, + optionally with file-hash caching attached. + """ + from typing import Any -# def get_default_composite_file_hasher(with_cache=True) -> LegacyCompositeFileHasher: -# if with_cache: -# # use unlimited caching -# string_cacher = InMemoryCacher(max_size=None) -# return LegacyPathLikeHasherFactory.create_cached_legacy_composite(string_cacher) -# return LegacyPathLikeHasherFactory.create_basic_legacy_composite() + from orcapod.contexts import get_default_context + arrow_hasher: Any = get_default_context().arrow_hasher -# def get_default_composite_file_hasher_with_cacher( -# cacher=None, -# ) -> LegacyCompositeFileHasher: -# if cacher is None: -# cacher = InMemoryCacher(max_size=None) -# return LegacyPathLikeHasherFactory.create_cached_legacy_composite(cacher) + if cache_file_hash: + from orcapod.hashing.string_cachers import InMemoryCacher + + if cache_file_hash is True: + string_cacher: hp.StringCacherProtocol = InMemoryCacher(max_size=None) + else: + string_cacher = cache_file_hash + + # set_cacher is present on SemanticArrowHasher but not on the + # ArrowHasherProtocol protocol, so we call it via Any to avoid a type error. + arrow_hasher.set_cacher("path", string_cacher) + + return arrow_hasher diff --git a/src/orcapod/hashing/file_hashers.py b/src/orcapod/hashing/file_hashers.py index fd3cd819..7e82c063 100644 --- a/src/orcapod/hashing/file_hashers.py +++ b/src/orcapod/hashing/file_hashers.py @@ -1,9 +1,11 @@ +import os + from orcapod.hashing.hash_utils import hash_file from orcapod.protocols.hashing_protocols import ( - FileContentHasher, - StringCacher, + FileContentHasherProtocol, + StringCacherProtocol, ) -from orcapod.types import PathLike +from orcapod.types import ContentHash, PathLike class BasicFileHasher: @@ -17,7 +19,7 @@ def __init__( self.algorithm = algorithm self.buffer_size = buffer_size - def hash_file(self, file_path: PathLike) -> bytes: + def hash_file(self, file_path: PathLike) -> ContentHash: return hash_file( file_path, algorithm=self.algorithm, buffer_size=self.buffer_size ) @@ -28,192 +30,19 @@ class CachedFileHasher: def __init__( self, - file_hasher: FileContentHasher, - string_cacher: StringCacher, + file_hasher: FileContentHasherProtocol, + string_cacher: StringCacherProtocol, ): self.file_hasher = file_hasher self.string_cacher = string_cacher - def hash_file(self, file_path: PathLike) -> bytes: - cache_key = f"file:{file_path}" + def hash_file(self, file_path: PathLike) -> ContentHash: + stat = os.stat(file_path) + cache_key = f"file:{file_path}:{stat.st_mtime_ns}:{stat.st_size}" cached_value = self.string_cacher.get_cached(cache_key) if cached_value is not None: - return bytes.fromhex(cached_value) - - value = self.file_hasher.hash_file(file_path) - self.string_cacher.set_cached(cache_key, value.hex()) - return value - - -# ----------------Legacy implementations for backward compatibility----------------- - - -# class LegacyDefaultFileHasher: -# def __init__( -# self, -# algorithm: str = "sha256", -# buffer_size: int = 65536, -# ): -# self.algorithm = algorithm -# self.buffer_size = buffer_size - -# def hash_file(self, file_path: PathLike) -> str: -# return legacy_core.hash_file( -# file_path, algorithm=self.algorithm, buffer_size=self.buffer_size -# ) - - -# class LegacyCachedFileHasher: -# """File hasher with caching.""" - -# def __init__( -# self, -# file_hasher: LegacyFileHasher, -# string_cacher: StringCacher, -# ): -# self.file_hasher = file_hasher -# self.string_cacher = string_cacher - -# def hash_file(self, file_path: PathLike) -> str: -# cache_key = f"file:{file_path}" -# cached_value = self.string_cacher.get_cached(cache_key) -# if cached_value is not None: -# return cached_value - -# value = self.file_hasher.hash_file(file_path) -# self.string_cacher.set_cached(cache_key, value) -# return value - - -# class LegacyDefaultPathsetHasher: -# """Default pathset hasher that composes file hashing.""" - -# def __init__( -# self, -# file_hasher: LegacyFileHasher, -# char_count: int | None = 32, -# ): -# self.file_hasher = file_hasher -# self.char_count = char_count - -# def _hash_file_to_hex(self, file_path: PathLike) -> str: -# return self.file_hasher.hash_file(file_path) - -# def hash_pathset(self, pathset: PathSet) -> str: -# """Hash a pathset using the injected file hasher.""" -# return legacy_core.hash_pathset( -# pathset, -# char_count=self.char_count, -# file_hasher=self.file_hasher.hash_file, # Inject the method -# ) - - -# class LegacyDefaultPacketHasher: -# """Default packet hasher that composes pathset hashing.""" - -# def __init__( -# self, -# pathset_hasher: LegacyPathSetHasher, -# char_count: int | None = 32, -# prefix: str = "", -# ): -# self.pathset_hasher = pathset_hasher -# self.char_count = char_count -# self.prefix = prefix - -# def _hash_pathset_to_hex(self, pathset: PathSet): -# return self.pathset_hasher.hash_pathset(pathset) - -# def hash_packet(self, packet: PacketLike) -> str: -# """Hash a packet using the injected pathset hasher.""" -# hash_str = legacy_core.hash_packet( -# packet, -# char_count=self.char_count, -# prefix_algorithm=False, # Will apply prefix on our own -# pathset_hasher=self._hash_pathset_to_hex, # Inject the method -# ) -# return f"{self.prefix}-{hash_str}" if self.prefix else hash_str - - -# # Convenience composite implementation -# class LegacyDefaultCompositeFileHasher: -# """Composite hasher that implements all interfaces.""" - -# def __init__( -# self, -# file_hasher: LegacyFileHasher, -# char_count: int | None = 32, -# packet_prefix: str = "", -# ): -# self.file_hasher = file_hasher -# self.pathset_hasher = LegacyDefaultPathsetHasher(self.file_hasher, char_count) -# self.packet_hasher = LegacyDefaultPacketHasher( -# self.pathset_hasher, char_count, packet_prefix -# ) - -# def hash_file(self, file_path: PathLike) -> str: -# return self.file_hasher.hash_file(file_path) - -# def hash_pathset(self, pathset: PathSet) -> str: -# return self.pathset_hasher.hash_pathset(pathset) - -# def hash_packet(self, packet: PacketLike) -> str: -# return self.packet_hasher.hash_packet(packet) - - -# # Factory for easy construction -# class LegacyPathLikeHasherFactory: -# """Factory for creating various hasher combinations.""" - -# @staticmethod -# def create_basic_legacy_composite( -# algorithm: str = "sha256", -# buffer_size: int = 65536, -# char_count: int | None = 32, -# ) -> LegacyCompositeFileHasher: -# """Create a basic composite hasher.""" -# file_hasher = LegacyDefaultFileHasher(algorithm, buffer_size) -# # use algorithm as the prefix for the packet hasher -# return LegacyDefaultCompositeFileHasher( -# file_hasher, char_count, packet_prefix=algorithm -# ) - -# @staticmethod -# def create_cached_legacy_composite( -# string_cacher: StringCacher, -# algorithm: str = "sha256", -# buffer_size: int = 65536, -# char_count: int | None = 32, -# ) -> LegacyCompositeFileHasher: -# """Create a composite hasher with file caching.""" -# basic_file_hasher = LegacyDefaultFileHasher(algorithm, buffer_size) -# cached_file_hasher = LegacyCachedFileHasher(basic_file_hasher, string_cacher) -# return LegacyDefaultCompositeFileHasher( -# cached_file_hasher, char_count, packet_prefix=algorithm -# ) - -# @staticmethod -# def create_legacy_file_hasher( -# string_cacher: StringCacher | None = None, -# algorithm: str = "sha256", -# buffer_size: int = 65536, -# ) -> LegacyFileHasher: -# """Create just a file hasher, optionally with caching.""" -# default_hasher = LegacyDefaultFileHasher(algorithm, buffer_size) -# if string_cacher is None: -# return default_hasher -# else: -# return LegacyCachedFileHasher(default_hasher, string_cacher) + return ContentHash.from_string(cached_value) -# @staticmethod -# def create_file_hasher( -# string_cacher: StringCacher | None = None, -# algorithm: str = "sha256", -# buffer_size: int = 65536, -# ) -> FileContentHasher: -# """Create just a file hasher, optionally with caching.""" -# basic_hasher = BasicFileHasher(algorithm, buffer_size) -# if string_cacher is None: -# return basic_hasher -# else: -# return CachedFileHasher(basic_hasher, string_cacher) + result = self.file_hasher.hash_file(file_path) + self.string_cacher.set_cached(cache_key, result.to_string()) + return result diff --git a/src/orcapod/hashing/hash_utils.py b/src/orcapod/hashing/hash_utils.py index 292aa303..0addcb77 100644 --- a/src/orcapod/hashing/hash_utils.py +++ b/src/orcapod/hashing/hash_utils.py @@ -1,29 +1,39 @@ +import hashlib +import inspect import logging -import json +import zlib +from collections.abc import Callable, Collection from pathlib import Path -from collections.abc import Collection, Callable -import hashlib + import xxhash -import zlib -import inspect + +from orcapod.types import ContentHash logger = logging.getLogger(__name__) -# TODO: extract default char count as config def combine_hashes( *hashes: str, order: bool = False, prefix_hasher_id: bool = False, - hex_char_count: int | None = 20, + hex_char_count: int | None = None, ) -> str: - """Combine hashes into a single hash string.""" + """ + Combine multiple hash strings into a single SHA-256 hash string. - # Sort for deterministic order regardless of input order - if order: - prepared_hashes = sorted(hashes) - else: - prepared_hashes = list(hashes) + Args: + *hashes: Hash strings to combine. + order: If True, sort inputs before combining so the result is + order-independent. If False (default), insertion order + is preserved. + prefix_hasher_id: If True, prefix the result with ``"sha256@"``. + hex_char_count: Number of hex characters to return. None (default) + returns the full 64-character SHA-256 hex digest. + + Returns: + A hex string (optionally truncated / prefixed). + """ + prepared_hashes = sorted(hashes) if order else list(hashes) combined = "".join(prepared_hashes) combined_hash = hashlib.sha256(combined.encode()).hexdigest() if hex_char_count is not None: @@ -33,315 +43,28 @@ def combine_hashes( return combined_hash -def serialize_through_json(processed_obj) -> bytes: - """ - Create a deterministic string representation of a processed object structure. +def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> ContentHash: + """Calculate the hash of a file using the specified algorithm. Args: - processed_obj: The processed object to serialize - - Returns: - A bytes object ready for hashing - """ - # TODO: add type check of processed obj - return json.dumps(processed_obj, sort_keys=True, separators=(",", ":")).encode( - "utf-8" - ) - - -# def process_structure( -# obj: Any, -# visited: set[int] | None = None, -# object_hasher: ObjectHasher | None = None, -# function_info_extractor: FunctionInfoExtractor | None = None, -# compressed: bool = False, -# force_hash: bool = True, -# ) -> Any: -# """ -# Recursively process a structure to prepare it for hashing. - -# Args: -# obj: The object or structure to process -# visited: Set of object ids already visited (to handle circular references) -# function_info_extractor: FunctionInfoExtractor to be used for extracting necessary function representation - -# Returns: -# A processed version of the structure suitable for stable hashing -# """ -# # Initialize the visited set if this is the top-level call -# if visited is None: -# visited = set() -# else: -# visited = visited.copy() # Copy to avoid modifying the original set - -# # Check for circular references - use object's memory address -# # NOTE: While id() is not stable across sessions, we only use it within a session -# # to detect circular references, not as part of the final hash -# obj_id = id(obj) -# if obj_id in visited: -# logger.debug( -# f"Detected circular reference for object of type {type(obj).__name__}" -# ) -# return "CircularRef" # Don't include the actual id in hash output - -# # For objects that could contain circular references, add to visited -# if isinstance(obj, (dict, list, tuple, set)) or not isinstance( -# obj, (str, int, float, bool, type(None)) -# ): -# visited.add(obj_id) - -# # Handle None -# if obj is None: -# return None - -# # TODO: currently using runtime_checkable on ContentIdentifiable protocol -# # Re-evaluate this strategy to see if a faster / more robust check could be used -# if isinstance(obj, ContentIdentifiable): -# logger.debug( -# f"Processing ContentHashableBase instance of type {type(obj).__name__}" -# ) -# if compressed: -# # if compressed, the content identifiable object is immediately replaced with -# # its hashed string identity -# if object_hasher is None: -# raise ValueError( -# "ObjectHasher must be provided to hash ContentIdentifiable objects with compressed=True" -# ) -# return object_hasher.hash_object(obj.identity_structure(), compressed=True) -# else: -# # if not compressed, replace the object with expanded identity structure and re-process -# return process_structure( -# obj.identity_structure(), -# visited, -# object_hasher=object_hasher, -# function_info_extractor=function_info_extractor, -# ) - -# # Handle basic types -# if isinstance(obj, (str, int, float, bool)): -# return obj - -# # Handle bytes and bytearray -# if isinstance(obj, (bytes, bytearray)): -# logger.debug( -# f"Converting bytes/bytearray of length {len(obj)} to hex representation" -# ) -# return obj.hex() - -# # Handle Path objects -# if isinstance(obj, Path): -# logger.debug(f"Converting Path object to string: {obj}") -# return str(obj) - -# # Handle UUID objects -# if isinstance(obj, UUID): -# logger.debug(f"Converting UUID to string: {obj}") -# return str(obj) - -# # Handle named tuples (which are subclasses of tuple) -# if hasattr(obj, "_fields") and isinstance(obj, tuple): -# logger.debug(f"Processing named tuple of type {type(obj).__name__}") -# # For namedtuples, convert to dict and then process -# d = {field: getattr(obj, field) for field in obj._fields} # type: ignore -# return process_structure( -# d, -# visited, -# object_hasher=object_hasher, -# function_info_extractor=function_info_extractor, -# compressed=compressed, -# ) - -# # Handle mappings (dict-like objects) -# if isinstance(obj, Mapping): -# # Process both keys and values -# processed_items = [ -# ( -# process_structure( -# k, -# visited, -# object_hasher=object_hasher, -# function_info_extractor=function_info_extractor, -# compressed=compressed, -# ), -# process_structure( -# v, -# visited, -# object_hasher=object_hasher, -# function_info_extractor=function_info_extractor, -# compressed=compressed, -# ), -# ) -# for k, v in obj.items() -# ] - -# # Sort by the processed keys for deterministic order -# processed_items.sort(key=lambda x: str(x[0])) - -# # Create a new dictionary with string keys based on processed keys -# # TODO: consider checking for possibly problematic values in processed_k -# # and issue a warning -# return { -# str(processed_k): processed_v -# for processed_k, processed_v in processed_items -# } - -# # Handle sets and frozensets -# if isinstance(obj, (set, frozenset)): -# logger.debug( -# f"Processing set/frozenset of type {type(obj).__name__} with {len(obj)} items" -# ) -# # Process each item first, then sort the processed results -# processed_items = [ -# process_structure( -# item, -# visited, -# object_hasher=object_hasher, -# function_info_extractor=function_info_extractor, -# compressed=compressed, -# ) -# for item in obj -# ] -# return sorted(processed_items, key=str) - -# # Handle collections (list-like objects) -# if isinstance(obj, Collection): -# logger.debug( -# f"Processing collection of type {type(obj).__name__} with {len(obj)} items" -# ) -# return [ -# process_structure( -# item, -# visited, -# object_hasher=object_hasher, -# function_info_extractor=function_info_extractor, -# compressed=compressed, -# ) -# for item in obj -# ] - -# # For functions, use the function_content_hash -# if callable(obj) and hasattr(obj, "__code__"): -# logger.debug(f"Processing function: {getattr(obj, '__name__')}") -# if function_info_extractor is not None: -# # Use the extractor to get a stable representation -# function_info = function_info_extractor.extract_function_info(obj) -# logger.debug(f"Extracted function info: {function_info} for {obj.__name__}") - -# # simply return the function info as a stable representation -# return function_info -# else: -# raise ValueError( -# f"Function {obj} encountered during processing but FunctionInfoExtractor is missing" -# ) - -# # handle data types -# if isinstance(obj, type): -# logger.debug(f"Processing class/type: {obj.__name__}") -# return f"type:{obj.__name__}" - -# # For other objects, attempt to create deterministic representation only if force_hash=True -# class_name = obj.__class__.__name__ -# module_name = obj.__class__.__module__ -# if force_hash: -# try: -# import re - -# logger.debug( -# f"Processing generic object of type {module_name}.{class_name}" -# ) - -# # Try to get a stable dict representation if possible -# if hasattr(obj, "__dict__"): -# # Sort attributes to ensure stable order -# attrs = sorted( -# (k, v) for k, v in obj.__dict__.items() if not k.startswith("_") -# ) -# # Limit to first 10 attributes to avoid extremely long representations -# if len(attrs) > 10: -# logger.debug( -# f"Object has {len(attrs)} attributes, limiting to first 10" -# ) -# attrs = attrs[:10] -# attr_strs = [f"{k}={type(v).__name__}" for k, v in attrs] -# obj_repr = f"{{{', '.join(attr_strs)}}}" -# else: -# # Get basic repr but remove memory addresses -# logger.debug( -# "Object has no __dict__, using repr() with memory address removal" -# ) -# obj_repr = repr(obj) -# if len(obj_repr) > 1000: -# logger.debug( -# f"Object repr is {len(obj_repr)} chars, truncating to 1000" -# ) -# obj_repr = obj_repr[:1000] + "..." -# # Remove memory addresses which look like '0x7f9a1c2b3d4e' -# obj_repr = re.sub(r" at 0x[0-9a-f]+", " at 0xMEMADDR", obj_repr) - -# return f"{module_name}.{class_name}:{obj_repr}" -# except Exception as e: -# # Last resort - use class name only -# logger.warning(f"Failed to process object representation: {e}") -# try: -# return f"object:{obj.__class__.__module__}.{obj.__class__.__name__}" -# except AttributeError: -# logger.error("Could not determine object class, using UnknownObject") -# return "UnknownObject" -# else: -# raise ValueError( -# f"Processing of {obj} of type {module_name}.{class_name} is not supported" -# ) - - -# def hash_object( -# obj: Any, -# function_info_extractor: FunctionInfoExtractor | None = None, -# compressed: bool = False, -# ) -> bytes: -# # Process the object to handle nested structures and HashableMixin instances -# processed = process_structure( -# obj, function_info_extractor=function_info_extractor, compressed=compressed -# ) - -# # Serialize the processed structure -# json_str = json.dumps(processed, sort_keys=True, separators=(",", ":")).encode( -# "utf-8" -# ) -# logger.debug( -# f"Successfully serialized {type(obj).__name__} using custom serializer" -# ) - -# # Create the hash -# return hashlib.sha256(json_str).digest() - - -def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> bytes: - """ - Calculate the hash of a file using the specified algorithm. - - Parameters: - file_path (str): Path to the file to hash - algorithm (str): Hash algorithm to use - options include: - 'md5', 'sha1', 'sha256', 'sha512', 'xxh64', 'crc32', 'hash_path' - buffer_size (int): Size of chunks to read from the file at a time + file_path: Path to the file to hash. + algorithm: Hash algorithm to use — options include: + 'md5', 'sha1', 'sha256', 'sha512', 'xxh64', 'crc32', 'hash_path'. + buffer_size: Size of chunks to read from the file at a time. Returns: - str: Hexadecimal digest of the hash + A ContentHash with method set to the algorithm name and digest + containing the raw hash bytes. """ - # Verify the file exists if not Path(file_path).is_file(): raise FileNotFoundError(f"The file {file_path} does not exist") - # Handle special case for 'hash_path' algorithm + # Hash the path string itself rather than file content if algorithm == "hash_path": - # Hash the name of the file instead of its content - # This is useful for cases where the file content is well known or - # not relevant hasher = hashlib.sha256() - hasher.update(file_path.encode("utf-8")) - return hasher.digest() + hasher.update(str(file_path).encode("utf-8")) + return ContentHash(method=algorithm, digest=hasher.digest()) - # Handle non-cryptographic hash functions if algorithm == "xxh64": hasher = xxhash.xxh64() with open(file_path, "rb") as file: @@ -350,7 +73,7 @@ def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> bytes: if not data: break hasher.update(data) - return hasher.digest() + return ContentHash(method=algorithm, digest=hasher.digest()) if algorithm == "crc32": crc = 0 @@ -360,9 +83,11 @@ def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> bytes: if not data: break crc = zlib.crc32(data, crc) - return (crc & 0xFFFFFFFF).to_bytes(4, byteorder="big") + return ContentHash( + method=algorithm, + digest=(crc & 0xFFFFFFFF).to_bytes(4, byteorder="big"), + ) - # Handle cryptographic hash functions from hashlib try: hasher = hashlib.new(algorithm) except ValueError: @@ -378,7 +103,19 @@ def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> bytes: break hasher.update(data) - return hasher.digest() + return ContentHash(method=algorithm, digest=hasher.digest()) + + +def _is_in_string(line: str, pos: int) -> bool: + """Helper to check if a position in a line is inside a string literal.""" + in_single = False + in_double = False + for i in range(pos): + if line[i] == "'" and not in_double and (i == 0 or line[i - 1] != "\\"): + in_single = not in_single + elif line[i] == '"' and not in_single and (i == 0 or line[i - 1] != "\\"): + in_double = not in_double + return in_single or in_double def get_function_signature( @@ -392,26 +129,23 @@ def get_function_signature( Get a stable string representation of a function's signature. Args: - func: The function to process - include_defaults: Whether to include default values - include_module: Whether to include the module name + func: The function to process. + name_override: Override the function name in the output. + include_defaults: Whether to include default parameter values. + include_module: Whether to include the module name. + output_names: Unused; reserved for future use. Returns: - A string representation of the function signature + A string representation of the function signature. """ sig = inspect.signature(func) + parts: dict[str, object] = {} - # Build the signature string - parts = {} - - # Add module if requested if include_module and hasattr(func, "__module__"): parts["module"] = func.__module__ - # Add function name parts["name"] = name_override or func.__name__ - # Add parameters param_strs = [] for name, param in sig.parameters.items(): param_str = str(param) @@ -421,30 +155,18 @@ def get_function_signature( parts["params"] = f"({', '.join(param_strs)})" - # Add return annotation if present if sig.return_annotation is not inspect.Signature.empty: parts["returns"] = sig.return_annotation - # TODO: fix return handling - fn_string = f"{parts['module'] + '.' if 'module' in parts else ''}{parts['name']}{parts['params']}" + fn_string = ( + f"{parts['module'] + '.' if 'module' in parts else ''}" + f"{parts['name']}{parts['params']}" + ) if "returns" in parts: - fn_string = fn_string + f"-> {str(parts['returns'])}" + fn_string += f"-> {parts['returns']}" return fn_string -def _is_in_string(line, pos): - """Helper to check if a position in a line is inside a string literal.""" - # This is a simplified check - would need proper parsing for robust handling - in_single = False - in_double = False - for i in range(pos): - if line[i] == "'" and not in_double and (i == 0 or line[i - 1] != "\\"): - in_single = not in_single - elif line[i] == '"' and not in_single and (i == 0 or line[i - 1] != "\\"): - in_double = not in_double - return in_single or in_double - - def get_function_components( func: Callable, name_override: str | None = None, @@ -461,65 +183,51 @@ def get_function_components( Extract the components of a function that determine its identity for hashing. Args: - func: The function to process - include_name: Whether to include the function name - include_module: Whether to include the module name - include_declaration: Whether to include the function declaration line - include_docstring: Whether to include the function's docstring - include_comments: Whether to include comments in the function body - preserve_whitespace: Whether to preserve original whitespace/indentation - include_annotations: Whether to include function type annotations - include_code_properties: Whether to include code object properties + func: The function to process. + name_override: Override the function name in the output. + include_name: Whether to include the function name. + include_module: Whether to include the module name. + include_declaration: Whether to include the function declaration line. + include_docstring: Whether to include the function's docstring. + include_comments: Whether to include comments in the function body. + preserve_whitespace: Whether to preserve original whitespace/indentation. + include_annotations: Whether to include function type annotations. + include_code_properties: Whether to include code object properties. Returns: - A list of string components + A list of string components. """ components = [] - # Add function name if include_name: components.append(f"name:{name_override or func.__name__}") - # Add module if include_module and hasattr(func, "__module__"): components.append(f"module:{func.__module__}") - # Get the function's source code try: source = inspect.getsource(func) - # Handle whitespace preservation if not preserve_whitespace: source = inspect.cleandoc(source) - # Process source code components if not include_declaration: - # Remove function declaration line lines = source.split("\n") for i, line in enumerate(lines): - if line.strip().startswith("def "): + if line.strip().startswith(("def ", "async def ")): lines.pop(i) break source = "\n".join(lines) - # Extract and handle docstring separately if needed if not include_docstring and func.__doc__: - # This approach assumes the docstring is properly indented - # For multi-line docstrings, we need more sophisticated parsing doc_str = inspect.getdoc(func) - if doc_str: - doc_lines = doc_str.split("\n") - else: - doc_lines = [] + doc_lines = doc_str.split("\n") if doc_str else [] doc_pattern = '"""' + "\\n".join(doc_lines) + '"""' - # Try different quote styles if doc_pattern not in source: doc_pattern = "'''" + "\\n".join(doc_lines) + "'''" source = source.replace(doc_pattern, "") - # Handle comments (this is more complex and may need a proper parser) if not include_comments: - # This is a simplified approach - would need a proper parser for robust handling lines = source.split("\n") for i, line in enumerate(lines): comment_pos = line.find("#") @@ -530,7 +238,6 @@ def get_function_components( components.append(f"source:{source}") except (IOError, TypeError): - # If source can't be retrieved, fall back to signature components.append(f"name:{name_override or func.__name__}") try: sig = inspect.signature(func) @@ -538,7 +245,6 @@ def get_function_components( except ValueError: components.append("builtin:True") - # Add function annotations if requested if ( include_annotations and hasattr(func, "__annotations__") @@ -548,7 +254,6 @@ def get_function_components( annotations_str = ";".join(f"{k}:{v}" for k, v in sorted_annotations) components.append(f"annotations:{annotations_str}") - # Add code object properties if requested if include_code_properties: code = func.__code__ stable_code_props = { diff --git a/src/orcapod/hashing/legacy_core.py b/src/orcapod/hashing/legacy_core.py deleted file mode 100644 index 83d172b6..00000000 --- a/src/orcapod/hashing/legacy_core.py +++ /dev/null @@ -1,1128 +0,0 @@ -import hashlib -import inspect -import json -import logging -import zlib -from orcapod.protocols.hashing_protocols import FunctionInfoExtractor -from functools import partial -from os import PathLike -from pathlib import Path -from typing import ( - Any, - Callable, - Collection, - Dict, - Literal, - Mapping, - Optional, - Set, - TypeVar, - Union, -) -from uuid import UUID - - -import xxhash - -from orcapod.types import PathSet, Packet, PacketLike -from orcapod.utils.name import find_noncolliding_name - -WARN_NONE_IDENTITY = False -""" -Stable Hashing Library -====================== - -A library for creating stable, content-based hashes that remain consistent across Python sessions, -suitable for arbitrarily nested data structures and custom objects via HashableMixin. -""" - - -# Configure logging with __name__ for proper hierarchy -logger = logging.getLogger(__name__) - -# Type for recursive dictionary structures -T = TypeVar("T") -NestedDict = Dict[ - str, Union[str, int, float, bool, None, "NestedDict", list, tuple, set] -] - - -def configure_logging(level=logging.INFO, enable_console=True, log_file=None): - """ - Optional helper to configure logging for this library. - - Users can choose to use this or configure logging themselves. - - Args: - level: The logging level (default: INFO) - enable_console: Whether to log to the console (default: True) - log_file: Path to a log file (default: None) - """ - lib_logger = logging.getLogger(__name__) - lib_logger.setLevel(level) - - # Create a formatter - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - - # Add console handler if requested - if enable_console: - console_handler = logging.StreamHandler() - console_handler.setFormatter(formatter) - lib_logger.addHandler(console_handler) - - # Add file handler if requested - if log_file: - file_handler = logging.FileHandler(log_file) - file_handler.setFormatter(formatter) - lib_logger.addHandler(file_handler) - - lib_logger.debug("Logging configured for stable hash library") - return lib_logger - - -def serialize_for_hashing(processed_obj): - """ - Create a deterministic string representation of a processed object structure. - - This function aims to be more stable than json.dumps() by implementing - a custom serialization approach for the specific needs of hashing. - - Args: - processed_obj: The processed object to serialize - - Returns: - A bytes object ready for hashing - """ - if processed_obj is None: - return b"null" - - if isinstance(processed_obj, bool): - return b"true" if processed_obj else b"false" - - if isinstance(processed_obj, (int, float)): - return str(processed_obj).encode("utf-8") - - if isinstance(processed_obj, str): - # Escape quotes and backslashes to ensure consistent representation - escaped = processed_obj.replace("\\", "\\\\").replace('"', '\\"') - return f'"{escaped}"'.encode("utf-8") - - if isinstance(processed_obj, list): - items = [serialize_for_hashing(item) for item in processed_obj] - return b"[" + b",".join(items) + b"]" - - if isinstance(processed_obj, dict): - # Sort keys for deterministic order - sorted_items = sorted(processed_obj.items(), key=lambda x: str(x[0])) - serialized_items = [ - serialize_for_hashing(k) + b":" + serialize_for_hashing(v) - for k, v in sorted_items - ] - return b"{" + b",".join(serialized_items) + b"}" - - # Fallback for any other type - should not happen after _process_structure - logger.warning( - f"Unhandled type in _serialize_for_hashing: {type(processed_obj).__name__}. " - "Using str() representation as fallback, which may not be stable." - ) - return str(processed_obj).encode("utf-8") - - -class HashableMixin: - """ - A mixin that provides content-based hashing functionality. - - To use this mixin: - 1. Inherit from HashableMixin in your class - 2. Override identity_structure() to return a representation of your object's content - 3. Use content_hash(), content_hash_int(), or __hash__() as needed - - Example: - class MyClass(HashableMixin): - def __init__(self, name, value): - self.name = name - self.value = value - - def identity_structure(self): - return {'name': self.name, 'value': self.value} - """ - - def identity_structure(self) -> Any: - """ - Return a structure that represents the identity of this object. - - Override this method in your subclass to provide a stable representation - of your object's content. The structure should contain all fields that - determine the object's identity. - - Returns: - Any: A structure representing this object's content, or None to use default hash - """ - return None - - def content_hash(self, char_count: Optional[int] = 16) -> str: - """ - Generate a stable string hash based on the object's content. - - Args: - char_count: Number of characters to include in the hex digest (None for full hash) - - Returns: - str: A hexadecimal digest representing the object's content - """ - # Get the identity structure - structure = self.identity_structure() - - # If no custom structure is provided, use the class name - # We avoid using id() since it's not stable across sessions - if structure is None: - if WARN_NONE_IDENTITY: - logger.warning( - f"HashableMixin.content_hash called on {self.__class__.__name__} " - "instance that returned identity_structure() of None. " - "Using class name as default identity, which may not correctly reflect object uniqueness." - ) - # Fall back to class name for consistent behavior - return f"HashableMixin-DefaultIdentity-{self.__class__.__name__}" - - # Generate a hash from the identity structure - logger.debug( - f"Generating content hash for {self.__class__.__name__} using identity structure" - ) - return hash_to_hex(structure, char_count=char_count) - - def content_hash_int(self, hexdigits: int = 16) -> int: - """ - Generate a stable integer hash based on the object's content. - - Args: - hexdigits: Number of hex digits to use for the integer conversion - - Returns: - int: An integer representing the object's content - """ - # Get the identity structure - structure = self.identity_structure() - - # If no custom structure is provided, use the class name - # We avoid using id() since it's not stable across sessions - if structure is None: - if WARN_NONE_IDENTITY: - logger.warning( - f"HashableMixin.content_hash_int called on {self.__class__.__name__} " - "instance that returned identity_structure() of None. " - "Using class name as default identity, which may not correctly reflect object uniqueness." - ) - # Use the same default identity as content_hash for consistency - default_identity = ( - f"HashableMixin-DefaultIdentity-{self.__class__.__name__}" - ) - return hash_to_int(default_identity, hexdigits=hexdigits) - - # Generate a hash from the identity structure - logger.debug( - f"Generating content hash (int) for {self.__class__.__name__} using identity structure" - ) - return hash_to_int(structure, hexdigits=hexdigits) - - def content_hash_uuid(self) -> UUID: - """ - Generate a stable UUID hash based on the object's content. - - Returns: - UUID: A UUID representing the object's content - """ - # Get the identity structure - structure = self.identity_structure() - - # If no custom structure is provided, use the class name - # We avoid using id() since it's not stable across sessions - if structure is None: - if WARN_NONE_IDENTITY: - logger.warning( - f"HashableMixin.content_hash_uuid called on {self.__class__.__name__} " - "instance without identity_structure() implementation. " - "Using class name as default identity, which may not correctly reflect object uniqueness." - ) - # Use the same default identity as content_hash for consistency - default_identity = ( - f"HashableMixin-DefaultIdentity-{self.__class__.__name__}" - ) - return hash_to_uuid(default_identity) - - # Generate a hash from the identity structure - logger.debug( - f"Generating content hash (UUID) for {self.__class__.__name__} using identity structure" - ) - return hash_to_uuid(structure) - - def __hash__(self) -> int: - """ - Hash implementation that uses the identity structure if provided, - otherwise falls back to the superclass's hash method. - - Returns: - int: A hash value based on either content or identity - """ - # Get the identity structure - structure = self.identity_structure() - - # If no custom structure is provided, use the superclass's hash - if structure is None: - logger.warning( - f"HashableMixin.__hash__ called on {self.__class__.__name__} " - "instance without identity_structure() implementation. " - "Falling back to super().__hash__() which is not stable across sessions." - ) - return super().__hash__() - - # Generate a hash and convert to integer - logger.debug( - f"Generating hash for {self.__class__.__name__} using identity structure" - ) - return hash_to_int(structure) - - -# Core hashing functions that serve as the unified interface - - -def legacy_hash( - obj: Any, function_info_extractor: FunctionInfoExtractor | None = None -) -> bytes: - # Process the object to handle nested structures and HashableMixin instances - processed = process_structure(obj, function_info_extractor=function_info_extractor) - - # Serialize the processed structure - try: - # Use custom serialization for maximum stability - json_str = serialize_for_hashing(processed) - logger.debug( - f"Successfully serialized {type(obj).__name__} using custom serializer" - ) - except Exception as e: - # Fall back to string representation if serialization fails - logger.warning( - f"Custom serialization failed for {type(obj).__name__}, " - f"falling back to string representation. Error: {e}" - ) - try: - # Try standard JSON first - json_str = json.dumps(processed, sort_keys=True).encode("utf-8") - logger.info("Successfully used standard JSON serialization as fallback") - except (TypeError, ValueError) as json_err: - # If JSON also fails, use simple string representation - logger.warning( - f"JSON serialization also failed: {json_err}. " - "Using basic string representation as last resort." - ) - json_str = str(processed).encode("utf-8") - - # Create the hash - return hashlib.sha256(json_str).digest() - - -def hash_to_hex( - obj: Any, - char_count: int | None = 32, - function_info_extractor: FunctionInfoExtractor | None = None, -) -> str: - """ - Create a stable hex hash of any object that remains consistent across Python sessions. - - Args: - obj: The object to hash - can be a primitive type, nested data structure, or - HashableMixin instance - char_count: Number of hex characters to return (None for full hash) - - Returns: - A hex string hash - """ - - # Create the hash - hash_hex = legacy_hash(obj, function_info_extractor=function_info_extractor).hex() - - # Return the requested number of characters - if char_count is not None: - logger.debug(f"Using char_count: {char_count}") - if char_count > len(hash_hex): - raise ValueError( - f"Cannot truncate to {char_count} chars, hash only has {len(hash_hex)}" - ) - return hash_hex[:char_count] - return hash_hex - - -def hash_to_int( - obj: Any, - hexdigits: int = 16, - function_info_extractor: FunctionInfoExtractor | None = None, -) -> int: - """ - Convert any object to a stable integer hash that remains consistent across Python sessions. - - Args: - obj: The object to hash - hexdigits: Number of hex digits to use for the integer conversion - - Returns: - An integer hash - """ - hash_hex = hash_to_hex( - obj, char_count=hexdigits, function_info_extractor=function_info_extractor - ) - return int(hash_hex, 16) - - -def hash_to_uuid( - obj: Any, function_info_extractor: FunctionInfoExtractor | None = None -) -> UUID: - """ - Convert any object to a stable UUID hash that remains consistent across Python sessions. - - Args: - obj: The object to hash - - Returns: - A UUID hash - """ - hash_hex = hash_to_hex( - obj, char_count=32, function_info_extractor=function_info_extractor - ) - # TODO: update this to use UUID5 with a namespace on hash bytes output instead - return UUID(hash_hex) - - -# Helper function for processing nested structures -def process_structure( - obj: Any, - visited: Optional[Set[int]] = None, - function_info_extractor: FunctionInfoExtractor | None = None, -) -> Any: - """ - Recursively process a structure to prepare it for hashing. - - Args: - obj: The object or structure to process - visited: Set of object ids already visited (to handle circular references) - - Returns: - A processed version of the structure suitable for stable hashing - """ - # Initialize the visited set if this is the top-level call - if visited is None: - visited = set() - - # Check for circular references - use object's memory address - # NOTE: While id() is not stable across sessions, we only use it within a session - # to detect circular references, not as part of the final hash - obj_id = id(obj) - if obj_id in visited: - logger.debug( - f"Detected circular reference for object of type {type(obj).__name__}" - ) - return "CircularRef" # Don't include the actual id in hash output - - # For objects that could contain circular references, add to visited - if isinstance(obj, (dict, list, tuple, set)) or not isinstance( - obj, (str, int, float, bool, type(None)) - ): - visited.add(obj_id) - - # Handle None - if obj is None: - return None - - # If the object is a HashableMixin, use its content_hash - if isinstance(obj, HashableMixin): - logger.debug(f"Processing HashableMixin instance of type {type(obj).__name__}") - return obj.content_hash() - - from .content_identifiable import ContentIdentifiableBase - - if isinstance(obj, ContentIdentifiableBase): - logger.debug( - f"Processing ContentHashableBase instance of type {type(obj).__name__}" - ) - return process_structure( - obj.identity_structure(), visited, function_info_extractor - ) - - # Handle basic types - if isinstance(obj, (str, int, float, bool)): - return obj - - # Handle bytes and bytearray - if isinstance(obj, (bytes, bytearray)): - logger.debug( - f"Converting bytes/bytearray of length {len(obj)} to hex representation" - ) - return obj.hex() - - # Handle Path objects - if isinstance(obj, Path): - logger.debug(f"Converting Path object to string: {obj}") - return str(obj) - - # Handle UUID objects - if isinstance(obj, UUID): - logger.debug(f"Converting UUID to string: {obj}") - return str(obj) - - # Handle named tuples (which are subclasses of tuple) - if hasattr(obj, "_fields") and isinstance(obj, tuple): - logger.debug(f"Processing named tuple of type {type(obj).__name__}") - # For namedtuples, convert to dict and then process - d = {field: getattr(obj, field) for field in obj._fields} # type: ignore - return process_structure(d, visited, function_info_extractor) - - # Handle mappings (dict-like objects) - if isinstance(obj, Mapping): - # Process both keys and values - processed_items = [ - ( - process_structure(k, visited, function_info_extractor), - process_structure(v, visited, function_info_extractor), - ) - for k, v in obj.items() - ] - - # Sort by the processed keys for deterministic order - processed_items.sort(key=lambda x: str(x[0])) - - # Create a new dictionary with string keys based on processed keys - return { - str(processed_k): processed_v - for processed_k, processed_v in processed_items - } - - # Handle sets and frozensets - if isinstance(obj, (set, frozenset)): - logger.debug( - f"Processing set/frozenset of type {type(obj).__name__} with {len(obj)} items" - ) - # Process each item first, then sort the processed results - processed_items = [ - process_structure(item, visited, function_info_extractor) for item in obj - ] - return sorted(processed_items, key=str) - - # Handle collections (list-like objects) - if isinstance(obj, Collection) and not isinstance(obj, str): - logger.debug( - f"Processing collection of type {type(obj).__name__} with {len(obj)} items" - ) - return [ - process_structure(item, visited, function_info_extractor) for item in obj - ] - - # For functions, use the function_content_hash - if callable(obj) and hasattr(obj, "__code__"): - logger.debug(f"Processing function: {obj.__name__}") - if function_info_extractor is not None: - # Use the extractor to get a stable representation - function_info = function_info_extractor.extract_function_info(obj) - logger.debug(f"Extracted function info: {function_info} for {obj.__name__}") - - # simply return the function info as a stable representation - return function_info - else: - # Default to using legacy function content hash - return function_content_hash(obj) - - # For other objects, create a deterministic representation - try: - import re - - class_name = obj.__class__.__name__ - module_name = obj.__class__.__module__ - - logger.debug(f"Processing generic object of type {module_name}.{class_name}") - - # Try to get a stable dict representation if possible - if hasattr(obj, "__dict__"): - # Sort attributes to ensure stable order - attrs = sorted( - (k, v) for k, v in obj.__dict__.items() if not k.startswith("_") - ) - # Limit to first 10 attributes to avoid extremely long representations - if len(attrs) > 10: - logger.debug( - f"Object has {len(attrs)} attributes, limiting to first 10" - ) - attrs = attrs[:10] - attr_strs = [f"{k}={type(v).__name__}" for k, v in attrs] - obj_repr = f"{{{', '.join(attr_strs)}}}" - else: - # Get basic repr but remove memory addresses - logger.debug( - "Object has no __dict__, using repr() with memory address removal" - ) - obj_repr = repr(obj) - if len(obj_repr) > 1000: - logger.debug( - f"Object repr is {len(obj_repr)} chars, truncating to 1000" - ) - obj_repr = obj_repr[:1000] + "..." - # Remove memory addresses which look like '0x7f9a1c2b3d4e' - obj_repr = re.sub(r" at 0x[0-9a-f]+", " at 0xMEMADDR", obj_repr) - - return f"{module_name}.{class_name}-{obj_repr}" - except Exception as e: - # Last resort - use class name only - logger.warning(f"Failed to process object representation: {e}") - try: - return f"Object-{obj.__class__.__module__}.{obj.__class__.__name__}" - except AttributeError: - logger.error("Could not determine object class, using UnknownObject") - return "UnknownObject" - - -# Function hashing utilities - - -# Legacy compatibility functions - - -def hash_dict(d: NestedDict) -> UUID: - """ - Hash a dictionary with stable results across sessions. - - Args: - d: The dictionary to hash (can be arbitrarily nested) - - Returns: - A UUID hash of the dictionary - """ - return hash_to_uuid(d) - - -def stable_hash(s: Any) -> int: - """ - Create a stable hash that returns the same integer value across sessions. - - Args: - s: The object to hash - - Returns: - An integer hash - """ - return hash_to_int(s) - - -# Hashing of packets and PathSet - - -class PathSetHasher: - def __init__(self, char_count=32): - self.char_count = char_count - - def hash_pathset(self, pathset: PathSet) -> str: - if isinstance(pathset, str) or isinstance(pathset, PathLike): - pathset = Path(pathset) - if not pathset.exists(): - raise FileNotFoundError(f"Path {pathset} does not exist") - if pathset.is_dir(): - # iterate over all entries in the directory include subdirectory (single step) - hash_dict = {} - for entry in pathset.iterdir(): - file_name = find_noncolliding_name(entry.name, hash_dict) - hash_dict[file_name] = self.hash_pathset(entry) - return hash_to_hex(hash_dict, char_count=self.char_count) - else: - # it's a file, hash it directly - return hash_file(pathset) - - if isinstance(pathset, Collection): - hash_dict = {} - for path in pathset: - # TODO: consider handling of None value - if path is None: - raise NotImplementedError( - "Case of PathSet containing None is not supported yet" - ) - file_name = find_noncolliding_name(Path(path).name, hash_dict) - hash_dict[file_name] = self.hash_pathset(path) - return hash_to_hex(hash_dict, char_count=self.char_count) - - raise ValueError(f"PathSet of type {type(pathset)} is not supported") - - def hash_file(self, filepath) -> str: ... - - def id(self) -> str: ... - - -def hash_packet_with_psh( - packet: Packet, algo: PathSetHasher, prefix_algorithm: bool = True -) -> str: - """ - Generate a hash for a packet based on its content. - - Args: - packet: The packet to hash - algorithm: The algorithm to use for hashing - prefix_algorithm: Whether to prefix the hash with the algorithm name - - Returns: - A hexadecimal digest of the packet's content - """ - hash_results = {} - for key, pathset in packet.items(): - # TODO: fix pathset handling - hash_results[key] = algo.hash_pathset(pathset) # type: ignore - - packet_hash = hash_to_hex(hash_results) - - if prefix_algorithm: - # Prefix the hash with the algorithm name - packet_hash = f"{algo.id()}-{packet_hash}" - - return packet_hash - - -def hash_packet( - packet: PacketLike, - algorithm: str = "sha256", - buffer_size: int = 65536, - char_count: Optional[int] = 32, - prefix_algorithm: bool = True, - pathset_hasher: Callable[..., str] | None = None, -) -> str: - """ - Generate a hash for a packet based on its content. - - Args: - packet: The packet to hash - - Returns: - A hexadecimal digest of the packet's content - """ - if pathset_hasher is None: - pathset_hasher = partial( - hash_pathset, - algorithm=algorithm, - buffer_size=buffer_size, - char_count=char_count, - ) - - hash_results = {} - for key, pathset in packet.items(): - # TODO: fix Pathset handling - hash_results[key] = pathset_hasher(pathset) # type: ignore - - packet_hash = hash_to_hex(hash_results, char_count=char_count) - - if prefix_algorithm: - # Prefix the hash with the algorithm name - packet_hash = f"{algorithm}-{packet_hash}" - - return packet_hash - - -def hash_pathset( - pathset: PathSet, - algorithm="sha256", - buffer_size=65536, - char_count: int | None = 32, - file_hasher: Callable[..., str] | None = None, -) -> str: - """ - Generate hash of the pathset based primarily on the content of the files. - If the pathset is a collection of files or a directory, the name of the file - will be included in the hash calculation. - - Currently only support hashing of Pathset if Pathset points to a single file. - """ - if file_hasher is None: - file_hasher = partial(hash_file, algorithm=algorithm, buffer_size=buffer_size) - - if isinstance(pathset, str) or isinstance(pathset, PathLike): - pathset = Path(pathset) - if not pathset.exists(): - raise FileNotFoundError(f"Path {pathset} does not exist") - if pathset.is_dir(): - # iterate over all entries in the directory include subdirectory (single step) - hash_dict = {} - for entry in pathset.iterdir(): - file_name = find_noncolliding_name(entry.name, hash_dict) - hash_dict[file_name] = hash_pathset( - entry, - algorithm=algorithm, - buffer_size=buffer_size, - char_count=char_count, - file_hasher=file_hasher, - ) - return hash_to_hex(hash_dict, char_count=char_count) - else: - # it's a file, hash it directly - return file_hasher(pathset) - - if isinstance(pathset, Collection): - hash_dict = {} - for path in pathset: - if path is None: - raise NotImplementedError( - "Case of PathSet containing None is not supported yet" - ) - file_name = find_noncolliding_name(Path(path).name, hash_dict) - hash_dict[file_name] = hash_pathset( - path, - algorithm=algorithm, - buffer_size=buffer_size, - char_count=char_count, - file_hasher=file_hasher, - ) - return hash_to_hex(hash_dict, char_count=char_count) - - -def hash_file(file_path, algorithm="sha256", buffer_size=65536) -> str: - """ - Calculate the hash of a file using the specified algorithm. - - Parameters: - file_path (str): Path to the file to hash - algorithm (str): Hash algorithm to use - options include: - 'md5', 'sha1', 'sha256', 'sha512', 'xxh64', 'crc32', 'hash_path' - buffer_size (int): Size of chunks to read from the file at a time - - Returns: - str: Hexadecimal digest of the hash - """ - # Verify the file exists - if not Path(file_path).is_file(): - raise FileNotFoundError(f"The file {file_path} does not exist") - - # Handle special case for 'hash_path' algorithm - if algorithm == "hash_path": - # Hash the name of the file instead of its content - # This is useful for cases where the file content is well known or - # not relevant - return hash_to_hex(file_path) - - # Handle non-cryptographic hash functions - if algorithm == "xxh64": - hasher = xxhash.xxh64() - with open(file_path, "rb") as file: - while True: - data = file.read(buffer_size) - if not data: - break - hasher.update(data) - return hasher.hexdigest() - - if algorithm == "crc32": - crc = 0 - with open(file_path, "rb") as file: - while True: - data = file.read(buffer_size) - if not data: - break - crc = zlib.crc32(data, crc) - return format(crc & 0xFFFFFFFF, "08x") # Convert to hex string - - # Handle cryptographic hash functions from hashlib - try: - hasher = hashlib.new(algorithm) - except ValueError: - valid_algorithms = ", ".join(sorted(hashlib.algorithms_available)) - raise ValueError( - f"Invalid algorithm: {algorithm}. Available algorithms: {valid_algorithms}, xxh64, crc32" - ) - - with open(file_path, "rb") as file: - while True: - data = file.read(buffer_size) - if not data: - break - hasher.update(data) - - return hasher.hexdigest() - - -def get_function_signature( - func: Callable, - name_override: str | None = None, - include_defaults: bool = True, - include_module: bool = True, - output_names: Collection[str] | None = None, -) -> str: - """ - Get a stable string representation of a function's signature. - - Args: - func: The function to process - include_defaults: Whether to include default values - include_module: Whether to include the module name - - Returns: - A string representation of the function signature - """ - sig = inspect.signature(func) - - # Build the signature string - parts = {} - - # Add module if requested - if include_module and hasattr(func, "__module__"): - parts["module"] = func.__module__ - - # Add function name - parts["name"] = name_override or func.__name__ - - # Add parameters - param_strs = [] - for name, param in sig.parameters.items(): - param_str = str(param) - if not include_defaults and "=" in param_str: - param_str = param_str.split("=")[0].strip() - param_strs.append(param_str) - - parts["params"] = f"({', '.join(param_strs)})" - - # Add return annotation if present - if sig.return_annotation is not inspect.Signature.empty: - parts["returns"] = sig.return_annotation - - # TODO: fix return handling - fn_string = f"{parts['module'] + '.' if 'module' in parts else ''}{parts['name']}{parts['params']}" - if "returns" in parts: - fn_string = fn_string + f"-> {str(parts['returns'])}" - return fn_string - - -def _is_in_string(line, pos): - """Helper to check if a position in a line is inside a string literal.""" - # This is a simplified check - would need proper parsing for robust handling - in_single = False - in_double = False - for i in range(pos): - if line[i] == "'" and not in_double and (i == 0 or line[i - 1] != "\\"): - in_single = not in_single - elif line[i] == '"' and not in_single and (i == 0 or line[i - 1] != "\\"): - in_double = not in_double - return in_single or in_double - - -def get_function_components( - func: Callable, - name_override: str | None = None, - include_name: bool = True, - include_module: bool = True, - include_declaration: bool = True, - include_docstring: bool = True, - include_comments: bool = True, - preserve_whitespace: bool = True, - include_annotations: bool = True, - include_code_properties: bool = True, -) -> list: - """ - Extract the components of a function that determine its identity for hashing. - - Args: - func: The function to process - include_name: Whether to include the function name - include_module: Whether to include the module name - include_declaration: Whether to include the function declaration line - include_docstring: Whether to include the function's docstring - include_comments: Whether to include comments in the function body - preserve_whitespace: Whether to preserve original whitespace/indentation - include_annotations: Whether to include function type annotations - include_code_properties: Whether to include code object properties - - Returns: - A list of string components - """ - components = [] - - # Add function name - if include_name: - components.append(f"name:{name_override or func.__name__}") - - # Add module - if include_module and hasattr(func, "__module__"): - components.append(f"module:{func.__module__}") - - # Get the function's source code - try: - source = inspect.getsource(func) - - # Handle whitespace preservation - if not preserve_whitespace: - source = inspect.cleandoc(source) - - # Process source code components - if not include_declaration: - # Remove function declaration line - lines = source.split("\n") - for i, line in enumerate(lines): - if line.strip().startswith("def "): - lines.pop(i) - break - source = "\n".join(lines) - - # Extract and handle docstring separately if needed - if not include_docstring and func.__doc__: - # This approach assumes the docstring is properly indented - # For multi-line docstrings, we need more sophisticated parsing - doc_str = inspect.getdoc(func) - if doc_str: - doc_lines = doc_str.split("\n") - else: - doc_lines = [] - doc_pattern = '"""' + "\\n".join(doc_lines) + '"""' - # Try different quote styles - if doc_pattern not in source: - doc_pattern = "'''" + "\\n".join(doc_lines) + "'''" - source = source.replace(doc_pattern, "") - - # Handle comments (this is more complex and may need a proper parser) - if not include_comments: - # This is a simplified approach - would need a proper parser for robust handling - lines = source.split("\n") - for i, line in enumerate(lines): - comment_pos = line.find("#") - if comment_pos >= 0 and not _is_in_string(line, comment_pos): - lines[i] = line[:comment_pos].rstrip() - source = "\n".join(lines) - - components.append(f"source:{source}") - - except (IOError, TypeError): - # If source can't be retrieved, fall back to signature - components.append(f"name:{name_override or func.__name__}") - try: - sig = inspect.signature(func) - components.append(f"signature:{str(sig)}") - except ValueError: - components.append("builtin:True") - - # Add function annotations if requested - if ( - include_annotations - and hasattr(func, "__annotations__") - and func.__annotations__ - ): - sorted_annotations = sorted(func.__annotations__.items()) - annotations_str = ";".join(f"{k}:{v}" for k, v in sorted_annotations) - components.append(f"annotations:{annotations_str}") - - # Add code object properties if requested - if include_code_properties: - code = func.__code__ - stable_code_props = { - "co_argcount": code.co_argcount, - "co_kwonlyargcount": getattr(code, "co_kwonlyargcount", 0), - "co_nlocals": code.co_nlocals, - "co_varnames": code.co_varnames[: code.co_argcount], - } - components.append(f"code_properties:{stable_code_props}") - - return components - - -def function_content_hash( - func: Callable, - include_name: bool = True, - include_module: bool = True, - include_declaration: bool = True, - char_count: Optional[int] = 32, -) -> str: - """ - Compute a stable hash based on a function's source code and other properties. - - Args: - func: The function to hash - include_name: Whether to include the function name in the hash - include_module: Whether to include the module name in the hash - include_declaration: Whether to include the function declaration line - char_count: Number of characters to include in the result - - Returns: - A hex string hash of the function's content - """ - logger.debug(f"Generating content hash for function '{func.__name__}'") - components = get_function_components( - func, - include_name=include_name, - include_module=include_module, - include_declaration=include_declaration, - ) - - # Join all components and compute hash - combined = "\n".join(components) - logger.debug(f"Function components joined, length: {len(combined)} characters") - return hash_to_hex(combined, char_count=char_count) - - -def hash_function( - function: Callable, - function_hash_mode: Literal["content", "signature", "name"] = "content", - return_type: Literal["hex", "int", "uuid"] = "hex", - name_override: Optional[str] = None, - content_kwargs=None, - hash_kwargs=None, -) -> Union[str, int, UUID]: - """ - Hash a function based on specified mode and return type. - - Args: - function: The function to hash - function_hash_mode: The mode of hashing ('content', 'signature', or 'name') - return_type: The format of the hash to return ('hex', 'int', or 'uuid') - content_kwargs: Additional arguments to pass to the mode-specific function content - extractors: - - "content": arguments for get_function_components - - "signature": arguments for get_function_signature - - "name": no underlying function used - simply function.__name__ or name_override if provided - hash_kwargs: Additional arguments for the hashing function that depends on the return type - - "hex": arguments for hash_to_hex - - "int": arguments for hash_to_int - - "uuid": arguments for hash_to_uuid - - Returns: - A hash of the function in the requested format - - Example: - >>> def example(x, y=10): return x + y - >>> hash_function(example) # Returns content hash as string - >>> hash_function(example, function_hash_mode="signature") # Returns signature hash - >>> hash_function(example, return_type="int") # Returns content hash as integer - """ - content_kwargs = content_kwargs or {} - hash_kwargs = hash_kwargs or {} - - logger.debug( - f"Hashing function '{function.__name__}' using mode '{function_hash_mode}'" - + (f" with name override '{name_override}'" if name_override else "") - ) - - if function_hash_mode == "content": - hash_content = "\n".join( - get_function_components( - function, name_override=name_override, **content_kwargs - ) - ) - elif function_hash_mode == "signature": - hash_content = get_function_signature(function, **content_kwargs) - elif function_hash_mode == "name": - hash_content = name_override or function.__name__ - else: - err_msg = f"Unknown function_hash_mode: {function_hash_mode}" - logger.error(err_msg) - raise ValueError(err_msg) - - # Convert to the requested return type - if return_type == "hex": - hash_value = hash_to_hex(hash_content, **hash_kwargs) - elif return_type == "int": - hash_value = hash_to_int(hash_content, **hash_kwargs) - elif return_type == "uuid": - hash_value = hash_to_uuid(hash_content, **hash_kwargs) - else: - err_msg = f"Unknown return_type: {return_type}" - logger.error(err_msg) - raise ValueError(err_msg) - - logger.debug(f"Generated hash value as {return_type}: {hash_value}") - return hash_value diff --git a/src/orcapod/hashing/object_hashers.py b/src/orcapod/hashing/object_hashers.py deleted file mode 100644 index 09b01ddb..00000000 --- a/src/orcapod/hashing/object_hashers.py +++ /dev/null @@ -1,306 +0,0 @@ -import hashlib -import json -import logging -import uuid -from abc import ABC, abstractmethod -from collections.abc import Collection, Mapping -from pathlib import Path -from typing import Any -from uuid import UUID - -from orcapod.protocols import hashing_protocols as hp - -logger = logging.getLogger(__name__) - - -class ObjectHasherBase(ABC): - @abstractmethod - def hash_object(self, obj: object) -> hp.ContentHash: ... - - @property - @abstractmethod - def hasher_id(self) -> str: ... - - def hash_to_hex( - self, obj: Any, char_count: int | None = None, prefix_hasher_id: bool = False - ) -> str: - content_hash = self.hash_object(obj) - hex_str = content_hash.to_hex() - - # TODO: clean up this logic, as char_count handling is messy - if char_count is not None: - if char_count > len(hex_str): - raise ValueError( - f"Cannot truncate to {char_count} chars, hash only has {len(hex_str)}" - ) - hex_str = hex_str[:char_count] - if prefix_hasher_id: - hex_str = self.hasher_id + "@" + hex_str - return hex_str - - def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: - """ - Hash an object to an integer. - - Args: - obj (Any): The object to hash. - hexdigits (int): Number of hexadecimal digits to use for the hash. - - Returns: - int: The integer representation of the hash. - """ - hex_hash = self.hash_to_hex(obj, char_count=hexdigits) - return int(hex_hash, 16) - - def hash_to_uuid( - self, - obj: Any, - namespace: uuid.UUID = uuid.NAMESPACE_OID, - ) -> uuid.UUID: - """Convert hash to proper UUID5.""" - # TODO: decide whether to use to_hex or digest here - return uuid.uuid5(namespace, self.hash_object(obj).to_hex()) - - -class BasicObjectHasher(ObjectHasherBase): - """ - Default object hasher used throughout the codebase. - """ - - def __init__( - self, - hasher_id: str, - function_info_extractor: hp.FunctionInfoExtractor | None = None, - ): - self._hasher_id = hasher_id - self.function_info_extractor = function_info_extractor - - @property - def hasher_id(self) -> str: - return self._hasher_id - - def process_structure( - self, - obj: Any, - visited: set[int] | None = None, - force_hash: bool = True, - ) -> Any: - """ - Recursively process a structure to prepare it for hashing. - - Args: - obj: The object or structure to process - visited: Set of object ids already visited (to handle circular references) - function_info_extractor: FunctionInfoExtractor to be used for extracting necessary function representation - - Returns: - A processed version of the structure suitable for stable hashing - """ - # Initialize the visited set if this is the top-level call - if visited is None: - visited = set() - else: - visited = visited.copy() # Copy to avoid modifying the original set - - # Check for circular references - use object's memory address - # NOTE: While id() is not stable across sessions, we only use it within a session - # to detect circular references, not as part of the final hash - obj_id = id(obj) - if obj_id in visited: - logger.debug( - f"Detected circular reference for object of type {type(obj).__name__}" - ) - return "CircularRef" # Don't include the actual id in hash output - - # TODO: revisit the hashing of the ContentHash - if isinstance(obj, hp.ContentHash): - return (obj.method, obj.digest.hex()) - - # For objects that could contain circular references, add to visited - if isinstance(obj, (dict, list, tuple, set)) or not isinstance( - obj, (str, int, float, bool, type(None)) - ): - visited.add(obj_id) - - # Handle None - if obj is None: - return None - - # TODO: currently using runtime_checkable on ContentIdentifiable protocol - # Re-evaluate this strategy to see if a faster / more robust check could be used - if isinstance(obj, hp.ContentIdentifiable): - logger.debug( - f"Processing ContentHashableBase instance of type {type(obj).__name__}" - ) - return self._hash_object(obj.identity_structure(), visited=visited).to_hex() - - # Handle basic types - if isinstance(obj, (str, int, float, bool)): - return obj - - # Handle bytes and bytearray - if isinstance(obj, (bytes, bytearray)): - logger.debug( - f"Converting bytes/bytearray of length {len(obj)} to hex representation" - ) - return obj.hex() - - # Handle Path objects - if isinstance(obj, Path): - logger.debug(f"Converting Path object to string: {obj}") - raise NotImplementedError( - "Path objects are not supported in this hasher. Please convert to string." - ) - return str(obj) - - # Handle UUID objects - if isinstance(obj, UUID): - logger.debug(f"Converting UUID to string: {obj}") - raise NotImplementedError( - "UUID objects are not supported in this hasher. Please convert to string." - ) - return str(obj) - - # Handle named tuples (which are subclasses of tuple) - if hasattr(obj, "_fields") and isinstance(obj, tuple): - logger.debug(f"Processing named tuple of type {type(obj).__name__}") - # For namedtuples, convert to dict and then process - d = {field: getattr(obj, field) for field in obj._fields} # type: ignore - return self.process_structure(d, visited) - - # Handle mappings (dict-like objects) - if isinstance(obj, Mapping): - # Process both keys and values - processed_items = [ - ( - self.process_structure(k, visited), - self.process_structure(v, visited), - ) - for k, v in obj.items() - ] - - # Sort by the processed keys for deterministic order - processed_items.sort(key=lambda x: str(x[0])) - - # Create a new dictionary with string keys based on processed keys - # TODO: consider checking for possibly problematic values in processed_k - # and issue a warning - return { - str(processed_k): processed_v - for processed_k, processed_v in processed_items - } - - # Handle sets and frozensets - if isinstance(obj, (set, frozenset)): - logger.debug( - f"Processing set/frozenset of type {type(obj).__name__} with {len(obj)} items" - ) - # Process each item first, then sort the processed results - processed_items = [self.process_structure(item, visited) for item in obj] - return sorted(processed_items, key=str) - - # Handle collections (list-like objects) - if isinstance(obj, Collection): - logger.debug( - f"Processing collection of type {type(obj).__name__} with {len(obj)} items" - ) - return [self.process_structure(item, visited) for item in obj] - - # For functions, use the function_content_hash - if callable(obj) and hasattr(obj, "__code__"): - logger.debug(f"Processing function: {getattr(obj, '__name__')}") - if self.function_info_extractor is not None: - # Use the extractor to get a stable representation - function_info = self.function_info_extractor.extract_function_info(obj) - logger.debug( - f"Extracted function info: {function_info} for {obj.__name__}" - ) - - # simply return the function info as a stable representation - return function_info - else: - raise ValueError( - f"Function {obj} encountered during processing but FunctionInfoExtractor is missing" - ) - - # handle data types - if isinstance(obj, type): - logger.debug(f"Processing class/type: {obj.__name__}") - return f"type:{obj.__name__}" - - # For other objects, attempt to create deterministic representation only if force_hash=True - class_name = obj.__class__.__name__ - module_name = obj.__class__.__module__ - if force_hash: - try: - import re - - logger.debug( - f"Processing generic object of type {module_name}.{class_name}" - ) - - # Try to get a stable dict representation if possible - if hasattr(obj, "__dict__"): - # Sort attributes to ensure stable order - attrs = sorted( - (k, v) for k, v in obj.__dict__.items() if not k.startswith("_") - ) - # Limit to first 10 attributes to avoid extremely long representations - if len(attrs) > 10: - logger.debug( - f"Object has {len(attrs)} attributes, limiting to first 10" - ) - attrs = attrs[:10] - attr_strs = [f"{k}={type(v).__name__}" for k, v in attrs] - obj_repr = f"{{{', '.join(attr_strs)}}}" - else: - # Get basic repr but remove memory addresses - logger.debug( - "Object has no __dict__, using repr() with memory address removal" - ) - obj_repr = repr(obj) - if len(obj_repr) > 1000: - logger.debug( - f"Object repr is {len(obj_repr)} chars, truncating to 1000" - ) - obj_repr = obj_repr[:1000] + "..." - # Remove memory addresses which look like '0x7f9a1c2b3d4e' - obj_repr = re.sub(r" at 0x[0-9a-f]+", " at 0xMEMADDR", obj_repr) - - return f"{module_name}.{class_name}:{obj_repr}" - except Exception as e: - # Last resort - use class name only - logger.warning(f"Failed to process object representation: {e}") - try: - return f"object:{obj.__class__.__module__}.{obj.__class__.__name__}" - except AttributeError: - logger.error( - "Could not determine object class, using UnknownObject" - ) - return "UnknownObject" - else: - raise ValueError( - f"Processing of {obj} of type {module_name}.{class_name} is not supported" - ) - - def _hash_object( - self, - obj: Any, - visited: set[int] | None = None, - ) -> hp.ContentHash: - # Process the object to handle nested structures and HashableMixin instances - processed = self.process_structure(obj, visited=visited) - - # Serialize the processed structure - json_str = json.dumps(processed, sort_keys=True, separators=(",", ":")).encode( - "utf-8" - ) - logger.debug( - f"Successfully serialized {type(obj).__name__} using custom serializer" - ) - - # Create the hash - return hp.ContentHash(self.hasher_id, hashlib.sha256(json_str).digest()) - - def hash_object(self, obj: object) -> hp.ContentHash: - return self._hash_object(obj) diff --git a/src/orcapod/hashing/semantic_hashing/__init__.py b/src/orcapod/hashing/semantic_hashing/__init__.py new file mode 100644 index 00000000..bc120c18 --- /dev/null +++ b/src/orcapod/hashing/semantic_hashing/__init__.py @@ -0,0 +1,66 @@ +""" +orcapod.hashing.semantic_hashing +================================= +Sub-package containing all components of the semantic hashing system: + + BaseSemanticHasher -- content-based recursive object hasher + TypeHandlerRegistry -- MRO-aware registry mapping types → TypeHandlerProtocol + BuiltinTypeHandlerRegistry -- pre-populated registry with built-in handlers + ContentIdentifiableMixin -- convenience mixin for content-identifiable objects + +Built-in TypeHandlerProtocol implementations: + PathContentHandler -- pathlib.Path → file-content hash + UUIDHandler -- uuid.UUID → canonical string + BytesHandler -- bytes/bytearray → hex string + FunctionHandler -- callable → via FunctionInfoExtractorProtocol + TypeObjectHandler -- type objects → "type:." + register_builtin_handlers -- populate a registry with all of the above + +Function info extractors (used by FunctionHandler): + FunctionNameExtractor + FunctionSignatureExtractor + FunctionInfoExtractorFactory +""" + +from orcapod.hashing.semantic_hashing.builtin_handlers import ( + BytesHandler, + FunctionHandler, + PathContentHandler, + TypeObjectHandler, + UUIDHandler, + register_builtin_handlers, +) +from orcapod.hashing.semantic_hashing.content_identifiable_mixin import ( + ContentIdentifiableMixin, +) +from orcapod.hashing.semantic_hashing.function_info_extractors import ( + FunctionInfoExtractorFactory, + FunctionNameExtractor, + FunctionSignatureExtractor, +) +from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + BuiltinTypeHandlerRegistry, + TypeHandlerRegistry, +) + +__all__ = [ + # Core hasher + "BaseSemanticHasher", + # Registry + "TypeHandlerRegistry", + "BuiltinTypeHandlerRegistry", + # Mixin + "ContentIdentifiableMixin", + # Built-in handlers + "PathContentHandler", + "UUIDHandler", + "BytesHandler", + "FunctionHandler", + "TypeObjectHandler", + "register_builtin_handlers", + # Function info extractors + "FunctionNameExtractor", + "FunctionSignatureExtractor", + "FunctionInfoExtractorFactory", +] diff --git a/src/orcapod/hashing/semantic_hashing/builtin_handlers.py b/src/orcapod/hashing/semantic_hashing/builtin_handlers.py new file mode 100644 index 00000000..931d7cc5 --- /dev/null +++ b/src/orcapod/hashing/semantic_hashing/builtin_handlers.py @@ -0,0 +1,392 @@ +""" +Built-in TypeHandlerProtocol implementations for the SemanticHasherProtocol system. + +This module provides handlers for all Python types that the SemanticHasherProtocol +knows how to process out of the box: + + - PathContentHandler -- pathlib.Path: returns ContentHash of file content + - UUIDHandler -- uuid.UUID: canonical string representation + - BytesHandler -- bytes / bytearray: hex string representation + - FunctionHandler -- callable with __code__: via FunctionInfoExtractorProtocol + - TypeObjectHandler -- type objects (classes): stable "type:" string + +Note: ContentHash requires no handler -- it is recognised as a terminal by +``hash_object`` and returned as-is. + +The module also exposes ``register_builtin_handlers(registry)`` which is +called automatically when the global default registry is first accessed. + +Extending the system +-------------------- +To add a handler for a third-party type, create a class that implements the +TypeHandlerProtocol protocol (a single ``handle(obj, hasher)`` method) and register +it: + + from orcapod.hashing.semantic_hashing.type_handler_registry import get_default_type_handler_registry + get_default_type_handler_registry().register(MyType, MyTypeHandler()) +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from orcapod.types import PathLike, Schema + +if TYPE_CHECKING: + from orcapod.hashing.semantic_hashing.type_handler_registry import ( + TypeHandlerRegistry, + ) + from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + FileContentHasherProtocol, + SemanticHasherProtocol, + ) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Individual handlers +# --------------------------------------------------------------------------- + + +class PathContentHandler: + """ + Handler for pathlib.Path objects. + + Hashes the *content* of the file at the given path using the injected + FileContentHasherProtocol, producing a stable content-addressed identifier. + The resulting bytes are stored as a hex string embedded in the resolved + structure. + + The path must refer to an existing, readable file. Directories and + missing paths are not supported and will raise an error -- if you need + a path-as-string handler, register a separate handler for that use case + or return a ``str`` from ``identity_structure()`` instead of a ``Path``. + + Args: + file_hasher: Any object with a ``hash_file(path) -> ContentHash`` + method (satisfies the FileContentHasherProtocol protocol). + """ + + def __init__(self, file_hasher: FileContentHasherProtocol) -> None: + self.file_hasher = file_hasher + + def handle(self, obj: PathLike, hasher: "SemanticHasherProtocol") -> Any: + path: Path = Path(obj) + + if not path.exists(): + raise FileNotFoundError( + f"PathContentHandler: path does not exist: {path!r}. " + "Paths must refer to existing files for content-based hashing. " + "If you intended to hash the path string, return str(path) from " + "identity_structure() instead of a Path object." + ) + + if path.is_dir(): + raise IsADirectoryError( + f"PathContentHandler: path is a directory: {path!r}. " + "Only regular files are supported for content-based hashing." + ) + + logger.debug("PathContentHandler: hashing file content at %s", path) + return self.file_hasher.hash_file(path) + + +class UUIDHandler: + """ + Handler for uuid.UUID objects. + + Converts the UUID to its canonical hyphenated string representation + (e.g. ``"550e8400-e29b-41d4-a716-446655440000"``), which is stable, + human-readable, and unambiguous. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + return str(obj) + + +class BytesHandler: + """ + Handler for bytes and bytearray objects. + + Converts binary data to its lowercase hex string representation. This + avoids JSON serialisation issues with raw bytes while preserving the + exact byte sequence in the hash input. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + if isinstance(obj, (bytes, bytearray)): + return obj.hex() + raise TypeError(f"BytesHandler: expected bytes or bytearray, got {type(obj)!r}") + + +class FunctionHandler: + """ + Handler for Python functions / callables that carry a ``__code__`` attribute. + + Delegates to a FunctionInfoExtractorProtocol to produce a stable, serialisable + dict representation of the function. The extractor is responsible for + deciding which parts of the function (name, signature, source body, etc.) + are included. + + Args: + function_info_extractor: Any object with an + ``extract_function_info(func) -> dict`` method (satisfies the + FunctionInfoExtractorProtocol protocol). + """ + + def __init__(self, function_info_extractor: Any) -> None: + self.function_info_extractor = function_info_extractor + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + if not (callable(obj) and hasattr(obj, "__code__")): + raise TypeError( + f"FunctionHandler: expected a callable with __code__, got {type(obj)!r}" + ) + func_name = getattr(obj, "__name__", repr(obj)) + logger.debug("FunctionHandler: extracting info for function %r", func_name) + info: dict[str, Any] = self.function_info_extractor.extract_function_info(obj) + return info + + +class TypeObjectHandler: + """ + Handler for type objects (i.e. classes passed as values). + + Returns a stable string of the form ``"type:."`` so + that different classes always produce different hash inputs and the + result is human-readable. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + if not isinstance(obj, type): + raise TypeError( + f"TypeObjectHandler: expected a type/class, got {type(obj)!r}" + ) + module: str = obj.__module__ or "" + qualname: str = obj.__qualname__ + return f"type:{module}.{qualname}" + + +class SpecialFormHandler: + """ + Handler for ``typing._SpecialForm`` objects such as ``typing.Union`` and + ``typing.ClassVar``. + + These appear as the ``__origin__`` of typing generics — for example, + ``Optional[int]`` is ``Union[int, None]``, whose ``__origin__`` is + ``typing.Union``. Returns a stable string of the form + ``"special_form:typing."`` so they can be safely embedded as the + origin component inside a ``GenericAliasHandler`` result. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + name = getattr(obj, "_name", None) or repr(obj) + return f"special_form:typing.{name}" + + +class GenericAliasHandler: + """ + Handler for generic alias type annotations such as ``dict[int, list[int]]`` + (``types.GenericAlias``) and ``typing`` generics (``typing._GenericAlias``). + + Produces a stable dict containing the origin type and a list of hashed + argument types so that structurally identical generic annotations always + yield the same hash, and structurally different ones yield different hashes. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + origin = getattr(obj, "__origin__", None) + args = getattr(obj, "__args__", None) or () + if origin is None: + return f"generic_alias:{obj!r}" + return { + "__type__": "generic_alias", + "origin": hasher.hash_object(origin).to_string(), + "args": [hasher.hash_object(arg).to_string() for arg in args], + } + + +class ArrowTableHandler: + """ + Handler for ``pa.Table`` and ``pa.RecordBatch`` objects. + + Delegates to the injected ``ArrowHasherProtocol`` to produce a stable, + content-addressed ``ContentHash`` of the Arrow table data. The returned + ``ContentHash`` is recognised as a terminal by ``hash_object`` and + returned as-is — no further recursion occurs. + + Args: + arrow_hasher: Any object satisfying ArrowHasherProtocol (i.e. has a + ``hash_table(table) -> ContentHash`` method). + """ + + def __init__(self, arrow_hasher: ArrowHasherProtocol) -> None: + self.arrow_hasher = arrow_hasher + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + import pyarrow as _pa + + if isinstance(obj, _pa.RecordBatch): + obj = _pa.Table.from_batches([obj]) + if not isinstance(obj, _pa.Table): + raise TypeError( + f"ArrowTableHandler: expected pa.Table or pa.RecordBatch, got {type(obj)!r}" + ) + return self.arrow_hasher.hash_table(obj) + + +class SchemaHandler: + """ + Handler for :class:`~orcapod.types.Schema` objects. + + Produces a stable dict containing both the field-type mapping and the + sorted list of optional field names, so that two schemas differing only + in which fields are optional produce different hashes. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: + if not isinstance(obj, Schema): + raise TypeError(f"SchemaHandler: expected a Schema, got {type(obj)!r}") + # schema handler is not implemented yet + raise NotImplementedError() + # visited: frozenset[int] = frozenset() + + # return { + # "fields": {k: hasher._expand_element(v, visited) for k, v in obj.items()}, + # "optional_fields": sorted(obj.optional_fields), + # } + + +# --------------------------------------------------------------------------- +# Registration helper +# --------------------------------------------------------------------------- + + +def register_builtin_handlers( + registry: "TypeHandlerRegistry", + file_hasher: Any = None, + function_info_extractor: Any = None, + arrow_hasher: "ArrowHasherProtocol | None" = None, +) -> None: + """ + Register all built-in TypeHandlers into *registry*. + + This function is called automatically when the global default registry is + first accessed via ``get_default_type_handler_registry()``. It can also + be called manually to populate a custom registry. + + Path, function, and Arrow table handling require auxiliary objects. + When these are not supplied, sensible defaults are constructed: + + - ``BasicFileHasher`` (SHA-256, 64 KiB buffer) for Path handling. + - ``FunctionSignatureExtractor`` for function handling. + - ``SemanticArrowHasher`` (SHA-256, logical serialisation) for Arrow table handling. + + Args: + registry: + The TypeHandlerRegistry to populate. + file_hasher: + Optional object satisfying FileContentHasherProtocol (i.e. has a + ``hash_file(path) -> ContentHash`` method). Defaults to a + ``BasicFileHasher`` configured with SHA-256. + function_info_extractor: + Optional object satisfying FunctionInfoExtractorProtocol (i.e. has an + ``extract_function_info(func) -> dict`` method). Defaults to + ``FunctionSignatureExtractor``. + arrow_hasher: + Optional object satisfying ArrowHasherProtocol (i.e. has a + ``hash_table(table) -> ContentHash`` method). Defaults to a + ``SemanticArrowHasher`` configured with SHA-256 and logical serialisation. + Should be the data context's arrow hasher when called from a versioned + context so that hashing is consistent across all components. + """ + # Resolve defaults for auxiliary objects ---------------------------- + if file_hasher is None: + from orcapod.hashing.file_hashers import BasicFileHasher # stays in hashing/ + + file_hasher = BasicFileHasher(algorithm="sha256") + + if function_info_extractor is None: + from orcapod.hashing.semantic_hashing.function_info_extractors import ( + FunctionSignatureExtractor, + ) + + function_info_extractor = FunctionSignatureExtractor( + include_module=True, + include_defaults=True, + ) + + if arrow_hasher is None: + from orcapod.hashing.arrow_hashers import SemanticArrowHasher + from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry + + arrow_hasher = SemanticArrowHasher( + semantic_registry=SemanticTypeRegistry(), + hasher_id="arrow_v0.1", + hash_algorithm="sha256", + serialization_method="logical", + ) + + # Register handlers ------------------------------------------------- + + # bytes / bytearray + bytes_handler = BytesHandler() + registry.register(bytes, bytes_handler) + registry.register(bytearray, bytes_handler) + + # pathlib.Path (and subclasses such as PosixPath / WindowsPath) + registry.register(Path, PathContentHandler(file_hasher)) + + # uuid.UUID + registry.register(UUID, UUIDHandler()) + + # Note: ContentHash needs no handler -- SemanticHasherProtocol treats it as + # a terminal in hash_object() and returns it as-is. + + # Functions -- register types.FunctionType so MRO lookup works for + # plain ``def`` functions, plus built-in functions and bound methods. + import types as _types + + function_handler = FunctionHandler(function_info_extractor) + registry.register(_types.FunctionType, function_handler) + registry.register(_types.BuiltinFunctionType, function_handler) + registry.register(_types.MethodType, function_handler) + + # type objects (classes used as values, e.g. passed in a dict) + registry.register(type, TypeObjectHandler()) + + # generic alias type annotations: dict[int, str], list[str], etc. + generic_alias_handler = GenericAliasHandler() + registry.register(_types.GenericAlias, generic_alias_handler) + # typing._GenericAlias covers Optional[X], Union[X, Y], Dict[K, V], etc. + # typing._SpecialForm covers typing.Union, typing.ClassVar, etc. which + # appear as __origin__ on those generics (e.g. Optional[int].__origin__ + # is typing.Union, a _SpecialForm). + try: + import typing as _typing + + registry.register(_typing._GenericAlias, generic_alias_handler) # type: ignore[attr-defined] + registry.register(_typing._SpecialForm, SpecialFormHandler()) # type: ignore[attr-defined] + except AttributeError: + pass + + # Schema objects -- must come after type handler so Schema is matched + # specifically rather than falling through to the Mapping expansion path + registry.register(Schema, SchemaHandler()) + + # Arrow tables and record batches -- delegate to the injected arrow hasher + import pyarrow as _pa + + arrow_table_handler = ArrowTableHandler(arrow_hasher) + registry.register(_pa.Table, arrow_table_handler) + registry.register(_pa.RecordBatch, arrow_table_handler) + + logger.debug( + "register_builtin_handlers: registered %d built-in handlers", + len(registry), + ) diff --git a/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py b/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py new file mode 100644 index 00000000..f4bd04ce --- /dev/null +++ b/src/orcapod/hashing/semantic_hashing/content_identifiable_mixin.py @@ -0,0 +1,259 @@ +""" +ContentIdentifiableMixin -- convenience base class for content-identifiable objects. + +Any class that implements ``identity_structure()`` can inherit from this mixin +to gain a full suite of content-based identity helpers without having to wire +up a BaseSemanticHasher manually: + + - ``content_hash()`` -- returns a stable ContentHash for the object + - ``__hash__()`` -- Python hash based on content (int) + - ``__eq__()`` -- equality via content_hash comparison + +The mixin uses the global default BaseSemanticHasher by default, but accepts an +injected hasher for testing or custom configurations. + +Usage +----- +Simple usage with the global default hasher:: + + class MyRecord(ContentIdentifiableMixin): + def __init__(self, name: str, value: int) -> None: + self.name = name + self.value = value + + def identity_structure(self): + return {"name": self.name, "value": self.value} + + r1 = MyRecord("foo", 42) + r2 = MyRecord("foo", 42) + assert r1 == r2 + assert hash(r1) == hash(r2) + print(r1.content_hash()) # ContentHash(method='object_v0.1', digest=...) + +With an injected hasher (e.g. in tests):: + + hasher = BaseSemanticHasher(hasher_id="test", strict=True) + record = MyRecord("foo", 42) + record._semantic_hasher = hasher + print(record.content_hash()) + +Design notes +------------ +- The mixin stores a lazily-computed ``_cached_content_hash`` to avoid + recomputing the hash on every call. The cache is invalidated by calling + ``_invalidate_content_hash_cache()``, which subclasses should call whenever + a mutation changes the semantic content of the object. + +- ``__eq__`` compares ContentHash objects (not identity structures directly) + for efficiency: if two objects have the same hash they are considered equal. + This is a deliberate trade-off -- hash collisions are astronomically rare + for SHA-256. + +- The mixin deliberately does *not* inherit from ABC or impose any abstract + method requirements. ``identity_structure()`` is expected to be present on + the concrete class; if it is missing a clear AttributeError will surface at + call time. + +- When used alongside other base classes in a multiple-inheritance chain, + ensure that ``ContentIdentifiableMixin.__init__`` is cooperative (it calls + ``super().__init__(**kwargs)``). Pass ``semantic_hasher=`` as a keyword + argument if needed. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.types import ContentHash + +logger = logging.getLogger(__name__) + + +class ContentIdentifiableMixin: + """ + Mixin that provides content-based identity to any class implementing + ``identity_structure()``. + + Subclasses must implement:: + + def identity_structure(self) -> Any: + ... + + The returned structure is recursively resolved and hashed by the + BaseSemanticHasher to produce a stable ContentHash. + + Parameters (passed as keyword arguments to ``__init__``) + --------------------------------------------------------- + semantic_hasher: + Optional BaseSemanticHasher instance to use. When omitted, the hasher + is obtained from the default data context via + ``orcapod.contexts.get_default_context().semantic_hasher``, which is + the single source of truth for versioned component configuration. + """ + + def __init__( + self, *, semantic_hasher: BaseSemanticHasher | None = None, **kwargs: Any + ) -> None: + # Cooperative MRO-friendly init -- forward remaining kwargs up the chain. + super().__init__(**kwargs) + # Store injected hasher (may be None; resolved lazily on first use). + self._semantic_hasher = semantic_hasher + # Content hash cache keyed by hasher_id. + self._content_hash_cache: dict[str, ContentHash] = {} + + # ------------------------------------------------------------------ + # Core content-hash API + # ------------------------------------------------------------------ + + def content_hash(self, hasher=None) -> ContentHash: + """ + Return a stable ContentHash based on the object's semantic content. + + The hasher is used for the entire recursive computation — all nested + ContentIdentifiable objects are resolved using the same hasher. + Results are cached by hasher_id so repeated calls with the same + hasher are free. + + Args: + hasher: Optional semantic hasher to use. When omitted, resolved + via _get_hasher() (injected hasher or default context). + + Returns: + ContentHash: Deterministic, content-based hash of this object. + """ + if hasher is None: + hasher = self._get_hasher() + cache_key = hasher.hasher_id + if cache_key not in self._content_hash_cache: + structure = self.identity_structure() # type: ignore[attr-defined] + logger.debug( + "ContentIdentifiableMixin.content_hash: computing hash for %s", + type(self).__name__, + ) + resolver = lambda obj: obj.content_hash(hasher) + self._content_hash_cache[cache_key] = hasher.hash_object( + structure, resolver=resolver + ) + return self._content_hash_cache[cache_key] + + def identity_structure(self) -> Any: + """ + Return a structure representing the semantic identity of this object. + + Subclasses MUST override this method. The default implementation + raises NotImplementedError to make the missing override visible + immediately rather than silently producing a meaningless hash. + + Returns: + Any: A deterministic Python structure whose content fully captures + the semantic identity of this object. + + Raises: + NotImplementedError: Always, unless overridden by a subclass. + """ + raise NotImplementedError( + f"{type(self).__name__} must implement identity_structure() to use " + "ContentIdentifiableMixin. Override this method and return a " + "deterministic Python structure representing the object's semantic " + "content." + ) + + # ------------------------------------------------------------------ + # Python data model integration + # ------------------------------------------------------------------ + + def __hash__(self) -> int: + """ + Return a Python integer hash derived from the content hash. + + Uses the first 16 hex characters (64 bits) of the SHA-256 digest + converted to an integer. This provides a good distribution while + fitting within Python's hash range on all platforms. + + Returns: + int: A stable, content-based hash integer. + """ + return self.content_hash().to_int(hexdigits=16) + + def __eq__(self, other: object) -> bool: + """ + Compare this object to *other* based on content hash equality. + + Two ContentIdentifiableProtocol objects are considered equal if and only if + their content hashes are identical. Objects of a different type that + do not inherit ContentIdentifiableMixin are never equal to a mixin + instance (returns NotImplemented to allow the other object to decide). + + Args: + other: The object to compare against. + + Returns: + bool: True if both objects have the same content hash. + NotImplemented: If *other* does not implement content_hash(). + """ + if not isinstance(other, ContentIdentifiableMixin): + return NotImplemented + return self.content_hash() == other.content_hash() + + # ------------------------------------------------------------------ + # Cache management + # ------------------------------------------------------------------ + + def _invalidate_content_hash_cache(self) -> None: + """ + Invalidate the cached content hash. + + Call this after any mutation that changes the object's semantic + content so that the next call to ``content_hash()`` recomputes from + scratch. + """ + self._content_hash_cache.clear() + + # ------------------------------------------------------------------ + # Hasher resolution + # ------------------------------------------------------------------ + + def _get_hasher(self) -> BaseSemanticHasher: + """ + Return the BaseSemanticHasher to use for this object. + + Resolution order: + 1. The instance-level ``_semantic_hasher`` attribute (set at + construction or injected directly). + 2. The semantic hasher from the default data context, obtained via + ``orcapod.contexts.get_default_context().semantic_hasher``. + The data context is the single source of truth for versioned + component configuration; going through it ensures that the + hasher is consistent with all other components (arrow hasher, + type converter, etc.) that belong to the same context. + + Returns: + BaseSemanticHasher: The hasher to use. + """ + if self._semantic_hasher is not None: + return self._semantic_hasher + + # Late import to avoid circular dependencies: contexts imports from + # protocols and hashing, so we must not import it at module level. + from orcapod.contexts import get_default_context + + return get_default_context().semantic_hasher # type: ignore[return-value] + + # ------------------------------------------------------------------ + # Repr helper + # ------------------------------------------------------------------ + + def __repr__(self) -> str: + """ + Return a human-readable representation including the short content hash. + + Uses only 8 hex characters to keep the repr concise. Subclasses are + encouraged to override this if they need a more informative repr. + """ + try: + short_hash = self.content_hash().to_hex(char_count=8) + except Exception: + short_hash = "" + return f"{type(self).__name__}(content_hash={short_hash!r})" diff --git a/src/orcapod/hashing/function_info_extractors.py b/src/orcapod/hashing/semantic_hashing/function_info_extractors.py similarity index 90% rename from src/orcapod/hashing/function_info_extractors.py rename to src/orcapod/hashing/semantic_hashing/function_info_extractors.py index 0b5d4488..6191a684 100644 --- a/src/orcapod/hashing/function_info_extractors.py +++ b/src/orcapod/hashing/semantic_hashing/function_info_extractors.py @@ -1,8 +1,9 @@ -from orcapod.protocols.hashing_protocols import FunctionInfoExtractor +import inspect from collections.abc import Callable from typing import Any, Literal -from orcapod.types import PythonSchema -import inspect + +from orcapod.protocols.hashing_protocols import FunctionInfoExtractorProtocol +from orcapod.types import Schema class FunctionNameExtractor: @@ -14,8 +15,8 @@ def extract_function_info( self, func: Callable[..., Any], function_name: str | None = None, - input_typespec: PythonSchema | None = None, - output_typespec: PythonSchema | None = None, + input_typespec: Schema | None = None, + output_typespec: Schema | None = None, ) -> dict[str, Any]: if not callable(func): raise TypeError("Provided object is not callable") @@ -38,8 +39,8 @@ def extract_function_info( self, func: Callable[..., Any], function_name: str | None = None, - input_typespec: PythonSchema | None = None, - output_typespec: PythonSchema | None = None, + input_typespec: Schema | None = None, + output_typespec: Schema | None = None, ) -> dict[str, Any]: if not callable(func): raise TypeError("Provided object is not callable") @@ -80,7 +81,7 @@ class FunctionInfoExtractorFactory: @staticmethod def create_function_info_extractor( strategy: Literal["name", "signature"] = "signature", - ) -> FunctionInfoExtractor: + ) -> FunctionInfoExtractorProtocol: """Create a basic composite extractor.""" if strategy == "name": return FunctionNameExtractor() diff --git a/src/orcapod/hashing/semantic_hashing/semantic_hasher.py b/src/orcapod/hashing/semantic_hashing/semantic_hasher.py new file mode 100644 index 00000000..ceb13315 --- /dev/null +++ b/src/orcapod/hashing/semantic_hashing/semantic_hasher.py @@ -0,0 +1,427 @@ +""" +BaseSemanticHasher -- content-based recursive object hasher. + +Algorithm +--------- +``hash_object(obj)`` is the single public entry point. It is mutually +recursive with ``_expand_structure``: + +``hash_object(obj)`` + Produces a ContentHash for *any* Python object. + + - ContentHash → terminal; returned as-is (already a hash) + - Primitive → JSON-serialise + SHA-256 + - Structure → delegate to ``_expand_structure``, then + JSON-serialise the resulting tagged tree + SHA-256 + - Handler match → call handler.handle(obj), recurse via hash_object + - ContentIdentifiableProtocol→ call identity_structure(), recurse via hash_object + - Fallback → strict error or best-effort string, then hash + +``_expand_structure(obj)`` + Structural expansion only -- called exclusively for container types + (list, tuple, dict, set, frozenset, namedtuple). Returns a + JSON-serialisable value where: + + - Primitive elements → passed through as-is (become leaves in the tree) + - Nested structures → recurse via ``_expand_structure`` + - Everything else → call ``hash_object``, embed the resulting + ContentHash.to_string() as a string leaf + +The boundary between the two functions encodes a key semantic distinction: +a ContentIdentifiableProtocol object X whose identity_structure returns [A, B] +embedded inside [X, C] contributes only its hash token to the parent -- +it is NOT the same as [[A, B], C]. The parent's structure is opaque to +the expansion that produced X's hash. + +Container type serialisation +---------------------------- +Native JSON container types (list and dict) are kept in their natural JSON +form. Python-only container types that have no unambiguous JSON equivalent +are wrapped in a ``{"__type__": ..., ...}`` tagged object so that +structurally similar but type-distinct containers produce different hashes: + + list → [...] # native JSON array + dict → {...} # native JSON object; keys sorted + tuple → {"__type__": "tuple", "items": [...]} + set → {"__type__": "set", "items": [...]} # items sorted by str() + frozenset → {"__type__": "set", "items": [...]} # same tag as set + namedtuple → {"__type__": "namedtuple","name": "T", + "fields": {...}} # sorted by field name + +This means a ``list`` and a ``tuple`` with the same elements will hash +differently (the tuple carries a type tag), while a plain ``list`` and a +plain JSON array embedded anywhere in a structure are indistinguishable -- +which is exactly the desired semantics for interoperability. + +Circular-reference detection +----------------------------- +Container ids are tracked in a ``_visited`` frozenset threaded through +``_expand_structure``. When an already-visited id is encountered the +sentinel string ``"CircularRef"`` is embedded as the leaf value. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import re +from collections.abc import Callable, Mapping +from typing import Any + +from orcapod.hashing.semantic_hashing.type_handler_registry import TypeHandlerRegistry +from orcapod.protocols import hashing_protocols as hp +from orcapod.types import ContentHash + +logger = logging.getLogger(__name__) + +_CIRCULAR_REF_SENTINEL = "CircularRef" +_MEMADDR_RE = re.compile(r" at 0x[0-9a-fA-F]+") + + +class BaseSemanticHasher: + """ + Content-based recursive hasher. + + Parameters + ---------- + hasher_id: + A short string identifying this hasher version/configuration. + Embedded in every ContentHash produced. + type_handler_registry: + TypeHandlerRegistry for MRO-aware lookup of TypeHandlerProtocol instances. + If None, the default registry from the active DataContext is used. + strict: + When True (default) raises TypeError for unhandled types. + When False falls back to a best-effort string representation. + """ + + def __init__( + self, + hasher_id: str, + type_handler_registry: TypeHandlerRegistry | None = None, + strict: bool = True, + ) -> None: + self._hasher_id = hasher_id + self._strict = strict + + if type_handler_registry is None: + from orcapod.hashing.defaults import get_default_type_handler_registry + + self._registry = get_default_type_handler_registry() # stays in hashing/ + else: + self._registry = type_handler_registry + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + @property + def hasher_id(self) -> str: + return self._hasher_id + + @property + def strict(self) -> bool: + return self._strict + + def hash_object( + self, + obj: Any, + resolver: Callable[[Any], ContentHash] | None = None, + ) -> ContentHash: + """ + Hash *obj* based on its semantic content. + + This is the single recursive entry point for the hashing system. + Returns a ContentHash for any Python object. + + - ContentHash → terminal; returned as-is + - Primitive → JSON-serialised and hashed directly + - Structure → structurally expanded then hashed + - Handler match → handler produces a value, recurse + - ContentIdentifiableProtocol→ resolver(obj) if resolver provided, else obj.content_hash() + - Unknown type → TypeError in strict mode; best-effort otherwise + + Args: + obj: The object to hash. + resolver: Optional callable invoked for any ContentIdentifiableProtocol + object encountered during hashing. When provided it overrides the + default ``obj.content_hash()`` call, allowing the caller to control + which identity chain is used (e.g. pipeline_hash vs content_hash) + and to propagate a consistent semantic hasher through the full + recursive computation. + + Returns: + ContentHash: Stable, content-based hash of the object. + """ + # Terminal: already a hash -- return as-is. + if isinstance(obj, ContentHash): + return obj + + # Primitives: hash their direct JSON representation. + if isinstance(obj, (type(None), bool, int, float, str)): + return self._hash_to_content_hash(obj) + + # Structures: expand into a tagged tree, then hash the tree. + if _is_structure(obj): + expanded = self._expand_structure( + obj, _visited=frozenset(), resolver=resolver + ) + return self._hash_to_content_hash(expanded) + + # Handler dispatch: the handler produces a new value; recurse. + handler = self._registry.get_handler(obj) + if handler is not None: + logger.debug( + "hash_object: dispatching %s to handler %s", + type(obj).__name__, + type(handler).__name__, + ) + return self.hash_object(handler.handle(obj, self), resolver=resolver) + + # ContentIdentifiableProtocol: use resolver if provided, else content_hash(). + if isinstance(obj, hp.ContentIdentifiableProtocol): + if resolver is not None: + logger.debug( + "hash_object: resolving ContentIdentifiableProtocol %s via resolver", + type(obj).__name__, + ) + return resolver(obj) + else: + logger.debug( + "hash_object: using ContentIdentifiableProtocol %s's content_hash", + type(obj).__name__, + ) + return obj.content_hash() + + # Fallback for unhandled types. + fallback = self._handle_unknown(obj) + return self._hash_to_content_hash(fallback) + + # ------------------------------------------------------------------ + # Private: structural expansion + # ------------------------------------------------------------------ + + def _expand_structure( + self, + obj: Any, + _visited: frozenset[int], + resolver: Callable[[Any], ContentHash] | None = None, + ) -> Any: + """ + Expand a container object into a JSON-serialisable tagged tree. + + Only called for structural types (list, tuple, dict, set, frozenset, + namedtuple). Within nested structures this function recurses into + itself for container elements and calls ``hash_object`` for all + non-container, non-primitive elements, embedding the resulting + ContentHash.to_string() as a string leaf. + + Primitives are passed through as-is. + + Args: + obj: The object to expand. Must be a structure or primitive. + _visited: Set of container ids already on the current traversal + path, for circular-reference detection. + + Returns: + A JSON-serialisable dict (with ``__type__`` tag) for containers, + or the primitive value itself. + """ + # Primitives are leaves -- pass through. + if isinstance(obj, (type(None), bool, int, float, str)): + return obj + + # ContentHash is a terminal leaf -- embed as its string form. + if isinstance(obj, ContentHash): + return obj.to_string() + + # Circular-reference guard for containers. + obj_id = id(obj) + if obj_id in _visited: + logger.debug( + "_expand_structure: circular reference detected for %s", + type(obj).__name__, + ) + return _CIRCULAR_REF_SENTINEL + _visited = _visited | {obj_id} + + if _is_namedtuple(obj): + return self._expand_namedtuple(obj, _visited, resolver=resolver) + + if isinstance(obj, (dict, Mapping)): + return self._expand_mapping(obj, _visited, resolver=resolver) + + if isinstance(obj, list): + return [ + self._expand_element(item, _visited, resolver=resolver) for item in obj + ] + + if isinstance(obj, tuple): + return { + "__type__": "tuple", + "items": [ + self._expand_element(item, _visited, resolver=resolver) + for item in obj + ], + } + + if isinstance(obj, (set, frozenset)): + expanded_items = [ + self._expand_element(item, _visited, resolver=resolver) for item in obj + ] + return { + "__type__": "set", + "items": sorted(expanded_items, key=str), + } + + # Should not be reached if _is_structure() is consistent. + raise TypeError(f"_expand_structure called on non-structure type {type(obj)!r}") + + def _expand_element( + self, + obj: Any, + _visited: frozenset[int], + resolver: Callable[[Any], ContentHash] | None = None, + ) -> Any: + """ + Expand a single element within a structure. + + - Primitives and ContentHash → handled by _expand_structure (leaf) + - Nested structures → recurse via _expand_structure + - Everything else → call hash_object, embed to_string() as leaf + """ + if isinstance(obj, (type(None), bool, int, float, str, ContentHash)): + return self._expand_structure(obj, _visited, resolver=resolver) + + if _is_structure(obj): + return self._expand_structure(obj, _visited, resolver=resolver) + + # Non-structure, non-primitive: hash independently and embed token. + return self.hash_object(obj, resolver=resolver).to_string() + + def _expand_mapping( + self, + obj: Mapping, + _visited: frozenset[int], + resolver: Callable[[Any], ContentHash] | None = None, + ) -> dict: + """Expand a dict/Mapping into a sorted native JSON object.""" + items: dict[str, Any] = {} + for k, v in obj.items(): + str_key = str(self._expand_element(k, _visited, resolver=resolver)) + items[str_key] = self._expand_element(v, _visited, resolver=resolver) + # Sort for determinism regardless of insertion order. + return dict(sorted(items.items())) + + def _expand_namedtuple( + self, + obj: Any, + _visited: frozenset[int], + resolver: Callable[[Any], ContentHash] | None = None, + ) -> dict: + """Expand a namedtuple into a tagged dict preserving field names.""" + fields: tuple[str, ...] = obj._fields + expanded_fields = { + field: self._expand_element( + getattr(obj, field), _visited, resolver=resolver + ) + for field in fields + } + return { + "__type__": "namedtuple", + "name": type(obj).__name__, + "fields": dict(sorted(expanded_fields.items())), + } + + # ------------------------------------------------------------------ + # Private: hashing + # ------------------------------------------------------------------ + + def _hash_to_content_hash(self, obj: Any) -> ContentHash: + """ + JSON-serialise *obj* and compute a SHA-256 ContentHash. + + *obj* must already be a JSON-serialisable primitive or tagged tree + (the output of _expand_structure or a raw primitive). + """ + try: + json_bytes = json.dumps( + obj, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + except (TypeError, ValueError) as exc: + raise TypeError( + f"BaseSemanticHasher: failed to JSON-serialise object of type " + f"{type(obj).__name__!r}. Ensure all TypeHandlers and " + "identity_structure() implementations return JSON-serialisable " + "primitives or structures." + ) from exc + + digest = hashlib.sha256(json_bytes).digest() + return ContentHash(self._hasher_id, digest) + + # ------------------------------------------------------------------ + # Private: fallback for unhandled types + # ------------------------------------------------------------------ + + def _handle_unknown(self, obj: Any) -> str: + """ + Produce a best-effort string for an unregistered, non-ContentIdentifiableProtocol + type. Raises TypeError in strict mode. + """ + class_name = type(obj).__name__ + module_name = getattr(type(obj), "__module__", "") + qualified = f"{module_name}.{class_name}" + + if self._strict: + raise TypeError( + f"BaseSemanticHasher (strict): no TypeHandlerProtocol registered for type " + f"'{qualified}' and it does not implement ContentIdentifiableProtocol. " + "Register a TypeHandlerProtocol via the TypeHandlerRegistry or implement " + "identity_structure() on the class." + ) + + logger.warning( + "SemanticHasherProtocol (non-strict): no handler for type '%s'. " + "Falling back to best-effort string representation.", + qualified, + ) + + if hasattr(obj, "__dict__"): + attrs = sorted( + (k, type(v).__name__) + for k, v in obj.__dict__.items() + if not k.startswith("_") + ) + attr_str = ", ".join(f"{k}={t}" for k, t in attrs[:10]) + return f"{qualified}{{{attr_str}}}" + else: + raw = repr(obj) + if len(raw) > 1000: + raw = raw[:1000] + "..." + scrubbed = _MEMADDR_RE.sub(" at 0xMEMADDR", raw) + return f"{qualified}:{scrubbed}" + + +# --------------------------------------------------------------------------- +# Helper predicates +# --------------------------------------------------------------------------- + + +def _is_structure(obj: Any) -> bool: + """Return True if *obj* is a container type handled by _expand_structure.""" + return isinstance(obj, (list, tuple, dict, set, frozenset, Mapping)) + + +def _is_namedtuple(obj: Any) -> bool: + """Return True if *obj* is an instance of a namedtuple class.""" + if not isinstance(obj, tuple): + return False + obj_type = type(obj) + fields = getattr(obj_type, "_fields", None) + if fields is None: + return False + return isinstance(fields, tuple) and all(isinstance(f, str) for f in fields) diff --git a/src/orcapod/hashing/semantic_hashing/type_handler_registry.py b/src/orcapod/hashing/semantic_hashing/type_handler_registry.py new file mode 100644 index 00000000..690ec024 --- /dev/null +++ b/src/orcapod/hashing/semantic_hashing/type_handler_registry.py @@ -0,0 +1,260 @@ +""" +Type Handler Registry for the SemanticHasherProtocol system. + +Provides a registry through which TypeHandlerProtocol implementations can be +registered for specific Python types. Lookup is MRO-aware: if no handler +is registered for an exact type, the registry walks the MRO of the object's +class to find the nearest ancestor for which a handler has been registered. + +Usage +----- +# Register a handler for a specific type: +registry = TypeHandlerRegistry() +registry.register(Path, PathContentHandler()) + +# Or use the global default registry: +from orcapod.hashing.semantic_hashing.type_handler_registry import get_default_type_handler_registry +get_default_type_handler_registry().register(MyType, MyTypeHandler()) + +# Look up a handler (returns None if not found): +handler = registry.get_handler(some_object) +""" + +from __future__ import annotations + +import logging +import threading +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from orcapod.protocols.hashing_protocols import ( + ArrowHasherProtocol, + TypeHandlerProtocol, + ) + +logger = logging.getLogger(__name__) + + +class TypeHandlerRegistry: + """ + Registry mapping Python types to TypeHandlerProtocol instances. + + Lookup is MRO-aware: when no handler is registered for the exact type of + an object, the registry walks the object's MRO (most-derived first) until + it finds a match. This means a handler registered for a base class is + automatically inherited by all subclasses, unless a more specific handler + has been registered for the subclass. + + Thread safety + ------------- + Registration and lookup are protected by a reentrant lock so that the + global singleton can be safely used from multiple threads. + """ + + def __init__( + self, handlers: list[tuple[type, TypeHandlerProtocol]] | None = None + ) -> None: + """ + Args: + handlers: Optional list of ``(target_type, handler)`` pairs to + register at construction time. Designed for use with + ``parse_objectspec``: the JSON spec provides a list of + two-element arrays where the first element uses ``_type`` + to resolve a Python type and the second uses ``_class`` to + instantiate the handler. + """ + # Maps type -> handler; insertion order is preserved but lookup uses MRO. + self._handlers: dict[type, TypeHandlerProtocol] = {} + self._lock = threading.RLock() + if handlers: + for target_type, handler in handlers: + self.register(target_type, handler) + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register(self, target_type: type, handler: TypeHandlerProtocol) -> None: + """ + Register a handler for a specific Python type. + + If a handler is already registered for *target_type*, it is silently + replaced by the new handler. + + Args: + target_type: The Python type (or class) for which the handler + should be used. Must be a ``type`` object. + handler: A TypeHandlerProtocol instance whose ``handle()`` method will + be called when an object of ``target_type`` (or a + subclass with no more specific handler) is encountered + during structure resolution. + + Raises: + TypeError: If ``target_type`` is not a ``type``. + """ + if not isinstance(target_type, type): + raise TypeError( + f"target_type must be a type/class, got {type(target_type)!r}" + ) + with self._lock: + existing = self._handlers.get(target_type) + if existing is not None and existing is not handler: + logger.debug( + "TypeHandlerRegistry: replacing existing handler for %s (%s -> %s)", + target_type.__name__, + type(existing).__name__, + type(handler).__name__, + ) + self._handlers[target_type] = handler + + def unregister(self, target_type: type) -> bool: + """ + Remove the handler registered for *target_type*, if any. + + Args: + target_type: The type whose handler should be removed. + + Returns: + True if a handler was removed, False if none was registered. + """ + with self._lock: + if target_type in self._handlers: + del self._handlers[target_type] + return True + return False + + # ------------------------------------------------------------------ + # Lookup + # ------------------------------------------------------------------ + + def get_handler(self, obj: Any) -> "TypeHandlerProtocol | None": + """ + Look up the handler for *obj* using MRO-aware resolution. + + The MRO of ``type(obj)`` is walked from most-derived to least-derived + (i.e. the object's own class first, then its bases). The first + match found in the registry is returned. + + Args: + obj: The object for which a handler is needed. + + Returns: + The registered TypeHandlerProtocol, or None if no handler is registered + for the object's type or any of its base classes. + """ + obj_type = type(obj) + with self._lock: + # Fast path: exact type match. + handler = self._handlers.get(obj_type) + if handler is not None: + return handler + + # Slow path: walk the MRO, skipping the type itself (already + # checked above) and skipping ``object`` as a last resort -- a + # handler registered for ``object`` would match everything. + for base in obj_type.__mro__[1:]: + handler = self._handlers.get(base) + if handler is not None: + logger.debug( + "TypeHandlerRegistry: resolved handler for %s via base %s", + obj_type.__name__, + base.__name__, + ) + return handler + + return None + + def get_handler_for_type(self, target_type: type) -> "TypeHandlerProtocol | None": + """ + Look up the handler for a *type object* (rather than an instance). + + Useful when the caller already has the type and wants to check + registration without constructing a dummy instance. + + Args: + target_type: The type to look up. + + Returns: + The registered TypeHandlerProtocol, or None. + """ + with self._lock: + handler = self._handlers.get(target_type) + if handler is not None: + return handler + for base in target_type.__mro__[1:]: + handler = self._handlers.get(base) + if handler is not None: + return handler + return None + + def has_handler(self, target_type: type) -> bool: + """ + Return True if a handler is registered for *target_type* or any of + its MRO ancestors. + + Args: + target_type: The type to check. + """ + return self.get_handler_for_type(target_type) is not None + + def registered_types(self) -> list[type]: + """ + Return a list of all directly-registered types (no MRO expansion). + + Returns: + A snapshot list of types that have explicit handler registrations. + """ + with self._lock: + return list(self._handlers.keys()) + + # ------------------------------------------------------------------ + # Dunder helpers + # ------------------------------------------------------------------ + + def __repr__(self) -> str: + with self._lock: + names = [t.__name__ for t in self._handlers] + return f"TypeHandlerRegistry(registered={names!r})" + + def __len__(self) -> int: + with self._lock: + return len(self._handlers) + + +# --------------------------------------------------------------------------- +# Pre-populated registry +# --------------------------------------------------------------------------- + + +def get_default_type_handler_registry() -> "TypeHandlerRegistry": + """ + Return the TypeHandlerRegistry from the default data context. + + This is a convenience wrapper; the registry is owned and versioned by the + active DataContext. Importing this function from + ``orcapod.hashing.defaults`` or ``orcapod.hashing`` is equivalent. + """ + from orcapod.hashing.defaults import ( + get_default_type_handler_registry as _get, + ) # stays in hashing/ + + return _get() + + +class BuiltinTypeHandlerRegistry(TypeHandlerRegistry): + """ + A TypeHandlerRegistry pre-populated with all built-in handlers. + + Constructed via the data context JSON spec so that the default registry + is versioned alongside the rest of the context components. The built-in + handlers are registered in ``__init__`` so that no separate population + step is required after construction. + """ + + def __init__(self, arrow_hasher: "ArrowHasherProtocol | None" = None) -> None: + super().__init__() + from orcapod.hashing.semantic_hashing.builtin_handlers import ( + register_builtin_handlers, + ) + + register_builtin_handlers(self, arrow_hasher=arrow_hasher) diff --git a/src/orcapod/hashing/semantic_type_hashers.py b/src/orcapod/hashing/semantic_type_hashers.py deleted file mode 100644 index 712d0194..00000000 --- a/src/orcapod/hashing/semantic_type_hashers.py +++ /dev/null @@ -1,100 +0,0 @@ -import hashlib -import os - -import pyarrow as pa - -from orcapod.protocols.hashing_protocols import ( - FileContentHasher, - SemanticTypeHasher, - StringCacher, -) - - -class PathHasher(SemanticTypeHasher): - """Hasher for Path semantic type columns - hashes file contents.""" - - def __init__( - self, - file_hasher: FileContentHasher, - handle_missing: str = "error", - string_cacher: StringCacher | None = None, - cache_key_prefix: str = "path_hasher", - ): - """ - Initialize PathHasher. - - Args: - chunk_size: Size of chunks to read files in bytes - handle_missing: How to handle missing files ('error', 'skip', 'null_hash') - """ - self.file_hasher = file_hasher - self.handle_missing = handle_missing - self.cacher = string_cacher - self.cache_key_prefix = cache_key_prefix - - def _hash_file_content(self, file_path: str) -> bytes: - """Hash the content of a single file""" - import os - - # if cacher exists, check if the hash is cached - if self.cacher: - cache_key = f"{self.cache_key_prefix}:{file_path}" - cached_hash_hex = self.cacher.get_cached(cache_key) - if cached_hash_hex is not None: - return bytes.fromhex(cached_hash_hex) - - try: - if not os.path.exists(file_path): - if self.handle_missing == "error": - raise FileNotFoundError(f"File not found: {file_path}") - elif self.handle_missing == "skip": - return hashlib.sha256(b"").digest() - elif self.handle_missing == "null_hash": - return hashlib.sha256(b"").digest() - - hashed_value = self.file_hasher.hash_file(file_path) - if self.cacher: - # Cache the computed hash hex - self.cacher.set_cached( - f"{self.cache_key_prefix}:{file_path}", hashed_value.to_hex() - ) - # TODO: make consistent use of bytes/string for hash - return hashed_value.digest - - except (IOError, OSError, PermissionError) as e: - if self.handle_missing == "error": - raise IOError(f"Cannot read file {file_path}: {e}") - else: # skip or null_hash - error_msg = f"" - return hashlib.sha256(error_msg.encode("utf-8")).digest() - - def hash_column(self, column: pa.Array) -> pa.Array: - """ - Replace path column with file content hashes. - Returns a new array where each path is replaced with its file content hash. - """ - - # Convert to python list for processing - paths = column.to_pylist() - - # Hash each file's content individually - content_hashes = [] - for path in paths: - if path is not None: - # Normalize path for consistency - normalized_path = os.path.normpath(str(path)) - file_content_hash = self._hash_file_content(normalized_path) - content_hashes.append(file_content_hash) - else: - content_hashes.append(None) # Preserve nulls - - # Return new array with content hashes instead of paths - return pa.array(content_hashes) - - def set_cacher(self, cacher: StringCacher) -> None: - """ - Add a string cacher for caching hash values. - This is a no-op for PathHasher since it hashes file contents directly. - """ - # PathHasher does not use string caching, so this is a no-op - self.cacher = cacher diff --git a/src/orcapod/hashing/string_cachers.py b/src/orcapod/hashing/string_cachers.py index caa6c93d..9575411f 100644 --- a/src/orcapod/hashing/string_cachers.py +++ b/src/orcapod/hashing/string_cachers.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from orcapod.protocols.hashing_protocols import StringCacher +from orcapod.protocols.hashing_protocols import StringCacherProtocol logger = logging.getLogger(__name__) @@ -14,14 +14,14 @@ import redis -class TransferCacher(StringCacher): +class TransferCacher(StringCacherProtocol): """ Takes two string cachers as source and destination. Everytime a cached value is retrieved from source, the value is also set in the destination cacher. This is useful for transferring cached values between different caching mechanisms. """ - def __init__(self, source: StringCacher, destination: StringCacher): + def __init__(self, source: StringCacherProtocol, destination: StringCacherProtocol): """ Initialize the TransferCacher. @@ -68,7 +68,7 @@ def clear_cache(self) -> None: self.destination.clear_cache() -class InMemoryCacher(StringCacher): +class InMemoryCacher(StringCacherProtocol): """Thread-safe in-memory LRU cache.""" def __init__(self, max_size: int | None = 1000): @@ -108,7 +108,7 @@ def clear_cache(self) -> None: self._access_order.clear() -class FileCacher(StringCacher): +class FileCacher(StringCacherProtocol): """File-based cacher with eventual consistency between memory and disk.""" def __init__( @@ -270,7 +270,7 @@ def force_sync(self) -> None: self._sync_to_file() -class SQLiteCacher(StringCacher): +class SQLiteCacher(StringCacherProtocol): """SQLite-based cacher with in-memory LRU and database persistence.""" def __init__( @@ -312,11 +312,11 @@ def _init_database(self) -> None: CREATE TABLE IF NOT EXISTS cache_entries ( key TEXT PRIMARY KEY, value TEXT NOT NULL, - last_accessed REAL DEFAULT (strftime('%f', 'now')) + last_accessed REAL DEFAULT (strftime('%s', 'now')) ) """) conn.execute(""" - CREATE INDEX IF NOT EXISTS idx_last_accessed + CREATE INDEX IF NOT EXISTS idx_last_accessed ON cache_entries(last_accessed) """) conn.commit() @@ -330,7 +330,7 @@ def _load_from_database(self) -> None: try: with sqlite3.connect(self.db_path) as conn: cursor = conn.execute(""" - SELECT key, value FROM cache_entries + SELECT key, value FROM cache_entries ORDER BY last_accessed DESC """) @@ -396,7 +396,7 @@ def _sync_to_database(self) -> None: conn.execute( """ INSERT OR REPLACE INTO cache_entries (key, value, last_accessed) - VALUES (?, ?, strftime('%f', 'now')) + VALUES (?, ?, strftime('%s', 'now')) """, (key, value), ) @@ -579,7 +579,7 @@ def __del__(self): pass # Avoid exceptions in destructor -class RedisCacher(StringCacher): +class RedisCacher(StringCacherProtocol): """Redis-based cacher with graceful failure handling.""" def __init__( diff --git a/src/orcapod/hashing/types.py b/src/orcapod/hashing/types.py deleted file mode 100644 index 027b9e28..00000000 --- a/src/orcapod/hashing/types.py +++ /dev/null @@ -1,178 +0,0 @@ -# """Hash strategy protocols for dependency injection.""" - -# from abc import ABC, abstractmethod -# from collections.abc import Callable -# from typing import Any, Protocol, runtime_checkable -# import uuid - -# from orcapod.types import PacketLike, PathLike, PathSet, TypeSpec - -# import pyarrow as pa - - -# @runtime_checkable -# class Identifiable(Protocol): -# """Protocol for objects that can provide an identity structure.""" - -# def identity_structure(self) -> Any: -# """ -# Return a structure that represents the identity of this object. - -# Returns: -# Any: A structure representing this object's content. -# Should be deterministic and include all identity-relevant data. -# Return None to indicate no custom identity is available. -# """ -# pass # pragma: no cover - - -# class ObjectHasher(ABC): -# """Abstract class for general object hashing.""" - -# # TODO: consider more explicitly stating types of objects accepted -# @abstractmethod -# def hash(self, obj: Any) -> bytes: -# """ -# Hash an object to a byte representation. - -# Args: -# obj (Any): The object to hash. - -# Returns: -# bytes: The byte representation of the hash. -# """ -# ... - -# @abstractmethod -# def get_hasher_id(self) -> str: -# """ -# Returns a unique identifier/name assigned to the hasher -# """ - -# def hash_to_hex( -# self, obj: Any, char_count: int | None = None, prefix_hasher_id: bool = False -# ) -> str: -# hash_bytes = self.hash(obj) -# hex_str = hash_bytes.hex() - -# # TODO: clean up this logic, as char_count handling is messy -# if char_count is not None: -# if char_count > len(hex_str): -# raise ValueError( -# f"Cannot truncate to {char_count} chars, hash only has {len(hex_str)}" -# ) -# hex_str = hex_str[:char_count] -# if prefix_hasher_id: -# hex_str = self.get_hasher_id() + "@" + hex_str -# return hex_str - -# def hash_to_int(self, obj: Any, hexdigits: int = 16) -> int: -# """ -# Hash an object to an integer. - -# Args: -# obj (Any): The object to hash. -# hexdigits (int): Number of hexadecimal digits to use for the hash. - -# Returns: -# int: The integer representation of the hash. -# """ -# hex_hash = self.hash_to_hex(obj, char_count=hexdigits) -# return int(hex_hash, 16) - -# def hash_to_uuid( -# self, obj: Any, namespace: uuid.UUID = uuid.NAMESPACE_OID -# ) -> uuid.UUID: -# """Convert hash to proper UUID5.""" -# return uuid.uuid5(namespace, self.hash(obj)) - - -# @runtime_checkable -# class FileContentHasher(Protocol): -# """Protocol for file-related hashing.""" - -# def hash_file(self, file_path: PathLike) -> bytes: ... - - -# @runtime_checkable -# class ArrowHasher(Protocol): -# """Protocol for hashing arrow packets.""" - -# def get_hasher_id(self) -> str: ... - -# def hash_table(self, table: pa.Table, prefix_hasher_id: bool = True) -> str: ... - - -# @runtime_checkable -# class StringCacher(Protocol): -# """Protocol for caching string key value pairs.""" - -# def get_cached(self, cache_key: str) -> str | None: ... -# def set_cached(self, cache_key: str, value: str) -> None: ... -# def clear_cache(self) -> None: ... - - -# # Function hasher protocol -# @runtime_checkable -# class FunctionInfoExtractor(Protocol): -# """Protocol for extracting function information.""" - -# def extract_function_info( -# self, -# func: Callable[..., Any], -# function_name: str | None = None, -# input_typespec: TypeSpec | None = None, -# output_typespec: TypeSpec | None = None, -# ) -> dict[str, Any]: ... - - -# class SemanticTypeHasher(Protocol): -# """Abstract base class for semantic type-specific hashers.""" - -# @abstractmethod -# def hash_column( -# self, -# column: pa.Array, -# ) -> pa.Array: -# """Hash a column with this semantic type and return the hash bytes.""" -# pass - -# @abstractmethod -# def set_cacher(self, cacher: StringCacher) -> None: -# """Add a string cacher for caching hash values.""" -# pass - - -# # ---------------Legacy implementations and protocols to be deprecated--------------------- - - -# @runtime_checkable -# class LegacyFileHasher(Protocol): -# """Protocol for file-related hashing.""" - -# def hash_file(self, file_path: PathLike) -> str: ... - - -# # Higher-level operations that compose file hashing -# @runtime_checkable -# class LegacyPathSetHasher(Protocol): -# """Protocol for hashing pathsets (files, directories, collections).""" - -# def hash_pathset(self, pathset: PathSet) -> str: ... - - -# @runtime_checkable -# class LegacyPacketHasher(Protocol): -# """Protocol for hashing packets.""" - -# def hash_packet(self, packet: PacketLike) -> str: ... - - -# # Combined interface for convenience (optional) -# @runtime_checkable -# class LegacyCompositeFileHasher( -# LegacyFileHasher, LegacyPathSetHasher, LegacyPacketHasher, Protocol -# ): -# """Combined interface for all file-related hashing operations.""" - -# pass diff --git a/src/orcapod/hashing/versioned_hashers.py b/src/orcapod/hashing/versioned_hashers.py new file mode 100644 index 00000000..8adce44d --- /dev/null +++ b/src/orcapod/hashing/versioned_hashers.py @@ -0,0 +1,149 @@ +""" +Versioned hasher factories for OrcaPod. + +This module is the single source of truth for which concrete hasher +implementations correspond to each versioned context. All code that +needs a "current" or "versioned" hasher should go through these factory +functions rather than constructing hashers directly, so that version +bumps happen in exactly one place. + +Functions +--------- +get_versioned_semantic_hasher() + Return the current-version SemanticHasherProtocol (the new content-based + recursive hasher that replaces BasicObjectHasher). + +get_versioned_semantic_arrow_hasher() + Return the current-version SemanticArrowHasher (Arrow table hasher + with semantic-type support). +""" + +from __future__ import annotations + +import logging +from typing import Any + +from orcapod.protocols import hashing_protocols as hp + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Version constants +# --------------------------------------------------------------------------- + +# The hasher_id embedded in every ContentHash produced by the current +# semantic hasher. Bump this string when the resolution/serialisation +# algorithm changes in a way that would alter hash outputs so that stored +# hashes can be distinguished from newly-computed ones. +_CURRENT_SEMANTIC_HASHER_ID = "semantic_v0.1" + +# The hasher_id for the Arrow hasher. +_CURRENT_ARROW_HASHER_ID = "arrow_v0.1" + + +# --------------------------------------------------------------------------- +# SemanticHasherProtocol factory +# --------------------------------------------------------------------------- + + +def get_versioned_semantic_hasher( + hasher_id: str = _CURRENT_SEMANTIC_HASHER_ID, + strict: bool = True, + type_handler_registry: "hp.TypeHandlerRegistry | None" = None, # type: ignore[name-defined] +) -> hp.SemanticHasherProtocol: + """ + Return a SemanticHasherProtocol configured for the current version. + + The returned hasher uses the global default TypeHandlerRegistry (which + is pre-populated with all built-in handlers) unless an explicit registry + is supplied. + + Parameters + ---------- + hasher_id: + Identifier embedded in every ContentHash produced by this hasher. + Defaults to the current version constant. Override only when + producing hashes that must be tagged with a specific version string. + strict: + When True (the default) the hasher raises TypeError on encountering + an object of an unhandled type. When False it falls back to a + best-effort string representation with a logged warning. + type_handler_registry: + Optional TypeHandlerRegistry to inject. When None the global + default registry is used (recommended for production code). + + Returns + ------- + SemanticHasherProtocol + A fully configured SemanticHasherProtocol instance. + """ + from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher + + if type_handler_registry is None: + from orcapod.hashing.semantic_hashing.type_handler_registry import ( + get_default_type_handler_registry, + ) + + type_handler_registry = get_default_type_handler_registry() + + logger.debug( + "get_versioned_semantic_hasher: creating BaseSemanticHasher " + "(hasher_id=%r, strict=%r)", + hasher_id, + strict, + ) + return BaseSemanticHasher( + hasher_id=hasher_id, + type_handler_registry=type_handler_registry, + strict=strict, + ) + + +# --------------------------------------------------------------------------- +# SemanticArrowHasher factory +# --------------------------------------------------------------------------- + + +def get_versioned_semantic_arrow_hasher( + hasher_id: str = _CURRENT_ARROW_HASHER_ID, +) -> hp.ArrowHasherProtocol: + """ + Return a SemanticArrowHasher configured for the current version. + + The arrow hasher handles Arrow table / RecordBatch hashing with + semantic-type awareness (e.g. Path columns are hashed by file content). + + Parameters + ---------- + hasher_id: + Identifier embedded in every ContentHash produced by this hasher. + + Returns + ------- + ArrowHasherProtocol + A fully configured SemanticArrowHasher instance. + """ + from orcapod.hashing.arrow_hashers import SemanticArrowHasher + from orcapod.hashing.file_hashers import BasicFileHasher + from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry + from orcapod.semantic_types.semantic_struct_converters import PathStructConverter + + # Build a default semantic registry populated with the standard converters. + # We use Any-typed locals here to side-step type-checker false positives + # that arise from the protocol definition of SemanticStructConverterProtocol having + # a slightly different hash_struct_dict signature than the concrete class. + registry: Any = SemanticTypeRegistry() + file_hasher = BasicFileHasher(algorithm="sha256") + path_converter: Any = PathStructConverter(file_hasher=file_hasher) + registry.register_converter("path", path_converter) + + logger.debug( + "get_versioned_semantic_arrow_hasher: creating SemanticArrowHasher " + "(hasher_id=%r)", + hasher_id, + ) + hasher: Any = SemanticArrowHasher( + hasher_id=hasher_id, + semantic_registry=registry, + ) + return hasher diff --git a/src/orcapod/hashing/visitors.py b/src/orcapod/hashing/visitors.py index e205a12d..dede8c85 100644 --- a/src/orcapod/hashing/visitors.py +++ b/src/orcapod/hashing/visitors.py @@ -8,16 +8,10 @@ """ from abc import ABC, abstractmethod -from typing import Any, TYPE_CHECKING -from orcapod.utils.lazy_module import LazyModule -from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry - - -if TYPE_CHECKING: - import pyarrow as pa -else: - pa = LazyModule("pyarrow") +from typing import TYPE_CHECKING, Any +from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa @@ -143,36 +137,6 @@ def _visit_list_elements( return pa.list_(new_element_type), processed_elements -class PassThroughVisitor(ArrowTypeDataVisitor): - """ - A visitor that passes through data unchanged. - - Useful as a base class or for testing the visitor pattern. - """ - - def visit_struct( - self, struct_type: "pa.StructType", data: dict | None - ) -> tuple["pa.DataType", Any]: - return self._visit_struct_fields(struct_type, data) - - def visit_list( - self, list_type: "pa.ListType", data: list | None - ) -> tuple["pa.DataType", Any]: - return self._visit_list_elements(list_type, data) - - def visit_map( - self, map_type: "pa.MapType", data: dict | None - ) -> tuple["pa.DataType", Any]: - # For simplicity, treat maps like structs for now - # TODO: Implement proper map handling if needed - return map_type, data - - def visit_primitive( - self, primitive_type: "pa.DataType", data: Any - ) -> tuple["pa.DataType", Any]: - return primitive_type, data - - class SemanticHashingError(Exception): """Exception raised when semantic hashing fails""" @@ -295,83 +259,3 @@ def _visit_struct_fields( self._current_field_path.pop() return pa.struct(new_fields), new_data - - -class ValidationVisitor(ArrowTypeDataVisitor): - """ - Example visitor for data validation. - - This demonstrates how the visitor pattern can be extended for other use cases. - """ - - def __init__(self): - self.errors: list[str] = [] - self._current_field_path: list[str] = [] - - def visit_struct( - self, struct_type: "pa.StructType", data: dict | None - ) -> tuple["pa.DataType", Any]: - if data is None: - return struct_type, None - - # Check for missing required fields - field_names = {field.name for field in struct_type} - data_keys = set(data.keys()) - missing_fields = field_names - data_keys - - if missing_fields: - field_path = ( - ".".join(self._current_field_path) - if self._current_field_path - else "" - ) - self.errors.append( - f"Missing required fields {missing_fields} at '{field_path}'" - ) - - return self._visit_struct_fields(struct_type, data) - - def visit_list( - self, list_type: "pa.ListType", data: list | None - ) -> tuple["pa.DataType", Any]: - if data is None: - return list_type, None - - self._current_field_path.append("[*]") - try: - return self._visit_list_elements(list_type, data) - finally: - self._current_field_path.pop() - - def visit_map( - self, map_type: "pa.MapType", data: dict | None - ) -> tuple["pa.DataType", Any]: - return map_type, data - - def visit_primitive( - self, primitive_type: "pa.DataType", data: Any - ) -> tuple["pa.DataType", Any]: - return primitive_type, data - - def _visit_struct_fields( - self, struct_type: "pa.StructType", data: dict | None - ) -> tuple["pa.StructType", dict]: - """Override to add field path tracking""" - if data is None: - return struct_type, None - - new_fields = [] - new_data = {} - - for field in struct_type: - self._current_field_path.append(field.name) - try: - field_data = data.get(field.name) - new_field_type, new_field_data = self.visit(field.type, field_data) - - new_fields.append(pa.field(field.name, new_field_type)) - new_data[field.name] = new_field_data - finally: - self._current_field_path.pop() - - return pa.struct(new_fields), new_data diff --git a/src/orcapod/nodes/__init__.py b/src/orcapod/nodes/__init__.py new file mode 100644 index 00000000..8786cd8d --- /dev/null +++ b/src/orcapod/nodes/__init__.py @@ -0,0 +1,4 @@ +"""Public re-export of orcapod.core.nodes.""" + +from orcapod.core.nodes import * # noqa: F401,F403 +from orcapod.core.nodes import __all__ diff --git a/src/orcapod/operators/__init__.py b/src/orcapod/operators/__init__.py new file mode 100644 index 00000000..8cf25afb --- /dev/null +++ b/src/orcapod/operators/__init__.py @@ -0,0 +1,4 @@ +"""Public re-export of orcapod.core.operators.""" + +from orcapod.core.operators import * # noqa: F401,F403 +from orcapod.core.operators import __all__ diff --git a/src/orcapod/pipeline/__init__.py b/src/orcapod/pipeline/__init__.py index 616846a0..67dddfe1 100644 --- a/src/orcapod/pipeline/__init__.py +++ b/src/orcapod/pipeline/__init__.py @@ -1,11 +1,15 @@ -# from .legacy_pipeline import Pipeline - -# __all__ = [ -# "Pipeline", -# ] - +from .async_orchestrator import AsyncPipelineOrchestrator from .graph import Pipeline +from .logging_observer import LoggingObserver, PacketLogger +from .serialization import LoadStatus, PIPELINE_FORMAT_VERSION +from .sync_orchestrator import SyncPipelineOrchestrator __all__ = [ + "AsyncPipelineOrchestrator", + "LoadStatus", + "LoggingObserver", + "PacketLogger", + "PIPELINE_FORMAT_VERSION", "Pipeline", + "SyncPipelineOrchestrator", ] diff --git a/src/orcapod/pipeline/async_orchestrator.py b/src/orcapod/pipeline/async_orchestrator.py new file mode 100644 index 00000000..c6e35d41 --- /dev/null +++ b/src/orcapod/pipeline/async_orchestrator.py @@ -0,0 +1,225 @@ +"""Async pipeline orchestrator for push-based channel execution. + +Walks a compiled pipeline's node graph and launches all nodes concurrently +via ``asyncio.TaskGroup``, wiring them together with bounded channels. +Uses TypeGuard dispatch with tightened per-type async_execute signatures. +""" + +from __future__ import annotations + +import asyncio +import uuid +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +from orcapod.channels import BroadcastChannel, Channel +from orcapod.pipeline.result import OrchestratorResult +from orcapod.protocols.node_protocols import ( + is_function_node, + is_operator_node, + is_source_node, +) + +if TYPE_CHECKING: + import networkx as nx + + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + +logger = logging.getLogger(__name__) + + +class AsyncPipelineOrchestrator: + """Execute a compiled pipeline asynchronously using channels. + + After compilation, the orchestrator: + + 1. Walks the node graph in topological order. + 2. Creates bounded channels (or broadcast channels for fan-out) + between connected nodes. + 3. Launches every node's ``async_execute`` concurrently via + ``asyncio.TaskGroup``, using TypeGuard dispatch for per-type + signatures. + + Args: + observer: Optional execution observer for hooks. + buffer_size: Channel buffer size. Defaults to 64. + """ + + def __init__( + self, + observer: "ExecutionObserverProtocol | None" = None, + buffer_size: int = 64, + error_policy: str = "continue", + ) -> None: + self._observer = observer + self._buffer_size = buffer_size + self._error_policy = error_policy + + def run( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, + run_id: str | None = None, + ) -> OrchestratorResult: + """Synchronous entry point — runs the async pipeline to completion. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, collect all node outputs into + the result. If False, return empty node_outputs. + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. + + Returns: + OrchestratorResult with node outputs. + """ + return asyncio.run(self._run_async(graph, materialize_results, run_id=run_id)) + + async def run_async( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, + run_id: str | None = None, + ) -> OrchestratorResult: + """Async entry point for callers already inside an event loop. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, collect all node outputs. + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. + + Returns: + OrchestratorResult with node outputs. + """ + return await self._run_async(graph, materialize_results, run_id=run_id) + + async def _run_async( + self, + graph: "nx.DiGraph", + materialize_results: bool, + run_id: str | None = None, + ) -> OrchestratorResult: + """Core async logic: wire channels, launch tasks, collect results.""" + import networkx as nx + + effective_run_id = run_id or str(uuid.uuid4()) + if self._observer is not None: + self._observer.on_run_start(effective_run_id) + + topo_order = list(nx.topological_sort(graph)) + buf = self._buffer_size + + # Build edge maps + out_edges: dict[Any, list[Any]] = defaultdict(list) + in_edges: dict[Any, list[Any]] = defaultdict(list) + for upstream_node, downstream_node in graph.edges(): + out_edges[upstream_node].append(downstream_node) + in_edges[downstream_node].append(upstream_node) + + # Create channels for each edge + node_output_channels: dict[Any, Channel | BroadcastChannel] = {} + edge_readers: dict[tuple[Any, Any], Any] = {} + + for node, downstreams in out_edges.items(): + if len(downstreams) == 1: + ch = Channel(buffer_size=buf) + node_output_channels[node] = ch + edge_readers[(node, downstreams[0])] = ch.reader + else: + bch = BroadcastChannel(buffer_size=buf) + node_output_channels[node] = bch + for ds in downstreams: + edge_readers[(node, ds)] = bch.add_reader() + + # Terminal nodes need sink channels + terminal_channels: list[Channel] = [] + for node in topo_order: + if node not in node_output_channels: + ch = Channel(buffer_size=buf) + node_output_channels[node] = ch + terminal_channels.append(ch) + + # Result collection + collectors: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} + if materialize_results: + for node in topo_order: + collectors[node] = [] + + # Launch all nodes concurrently + async with asyncio.TaskGroup() as tg: + for node in topo_order: + writer = node_output_channels[node].writer + + if materialize_results: + collector = collectors[node] + writer = _CollectingWriter(writer, collector) + + if is_source_node(node): + tg.create_task( + node.async_execute(writer, observer=self._observer) + ) + elif is_function_node(node): + predecessors = in_edges.get(node, []) + if len(predecessors) != 1: + raise ValueError( + f"FunctionNode expects exactly 1 upstream, " + f"got {len(predecessors)}" + ) + input_reader = edge_readers[(predecessors[0], node)] + tg.create_task( + node.async_execute( + input_reader, writer, observer=self._observer + ) + ) + elif is_operator_node(node): + predecessors = in_edges.get(node, []) + # Sort by node.upstreams order for non-commutative operators + upstream_order = {id(s): i for i, s in enumerate(node.upstreams)} + sorted_preds = sorted( + predecessors, + key=lambda p: upstream_order.get(id(p), 0), + ) + input_readers = [ + edge_readers[(upstream, node)] + for upstream in sorted_preds + ] + tg.create_task( + node.async_execute( + input_readers, writer, observer=self._observer + ) + ) + else: + raise TypeError( + f"Unknown node type: {getattr(node, 'node_type', None)!r}" + ) + + # Drain terminal channels concurrently + for ch in terminal_channels: + tg.create_task(ch.reader.collect()) + + if self._observer is not None: + self._observer.on_run_end(effective_run_id) + return OrchestratorResult( + node_outputs=collectors if materialize_results else {} + ) + + +class _CollectingWriter: + """Wrapper that collects items while forwarding to real writer.""" + + def __init__(self, writer: Any, collector: list) -> None: + self._writer = writer + self._collector = collector + + async def send(self, item: Any) -> None: + self._collector.append(item) + await self._writer.send(item) + + async def close(self) -> None: + await self._writer.close() + + def __getattr__(self, name: str) -> Any: + return getattr(self._writer, name) diff --git a/src/orcapod/pipeline/graph.py b/src/orcapod/pipeline/graph.py index 45d83e0f..c5a9595e 100644 --- a/src/orcapod/pipeline/graph.py +++ b/src/orcapod/pipeline/graph.py @@ -1,16 +1,22 @@ -from orcapod.core.trackers import GraphTracker, Invocation -from orcapod.pipeline.nodes import KernelNode, PodNode -from orcapod.protocols.pipeline_protocols import Node -from orcapod import contexts -from orcapod.protocols import core_protocols as cp -from orcapod.protocols import database_protocols as dbp -from typing import Any, cast -from collections.abc import Collection +from __future__ import annotations + +import json +import logging import os import tempfile -import logging -import asyncio -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from orcapod.core.nodes import ( + FunctionNode, + GraphNode, + OperatorNode, + SourceNode, +) +from orcapod.core.tracker import AutoRegisteringContextBasedTracker +from orcapod.protocols import core_protocols as cp +from orcapod.protocols import database_protocols as dbp +from orcapod.types import PipelineConfig from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: @@ -21,33 +27,12 @@ logger = logging.getLogger(__name__) -def synchronous_run(async_func, *args, **kwargs): - """ - Use existing event loop if available. - - Pros: Reuses existing loop, more efficient - Cons: More complex, need to handle loop detection - """ - try: - # Check if we're already in an event loop - _ = asyncio.get_running_loop() - - def run_in_thread(): - return asyncio.run(async_func(*args, **kwargs)) - - import concurrent.futures - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) - return future.result() - except RuntimeError: - # No event loop running, safe to use asyncio.run() - return asyncio.run(async_func(*args, **kwargs)) - +# --------------------------------------------------------------------------- +# Visualization helper (unrelated to pipeline node types) +# --------------------------------------------------------------------------- - -class GraphNode: +class VizGraphNode: def __init__(self, label: str, id: int, kernel_type: str): self.label = label self.id = id @@ -57,7 +42,7 @@ def __hash__(self): return hash((self.id, self.kernel_type)) def __eq__(self, other): - if not isinstance(other, GraphNode): + if not isinstance(other, VizGraphNode): return NotImplemented return (self.id, self.kernel_type) == ( other.id, @@ -65,259 +50,927 @@ def __eq__(self, other): ) -class Pipeline(GraphTracker): - """ - Represents a pipeline in the system. - This class extends GraphTracker to manage the execution of kernels and pods in a pipeline. +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +class Pipeline(AutoRegisteringContextBasedTracker): + """A persistent pipeline that records operator and function pod invocations. + + During the ``with`` block, operator and function pod invocations are + recorded into an internal graph. On context exit, ``compile()`` rewires + the graph into execution-ready nodes: + + - Leaf streams -> ``SourceNode`` (thin wrapper for graph vertex) + - Function pod invocations -> ``FunctionNode`` + - Operator invocations -> ``OperatorNode`` + + Source caching is not a pipeline concern -- sources that need caching + should be wrapped in a ``CachedSource`` before being used in the + pipeline. + + All persistent nodes share the same ``pipeline_database`` and use + ``pipeline_name`` as path prefix, scoping their cache tables. + + Parameters: + name: Pipeline name (string or tuple). Used as the path prefix for + all cache/pipeline paths within the databases. + pipeline_database: Database for pipeline records and operator caches. + function_database: Optional separate database for function pod result + caches. When ``None``, ``pipeline_database`` is used with a + ``_results`` subfolder under the pipeline name. + auto_compile: If ``True`` (default), ``compile()`` is called + automatically when the context manager exits. """ def __init__( self, name: str | tuple[str, ...], - pipeline_database: dbp.ArrowDatabase, - results_database: dbp.ArrowDatabase | None = None, - tracker_manager: cp.TrackerManager | None = None, - data_context: str | contexts.DataContext | None = None, + pipeline_database: dbp.ArrowDatabaseProtocol, + function_database: dbp.ArrowDatabaseProtocol | None = None, + tracker_manager: cp.TrackerManagerProtocol | None = None, auto_compile: bool = True, - ): - super().__init__(tracker_manager=tracker_manager, data_context=data_context) - if not isinstance(name, tuple): - name = (name,) - self.name = name - self.pipeline_store_path_prefix = self.name - self.results_store_path_prefix = () - if results_database is None: - if pipeline_database is None: - raise ValueError( - "Either pipeline_database or results_database must be provided" - ) - results_database = pipeline_database - self.results_store_path_prefix = self.name + ("_results",) - self.pipeline_database = pipeline_database - self.results_database = results_database - self._nodes: dict[str, Node] = {} - self.auto_compile = auto_compile - self._dirty = False - self._ordered_nodes = [] # Track order of invocations + ) -> None: + super().__init__(tracker_manager=tracker_manager) + self._node_lut: dict[str, GraphNode] = {} + self._upstreams: dict[str, cp.StreamProtocol] = {} + self._graph_edges: list[tuple[str, str]] = [] + self._hash_graph: "nx.DiGraph" = nx.DiGraph() + self._name = (name,) if isinstance(name, str) else tuple(name) + self._pipeline_database = pipeline_database + self._function_database = function_database + self._pipeline_path_prefix = self._name + self._nodes: dict[str, GraphNode] = {} + self._persistent_node_map: dict[str, GraphNode] = {} + self._node_graph: "nx.DiGraph | None" = None + self._auto_compile = auto_compile + self._compiled = False + + # ------------------------------------------------------------------ + # Recording (TrackerProtocol) + # ------------------------------------------------------------------ + + def record_function_pod_invocation( + self, + pod: cp.FunctionPodProtocol, + input_stream: cp.StreamProtocol, + label: str | None = None, + ) -> None: + input_stream_hash = input_stream.content_hash().to_string() + function_node = FunctionNode( + function_pod=pod, + input_stream=input_stream, + label=label, + ) + function_node_hash = function_node.content_hash().to_string() + self._node_lut[function_node_hash] = function_node + self._upstreams[input_stream_hash] = input_stream + self._graph_edges.append((input_stream_hash, function_node_hash)) + self._hash_graph.add_edge(input_stream_hash, function_node_hash) + if not self._hash_graph.nodes[function_node_hash].get("node_type"): + self._hash_graph.nodes[function_node_hash]["node_type"] = "function" + + def record_operator_pod_invocation( + self, + pod: cp.OperatorPodProtocol, + upstreams: tuple[cp.StreamProtocol, ...] = (), + label: str | None = None, + ) -> None: + operator_node = OperatorNode( + operator=pod, + input_streams=upstreams, + label=label, + ) + operator_node_hash = operator_node.content_hash().to_string() + self._node_lut[operator_node_hash] = operator_node + upstream_hashes = [stream.content_hash().to_string() for stream in upstreams] + for upstream_hash, upstream in zip(upstream_hashes, upstreams): + self._upstreams[upstream_hash] = upstream + self._graph_edges.append((upstream_hash, operator_node_hash)) + self._hash_graph.add_edge(upstream_hash, operator_node_hash) + if not self._hash_graph.nodes[operator_node_hash].get("node_type"): + self._hash_graph.nodes[operator_node_hash]["node_type"] = "operator" @property - def nodes(self) -> dict[str, Node]: - return self._nodes.copy() + def nodes(self) -> list[GraphNode]: + """Return the list of recorded (non-persistent) nodes.""" + return list(self._node_lut.values()) @property - def function_pods(self) -> dict[str, cp.Pod]: - return { - label: cast(cp.Pod, node) - for label, node in self._nodes.items() - if getattr(node, "kernel_type") == "function" - } + def graph(self) -> "nx.DiGraph": + """Directed graph of content-hash strings representing the accumulated + pipeline structure. Vertices are ``content_hash`` strings; node + attributes include ``node_type`` ("source" / "function" / "operator") + and, after ``compile()``, ``label`` and ``pipeline_hash``. + + The graph accumulates across multiple ``with`` blocks and is never + cleared by ``reset()``. + """ + return self._hash_graph + + def reset(self) -> None: + """Clear session-scoped recorded state (node LUT, upstreams, edge list). + + Note: ``_hash_graph`` is intentionally *not* cleared -- it accumulates + the pipeline structure across ``with`` blocks. + """ + self._node_lut.clear() + self._upstreams.clear() + self._graph_edges.clear() + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ @property - def source_pods(self) -> dict[str, cp.Source]: - return { - label: node - for label, node in self._nodes.items() - if getattr(node, "kernel_type") == "source" - } + def name(self) -> tuple[str, ...]: + return self._name @property - def operator_pods(self) -> dict[str, cp.Kernel]: - return { - label: node - for label, node in self._nodes.items() - if getattr(node, "kernel_type") == "operator" - } + def pipeline_database(self) -> dbp.ArrowDatabaseProtocol: + return self._pipeline_database + + @property + def function_database(self) -> dbp.ArrowDatabaseProtocol | None: + return self._function_database + + @property + def compiled_nodes(self) -> dict[str, GraphNode]: + """Return a copy of the compiled nodes dict.""" + return self._nodes.copy() + + # ------------------------------------------------------------------ + # Context manager + # ------------------------------------------------------------------ def __exit__(self, exc_type=None, exc_value=None, traceback=None): - """ - Exit the pipeline context, ensuring all nodes are properly closed. - """ super().__exit__(exc_type, exc_value, traceback) - if self.auto_compile: + if self._auto_compile: self.compile() - def flush(self) -> None: - self.pipeline_database.flush() - self.results_database.flush() + # ------------------------------------------------------------------ + # Compile + # ------------------------------------------------------------------ - def record_kernel_invocation( - self, - kernel: cp.Kernel, - upstreams: tuple[cp.Stream, ...], - label: str | None = None, - ) -> None: - super().record_kernel_invocation(kernel, upstreams, label) - self._dirty = True + def compile(self) -> None: + """Compile recorded invocations into execution-ready nodes. - def record_pod_invocation( - self, - pod: cp.Pod, - upstreams: tuple[cp.Stream, ...], - label: str | None = None, - ) -> None: - super().record_pod_invocation(pod, upstreams, label) - self._dirty = True + Walks the graph in topological order and: - def compile(self) -> None: - import networkx as nx + - Wraps leaf streams in ``SourceNode`` + - Rewires upstream references on recorded ``FunctionNode`` / + ``OperatorNode`` to point at persistent (compiled) nodes + - Attaches databases to function/operator nodes via + ``attach_databases()`` - name_candidates = {} + After compile, nodes are accessible by label as attributes on the + pipeline instance. + """ + from orcapod.core.nodes import ( + FunctionNode, + OperatorNode, + ) - invocation_to_stream_lut = {} - G = self.generate_graph() - node_graph = nx.DiGraph() - for invocation in nx.topological_sort(G): - input_streams = [ - invocation_to_stream_lut[parent] for parent in invocation.parents() - ] + G = nx.DiGraph() + for edge in self._graph_edges: + G.add_edge(*edge) - node = self.wrap_invocation(invocation, new_input_streams=input_streams) + # Seed from existing persistent nodes (incremental compile) + persistent_node_map: dict[str, GraphNode] = dict(self._persistent_node_map) + name_candidates: dict[str, list[GraphNode]] = {} - for parent in node.upstreams: - node_graph.add_edge(parent.source, node) + for node_hash in nx.topological_sort(G): + if node_hash in persistent_node_map: + # Already compiled — reuse, but track for label assignment + existing_node = persistent_node_map[node_hash] + name_candidates.setdefault(existing_node.label, []).append( + existing_node + ) + continue - invocation_to_stream_lut[invocation] = node() + if node_hash not in self._node_lut: + # -- Leaf stream: wrap in SourceNode -- + stream = self._upstreams[node_hash] + node = SourceNode(stream=stream) + persistent_node_map[node_hash] = node + else: + node = self._node_lut[node_hash] + + if isinstance(node, FunctionNode): + # Rewire input stream to persistent upstream + input_hash = node._input_stream.content_hash().to_string() + rewired_input = persistent_node_map[input_hash] + node.upstreams = (rewired_input,) + + # Determine result database and path prefix + if self._function_database is not None: + result_db = self._function_database + result_prefix = None + else: + result_db = self._pipeline_database + result_prefix = self._name + ("_results",) + + node.attach_databases( + pipeline_database=self._pipeline_database, + result_database=result_db, + result_path_prefix=result_prefix, + pipeline_path_prefix=self._pipeline_path_prefix, + ) + + elif isinstance(node, OperatorNode): + # Rewire all input streams to persistent upstreams + rewired_inputs = tuple( + persistent_node_map[s.content_hash().to_string()] + for s in node.upstreams + ) + node.upstreams = rewired_inputs + + node.attach_databases( + pipeline_database=self._pipeline_database, + pipeline_path_prefix=self._pipeline_path_prefix, + ) + + else: + raise TypeError( + f"Unknown node type in pipeline graph: {type(node)}" + ) + + persistent_node_map[node_hash] = node + + # Track all nodes for label assignment name_candidates.setdefault(node.label, []).append(node) - # visit through the name candidates and resolve any collisions + # Save persistent node map for incremental re-compile + self._persistent_node_map = persistent_node_map + + # Build node graph for run() ordering + self._node_graph = nx.DiGraph() + for upstream_hash, downstream_hash in self._graph_edges: + upstream_node = persistent_node_map.get(upstream_hash) + downstream_node = persistent_node_map.get(downstream_hash) + if upstream_node is not None and downstream_node is not None: + self._node_graph.add_edge(upstream_node, downstream_node) + # Add isolated nodes (sources with no downstream in edges) + for node in persistent_node_map.values(): + if node not in self._node_graph: + self._node_graph.add_node(node) + + # Enrich hash graph with compiled node metadata (label, pipeline_hash, node_type) + for node_hash, node in persistent_node_map.items(): + if node_hash not in self._hash_graph: + continue + attrs = self._hash_graph.nodes[node_hash] + if not attrs.get("node_type"): + if isinstance(node, SourceNode): + attrs["node_type"] = "source" + elif isinstance(node, FunctionNode): + attrs["node_type"] = "function" + elif isinstance(node, OperatorNode): + attrs["node_type"] = "operator" + if not attrs.get("label"): + computed = node.label or ( + node.computed_label() if hasattr(node, "computed_label") else None + ) + if computed: + attrs["label"] = computed + if not attrs.get("pipeline_hash"): + attrs["pipeline_hash"] = node.pipeline_hash().to_string() + + # Assign labels, disambiguating collisions by content hash + self._nodes.clear() for label, nodes in name_candidates.items(): if len(nodes) > 1: - # If there are multiple nodes with the same label, we need to resolve the collision - logger.info(f"Collision detected for label '{label}': {nodes}") - for i, node in enumerate(nodes, start=1): - self._nodes[f"{label}_{i}"] = node - node.label = f"{label}_{i}" + # Sort by content hash for deterministic disambiguation + sorted_nodes = sorted(nodes, key=lambda n: n.content_hash().to_string()) + for i, node in enumerate(sorted_nodes, start=1): + key = f"{label}_{i}" + self._nodes[key] = node + node._label = key else: self._nodes[label] = nodes[0] - nodes[0].label = label - - self.label_lut = {v: k for k, v in self._nodes.items()} - - self.graph = node_graph - def show_graph(self, **kwargs) -> None: - render_graph(self.graph, **kwargs) + self._compiled = True - def set_mode(self, mode: str) -> None: - if mode not in ("production", "development"): - raise ValueError("Mode must be either 'production' or 'development'") - for node in self._nodes.values(): - if hasattr(node, "set_mode"): - node.set_mode(mode) + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ def run( self, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - run_async: bool | None = None, + orchestrator=None, + config: PipelineConfig | None = None, + execution_engine: cp.PacketFunctionExecutorProtocol | None = None, + execution_engine_opts: "dict[str, Any] | None" = None, ) -> None: - """Execute the pipeline by running all nodes in the graph. - - This method traverses through all nodes in the graph and executes them sequentially - using the specified execution engine. After execution, flushes the pipeline. + """Execute all compiled nodes. Args: - execution_engine (dp.ExecutionEngine | None): The execution engine to use for running - the nodes. If None, creates a new default ExecutionEngine instance. - run_async (bool | None): Whether to run nodes asynchronously. If None, defaults to - the preferred mode based on the execution engine. + orchestrator: Optional orchestrator instance. When provided, + the orchestrator drives execution and nodes handle their + own persistence internally. When omitted, defaults to + ``SyncPipelineOrchestrator`` (sync mode) or + ``AsyncPipelineOrchestrator`` (async mode). + config: Pipeline configuration. When ``config.executor`` is + ``ExecutorType.ASYNC_CHANNELS``, the pipeline runs + asynchronously via the orchestrator. When ``config`` is + omitted and an ``execution_engine`` is provided, async mode + is used by default. Passing an explicit ``config`` always + takes priority — supply ``ExecutorType.SYNCHRONOUS`` to force + synchronous execution even when an engine is present. + execution_engine: Optional packet-function executor applied to + every function node before execution (e.g. a ``RayExecutor``). + Overrides ``config.execution_engine`` when both are provided. + execution_engine_opts: Resource/options dict forwarded to the + engine via ``with_options()`` (e.g. ``{"num_cpus": 4}``). + Overrides ``config.execution_engine_opts`` when both are + provided. + """ + from orcapod.types import ExecutorType, PipelineConfig - Returns: - None + explicit_config = config is not None + config = config or PipelineConfig() - Note: - Current implementation uses a simple traversal through all nodes. Future versions - may implement more efficient graph traversal algorithms. - """ - import networkx as nx - if run_async is True and (execution_engine is None or not execution_engine.supports_async): - raise ValueError( - "Cannot run asynchronously with an execution engine that does not support async." - ) + # Explicit kwargs take precedence over values baked into config. + effective_engine = ( + execution_engine + if execution_engine is not None + else config.execution_engine + ) + effective_opts = ( + execution_engine_opts + if execution_engine_opts is not None + else config.execution_engine_opts + ) - # if set to None, determine based on execution engine capabilities - if run_async is None: - run_async = execution_engine is not None and execution_engine.supports_async + if not self._compiled: + self.compile() - logger.info(f"Running pipeline with run_async={run_async}") + if effective_engine is not None: + self._apply_execution_engine(effective_engine, effective_opts) - for node in nx.topological_sort(self.graph): - if run_async: - synchronous_run( - node.run_async, - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + if orchestrator is not None: + orchestrator.run(self._node_graph) + else: + # Default to async when an execution engine is provided, unless + # the caller explicitly supplied a config — in which case + # config.executor is authoritative and takes priority. + use_async = config.executor == ExecutorType.ASYNC_CHANNELS or ( + effective_engine is not None and not explicit_config + ) + if use_async: + from orcapod.pipeline.async_orchestrator import ( + AsyncPipelineOrchestrator, ) + + AsyncPipelineOrchestrator( + buffer_size=config.channel_buffer_size, + ).run(self._node_graph) else: - node.run( - execution_engine=execution_engine, - execution_engine_opts=execution_engine_opts, + from orcapod.pipeline.sync_orchestrator import ( + SyncPipelineOrchestrator, ) + SyncPipelineOrchestrator().run(self._node_graph) + self.flush() - def wrap_invocation( + def _apply_execution_engine( self, - invocation: Invocation, - new_input_streams: Collection[cp.Stream], - ) -> Node: - if invocation in self.invocation_to_pod_lut: - pod = self.invocation_to_pod_lut[invocation] - node = PodNode( - pod=pod, - input_streams=new_input_streams, - result_database=self.results_database, - record_path_prefix=self.results_store_path_prefix, - pipeline_database=self.pipeline_database, - pipeline_path_prefix=self.pipeline_store_path_prefix, - label=invocation.label, - kernel_type="function", + execution_engine: cp.PacketFunctionExecutorProtocol, + execution_engine_opts: dict[str, Any] | None, + ) -> None: + """Apply *execution_engine* to every ``FunctionNode`` in the pipeline. + + Each node receives its own executor instance via + ``engine.with_options(**opts)`` — even when *opts* is empty. + The executor's ``with_options`` implementation decides which + components to copy vs share (e.g. connection handles may be + shared while per-node state is copied). + + Args: + execution_engine: Executor to apply (must implement + ``PacketFunctionExecutorBase`` or at minimum expose + ``with_options``). + execution_engine_opts: Pipeline-level options dict, or + ``None`` for no defaults. + """ + assert self._node_graph is not None, ( + "_apply_execution_engine called before compile()" + ) + + opts = execution_engine_opts or {} + + for node in self._node_graph.nodes: + if not isinstance(node, FunctionNode): + continue + node.executor = execution_engine.with_options(**opts) + logger.debug( + "Applied execution engine %r to node %r (opts=%r)", + type(execution_engine).__name__, + node.label, + opts or None, ) - elif invocation in self.invocation_to_source_lut: - source = self.invocation_to_source_lut[invocation] - node = KernelNode( - kernel=source, - input_streams=new_input_streams, - pipeline_database=self.pipeline_database, - pipeline_path_prefix=self.pipeline_store_path_prefix, - label=invocation.label, - kernel_type="source", + + def flush(self) -> None: + """Flush all databases.""" + self._pipeline_database.flush() + if self._function_database is not None: + self._function_database.flush() + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + + def save(self, path: str) -> None: + """Serialize the compiled pipeline to a JSON file. + + Args: + path: File path to write the JSON output to. + + Raises: + ValueError: If the pipeline has not been compiled. + """ + if not self._compiled: + raise ValueError( + "Pipeline is not compiled. Call compile() or use " + "auto_compile=True before saving." ) + + from orcapod.core.nodes import OperatorNode + from orcapod.pipeline.serialization import ( + PIPELINE_FORMAT_VERSION, + serialize_schema, + ) + + # -- Pipeline metadata -- + pipeline_meta: dict[str, Any] = { + "name": list(self._name), + "databases": { + "pipeline_database": self._pipeline_database.to_config(), + "function_database": ( + self._function_database.to_config() + if self._function_database is not None + else None + ), + }, + } + + # -- Build node descriptors -- + nodes: dict[str, dict[str, Any]] = {} + for content_hash_str, node in self._persistent_node_map.items(): + tag_schema, packet_schema = node.output_schema() + type_converter = node.data_context.type_converter + descriptor: dict[str, Any] = { + "node_type": node.node_type, + "label": node.label, + "content_hash": node.content_hash().to_string(), + "pipeline_hash": node.pipeline_hash().to_string(), + "data_context_key": node.data_context_key, + "output_schema": { + "tag": serialize_schema(tag_schema, type_converter), + "packet": serialize_schema(packet_schema, type_converter), + }, + } + + if isinstance(node, SourceNode): + descriptor.update(self._build_source_descriptor(node)) + elif isinstance(node, FunctionNode): + descriptor.update(self._build_function_descriptor(node)) + elif isinstance(node, OperatorNode): + descriptor.update(self._build_operator_descriptor(node)) + + nodes[content_hash_str] = descriptor + + # -- Edges -- + edges = [list(edge) for edge in self._graph_edges] + + # -- Assemble top-level structure -- + output = { + "orcapod_pipeline_version": PIPELINE_FORMAT_VERSION, + "pipeline": pipeline_meta, + "nodes": nodes, + "edges": edges, + } + + with open(path, "w") as f: + json.dump(output, f, indent=2) + + # Reconstructable source types: file-backed sources that can be + # rebuilt from config alone. + _RECONSTRUCTABLE_SOURCE_TYPES = frozenset({"csv", "delta_table", "cached"}) + + def _build_source_descriptor(self, node: SourceNode) -> dict[str, Any]: + """Build source-specific descriptor fields for a SourceNode. + + Args: + node: The SourceNode to describe. + + Returns: + Dict with source-specific fields. + """ + stream = node.stream + + # Determine if stream implements SourceProtocol and build descriptor accordingly + # TODO: revisit this logic + if isinstance(stream, cp.SourceProtocol): + config = stream.to_config() + stream_type = config.get("source_type", "stream") + source_config = config + reconstructable = stream_type in self._RECONSTRUCTABLE_SOURCE_TYPES else: - node = KernelNode( - kernel=invocation.kernel, - input_streams=new_input_streams, - pipeline_database=self.pipeline_database, - pipeline_path_prefix=self.pipeline_store_path_prefix, - label=invocation.label, - kernel_type="operator", + stream_type = "stream" + source_config = None + reconstructable = False + + source_id = getattr(stream, "source_id", None) + + return { + "stream_type": stream_type, + "source_id": source_id, + "reconstructable": reconstructable, + "source_config": source_config, + } + + def _build_function_descriptor(self, node: "FunctionNode") -> dict[str, Any]: + """Build function-specific descriptor fields for a FunctionNode. + + Args: + node: The FunctionNode to describe. + + Returns: + Dict with function-specific fields. + """ + return { + "function_pod": node._function_pod.to_config(), + "pipeline_path": list(node.pipeline_path), + "result_record_path": list(node._cached_function_pod.record_path), + } + + def _build_operator_descriptor(self, node: OperatorNode) -> dict[str, Any]: + """Build operator-specific descriptor fields for a OperatorNode. + + Args: + node: The OperatorNode to describe. + + Returns: + Dict with operator-specific fields. + """ + return { + "operator": node._operator.to_config(), + "cache_mode": node._cache_mode.name, + "pipeline_path": list(node.pipeline_path), + } + + @classmethod + def load(cls, path: str | Path, mode: str = "full") -> "Pipeline": + """Deserialize a pipeline from a JSON file. + + Reconstructs the pipeline graph from the serialized descriptor, + rebuilding nodes in topological order. The *mode* parameter + controls how aggressively live objects are reconstructed: + + - ``"full"``: attempt to reconstruct live sources, function pods, + and operators so the pipeline can be re-run. Falls back to + read-only per-node when reconstruction fails. + - ``"read_only"``: load metadata only; no live sources or + function pods are reconstructed. + + Args: + path: Path to the JSON file produced by :meth:`save`. + mode: ``"full"`` (default) or ``"read_only"``. + + Returns: + A compiled ``Pipeline`` instance. + + Raises: + ValueError: If the file's format version is unsupported. + """ + + from orcapod.pipeline.serialization import ( + SUPPORTED_FORMAT_VERSIONS, + LoadStatus, + resolve_database_from_config, + resolve_operator_from_config, + resolve_source_from_config, + ) + + path = Path(path) + with open(path) as f: + data = json.load(f) + + # 1. Validate version + version = data.get("orcapod_pipeline_version", "") + if version not in SUPPORTED_FORMAT_VERSIONS: + raise ValueError( + f"Unsupported pipeline format version {version!r}. " + f"Supported versions: {sorted(SUPPORTED_FORMAT_VERSIONS)}" ) - return node + + # 2. Reconstruct databases + pipeline_meta = data["pipeline"] + db_configs = pipeline_meta["databases"] + + pipeline_db = resolve_database_from_config(db_configs["pipeline_database"]) + function_db = ( + resolve_database_from_config(db_configs["function_database"]) + if db_configs.get("function_database") is not None + else None + ) + + # 3. Build edge graph and derive topological order + nodes_data = data["nodes"] + edges = data["edges"] + + edge_graph: nx.DiGraph = nx.DiGraph() + for upstream_hash, downstream_hash in edges: + edge_graph.add_edge(upstream_hash, downstream_hash) + # Add isolated nodes (nodes with no edges) + for node_hash in nodes_data: + if node_hash not in edge_graph: + edge_graph.add_node(node_hash) + + topo_order = list(nx.topological_sort(edge_graph)) + + # 4. Walk nodes in topological order, reconstruct each + reconstructed: dict[str, SourceNode | FunctionNode | OperatorNode] = {} + + # Build reverse edge map: downstream -> list of upstream hashes + upstream_map: dict[str, list[str]] = {} + for up_hash, down_hash in edges: + upstream_map.setdefault(down_hash, []).append(up_hash) + + for node_hash in topo_order: + descriptor = nodes_data.get(node_hash) + if descriptor is None: + continue + + node_type = descriptor.get("node_type") + + if node_type == "source": + node = cls._load_source_node( + descriptor, mode, resolve_source_from_config + ) + reconstructed[node_hash] = node + + elif node_type == "function": + # Determine upstream node + up_hashes = upstream_map.get(node_hash, []) + upstream_node = reconstructed.get(up_hashes[0]) if up_hashes else None + + # Check if upstream is usable for full mode + upstream_usable = ( + upstream_node is not None + and hasattr(upstream_node, "load_status") + and upstream_node.load_status == LoadStatus.FULL + ) + + # Build databases dict + result_db = function_db if function_db is not None else pipeline_db + dbs = {"pipeline": pipeline_db, "result": result_db} + + node = cls._load_function_node( + descriptor, mode, upstream_node, upstream_usable, dbs + ) + reconstructed[node_hash] = node + + elif node_type == "operator": + up_hashes = upstream_map.get(node_hash, []) + upstream_nodes = tuple( + reconstructed[h] for h in up_hashes if h in reconstructed + ) + + # Check if all upstreams are usable + all_upstreams_usable = ( + all( + hasattr(n, "load_status") and n.load_status == LoadStatus.FULL + for n in upstream_nodes + ) + if upstream_nodes + else False + ) + + dbs = {"pipeline": pipeline_db} + + node = cls._load_operator_node( + descriptor, + mode, + upstream_nodes, + all_upstreams_usable, + dbs, + resolve_operator_from_config, + ) + reconstructed[node_hash] = node + + # 5. Build Pipeline instance + name = tuple(pipeline_meta["name"]) + pipeline = cls( + name=name, + pipeline_database=pipeline_db, + function_database=function_db, + auto_compile=False, + ) + + # Populate persistent node map + pipeline._persistent_node_map = dict(reconstructed) + + # Populate _nodes (label -> node) for all labeled nodes. + # Unlike compile() which excludes source nodes from _nodes, + # loaded pipelines include them so users can inspect load_status + # and metadata for all nodes via attribute access. + pipeline._nodes = {} + for node_hash, node in reconstructed.items(): + label = node.label + if label: + pipeline._nodes[label] = node + + # Build node graph + pipeline._node_graph = nx.DiGraph() + for up_hash, down_hash in edges: + up_node = reconstructed.get(up_hash) + down_node = reconstructed.get(down_hash) + if up_node is not None and down_node is not None: + pipeline._node_graph.add_edge(up_node, down_node) + for node in reconstructed.values(): + if node not in pipeline._node_graph: + pipeline._node_graph.add_node(node) + + # Restore graph edges as content_hash string pairs + pipeline._graph_edges = [(up, down) for up, down in edges] + + # Rebuild _hash_graph + pipeline._hash_graph = nx.DiGraph() + for up_hash, down_hash in edges: + pipeline._hash_graph.add_edge(up_hash, down_hash) + for node_hash, node in reconstructed.items(): + if node_hash not in pipeline._hash_graph: + pipeline._hash_graph.add_node(node_hash) + attrs = pipeline._hash_graph.nodes[node_hash] + attrs["node_type"] = node.node_type + if node.label: + attrs["label"] = node.label + + pipeline._compiled = True + + return pipeline + + @staticmethod + def _load_source_node( + descriptor: dict[str, Any], + mode: str, + resolve_source_from_config: Any, + ) -> SourceNode: + """Reconstruct a SourceNode from a descriptor. + + Args: + descriptor: The serialized node descriptor. + mode: Load mode (``"full"`` or ``"read_only"``). + resolve_source_from_config: Callable to reconstruct a source. + + Returns: + A ``SourceNode`` instance. + """ + + reconstructable = descriptor.get("reconstructable", False) + source_config = descriptor.get("source_config") + + stream = None + if reconstructable and mode != "read_only" and source_config is not None: + try: + stream = resolve_source_from_config(source_config) + except Exception: + logger.warning( + "Failed to reconstruct source %r, falling back to read-only.", + descriptor.get("label"), + ) + stream = None + + return SourceNode.from_descriptor(descriptor, stream=stream, databases={}) + + @staticmethod + def _load_function_node( + descriptor: dict[str, Any], + mode: str, + upstream_node: Any | None, + upstream_usable: bool, + databases: dict[str, Any], + ) -> FunctionNode: + """Reconstruct a FunctionNode from a descriptor. + + Args: + descriptor: The serialized node descriptor. + mode: Load mode. + upstream_node: The reconstructed upstream node, or ``None``. + upstream_usable: Whether the upstream is in FULL mode. + databases: Database role mapping. + + Returns: + A ``FunctionNode`` instance. + """ + from orcapod.core.function_pod import FunctionPod + + if mode == "full": + if not upstream_usable: + logger.warning( + "Upstream for function node %r is not usable, " + "falling back to read-only.", + descriptor.get("label"), + ) + else: + try: + pod = FunctionPod.from_config(descriptor["function_pod"]) + return FunctionNode.from_descriptor( + descriptor, + function_pod=pod, + input_stream=upstream_node, + databases=databases, + ) + except Exception: + logger.warning( + "Failed to reconstruct function node %r, " + "falling back to read-only.", + descriptor.get("label"), + ) + + return FunctionNode.from_descriptor( + descriptor, + function_pod=None, + input_stream=None, + databases=databases, + ) + + @staticmethod + def _load_operator_node( + descriptor: dict[str, Any], + mode: str, + upstream_nodes: tuple, + all_upstreams_usable: bool, + databases: dict[str, Any], + resolve_operator_from_config: Any, + ) -> "OperatorNode": + """Reconstruct a OperatorNode from a descriptor. + + Args: + descriptor: The serialized node descriptor. + mode: Load mode. + upstream_nodes: Tuple of reconstructed upstream nodes. + all_upstreams_usable: Whether all upstreams are in FULL mode. + databases: Database role mapping. + resolve_operator_from_config: Callable to reconstruct an operator. + + Returns: + A ``OperatorNode`` instance. + """ + from orcapod.core.nodes import OperatorNode + + if mode != "read_only": + if not all_upstreams_usable: + logger.warning( + "Upstream(s) for operator node %r are not usable, " + "falling back to read-only.", + descriptor.get("label"), + ) + else: + try: + op = resolve_operator_from_config(descriptor["operator"]) + return OperatorNode.from_descriptor( + descriptor, + operator=op, + input_streams=upstream_nodes, + databases=databases, + ) + except Exception: + logger.warning( + "Failed to reconstruct operator node %r, " + "falling back to read-only.", + descriptor.get("label"), + ) + + return OperatorNode.from_descriptor( + descriptor, + operator=None, + input_streams=(), + databases=databases, + ) + + # ------------------------------------------------------------------ + # Node access by label + # ------------------------------------------------------------------ def __getattr__(self, item: str) -> Any: - """Allow direct access to pipeline attributes.""" - if item in self._nodes: - return self._nodes[item] + # Use __dict__ to avoid recursion during __init__ + nodes = self.__dict__.get("_nodes", {}) + if item in nodes: + return nodes[item] raise AttributeError(f"Pipeline has no attribute '{item}'") def __dir__(self) -> list[str]: - """Return a list of attributes and methods of the pipeline.""" return list(super().__dir__()) + list(self._nodes.keys()) - def rename(self, old_name: str, new_name: str) -> None: - """ - Rename a node in the pipeline. - This will update the label and the internal mapping. - """ - if old_name not in self._nodes: - raise KeyError(f"Node '{old_name}' does not exist in the pipeline.") - if new_name in self._nodes: - raise KeyError(f"Node '{new_name}' already exists in the pipeline.") - node = self._nodes[old_name] - del self._nodes[old_name] - node.label = new_name - self._nodes[new_name] = node - logger.info(f"Node '{old_name}' renamed to '{new_name}'") + +# =========================================================================== +# Graph Rendering Utilities +# =========================================================================== class GraphRenderer: @@ -340,8 +993,8 @@ class GraphRenderer: "dpi": 150, # HTML Label defaults "main_font_size": 14, # Main label font size - "type_font_size": 11, # Pod type font size (small) - "type_style": "normal", # Pod type text style + "type_font_size": 11, # PodProtocol type font size (small) + "type_style": "normal", # PodProtocol type text style } DEFAULT_STYLE_RULES = { @@ -586,8 +1239,8 @@ def render_graph( dot.render(name, format=format_type, cleanup=True) print(f"Graph saved to {output_path}") - import matplotlib.pyplot as plt import matplotlib.image as mpimg + import matplotlib.pyplot as plt if show: with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: @@ -596,7 +1249,6 @@ def render_graph( plt.figure(figsize=figsize, dpi=dpi) plt.imshow(img) plt.axis("off") - # plt.title("Example Graph") plt.tight_layout() plt.show() os.unlink(tmp.name) @@ -694,199 +1346,3 @@ def create_custom_rules( "type_font_color": kernel_type_fcolor, }, } - - -# import networkx as nx -# # import graphviz -# import matplotlib.pyplot as plt -# import matplotlib.image as mpimg -# import tempfile -# import os - - -# class GraphRenderer: -# """Simple renderer for NetworkX graphs using Graphviz DOT format""" - -# def __init__(self): -# """Initialize the renderer""" -# pass - -# def _sanitize_node_id(self, node_id: Any) -> str: -# """Convert node_id to a valid DOT identifier using hash""" -# return f"node_{hash(node_id)}" - -# def _get_node_label( -# self, node_id: Any, label_lut: dict[Any, str] | None = None -# ) -> str: -# """Get label for a node""" -# if label_lut and node_id in label_lut: -# return label_lut[node_id] -# return str(node_id) - -# def generate_dot( -# self, -# graph: "nx.DiGraph", -# label_lut: dict[Any, str] | None = None, -# rankdir: str = "TB", -# node_shape: str = "box", -# node_style: str = "filled", -# node_color: str = "lightblue", -# edge_color: str = "black", -# dpi: int = 150, -# ) -> str: -# """ -# Generate DOT syntax from NetworkX graph - -# Args: -# graph: NetworkX DiGraph to render -# label_lut: Optional dictionary mapping node_id -> display_label -# rankdir: Graph direction ('TB', 'BT', 'LR', 'RL') -# node_shape: Shape for all nodes -# node_style: Style for all nodes -# node_color: Fill color for all nodes -# edge_color: Color for all edges -# dpi: Resolution for rendered image (default 150) - -# Returns: -# DOT format string -# """ -# try: -# import graphviz -# except ImportError as e: -# raise ImportError( -# "Graphviz is not installed. Please install graphviz to render graph of the pipeline." -# ) from e - -# dot = graphviz.Digraph(comment="NetworkX Graph") - -# # Set graph attributes -# dot.attr(rankdir=rankdir, dpi=str(dpi)) -# dot.attr("node", shape=node_shape, style=node_style, fillcolor=node_color) -# dot.attr("edge", color=edge_color) - -# # Add nodes -# for node_id in graph.nodes(): -# sanitized_id = self._sanitize_node_id(node_id) -# label = self._get_node_label(node_id, label_lut) -# dot.node(sanitized_id, label=label) - -# # Add edges -# for source, target in graph.edges(): -# source_id = self._sanitize_node_id(source) -# target_id = self._sanitize_node_id(target) -# dot.edge(source_id, target_id) - -# return dot.source - -# def render_graph( -# self, -# graph: nx.DiGraph, -# label_lut: dict[Any, str] | None = None, -# show: bool = True, -# output_path: str | None = None, -# raw_output: bool = False, -# rankdir: str = "TB", -# figsize: tuple = (6, 4), -# dpi: int = 150, -# **style_kwargs, -# ) -> str | None: -# """ -# Render NetworkX graph using Graphviz - -# Args: -# graph: NetworkX DiGraph to render -# label_lut: Optional dictionary mapping node_id -> display_label -# show: Display the graph using matplotlib -# output_path: Save graph to file (e.g., 'graph.png', 'graph.pdf') -# raw_output: Return DOT syntax instead of rendering -# rankdir: Graph direction ('TB', 'BT', 'LR', 'RL') -# figsize: Figure size for matplotlib display -# dpi: Resolution for rendered image (default 150) -# **style_kwargs: Additional styling (node_color, edge_color, node_shape, etc.) - -# Returns: -# DOT syntax if raw_output=True, None otherwise -# """ -# try: -# import graphviz -# except ImportError as e: -# raise ImportError( -# "Graphviz is not installed. Please install graphviz to render graph of the pipeline." -# ) from e - -# if raw_output: -# return self.generate_dot(graph, label_lut, rankdir, dpi=dpi, **style_kwargs) - -# # Create Graphviz object -# dot = graphviz.Digraph(comment="NetworkX Graph") -# dot.attr(rankdir=rankdir, dpi=str(dpi)) - -# # Apply styling -# node_shape = style_kwargs.get("node_shape", "box") -# node_style = style_kwargs.get("node_style", "filled") -# node_color = style_kwargs.get("node_color", "lightblue") -# edge_color = style_kwargs.get("edge_color", "black") - -# dot.attr("node", shape=node_shape, style=node_style, fillcolor=node_color) -# dot.attr("edge", color=edge_color) - -# # Add nodes with labels -# for node_id in graph.nodes(): -# sanitized_id = self._sanitize_node_id(node_id) -# label = self._get_node_label(node_id, label_lut) -# dot.node(sanitized_id, label=label) - -# # Add edges -# for source, target in graph.edges(): -# source_id = self._sanitize_node_id(source) -# target_id = self._sanitize_node_id(target) -# dot.edge(source_id, target_id) - -# # Handle output -# if output_path: -# # Save to file -# name, ext = os.path.splitext(output_path) -# format_type = ext[1:] if ext else "png" -# dot.render(name, format=format_type, cleanup=True) -# print(f"Graph saved to {output_path}") - -# if show: -# # Display with matplotlib -# with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: -# dot.render(tmp.name[:-4], format="png", cleanup=True) - -# import matplotlib.pyplot as plt -# import matplotlib.image as mpimg - -# # Display with matplotlib -# img = mpimg.imread(tmp.name) -# plt.figure(figsize=figsize) -# plt.imshow(img) -# plt.axis("off") -# plt.title("Graph Visualization") -# plt.tight_layout() -# plt.show() - -# # Clean up -# os.unlink(tmp.name) - -# return None - - -# # Convenience function for quick rendering -# def render_graph( -# graph: nx.DiGraph, label_lut: dict[Any, str] | None = None, **kwargs -# ) -> str | None: -# """ -# Convenience function to quickly render a NetworkX graph - -# Args: -# graph: NetworkX DiGraph to render -# label_lut: Optional dictionary mapping node_id -> display_label -# **kwargs: All other arguments passed to GraphRenderer.render_graph() - -# Returns: -# DOT syntax if raw_output=True, None otherwise -# """ -# renderer = GraphRenderer() -# return renderer.render_graph(graph, label_lut, **kwargs) diff --git a/src/orcapod/pipeline/logging_capture.py b/src/orcapod/pipeline/logging_capture.py new file mode 100644 index 00000000..e2d8cee2 --- /dev/null +++ b/src/orcapod/pipeline/logging_capture.py @@ -0,0 +1,235 @@ +"""Capture infrastructure for observability logging. + +Provides context-variable-local capture of stdout, stderr, and Python logging +for use in FunctionNode execution. Thread-safe and asyncio-task-safe via +``contextvars.ContextVar`` — captures from concurrent packets never intermingle. + +CapturedLogs travel as part of the return type through the call chain +(``direct_call`` → ``call`` → ``process_packet`` → FunctionNode) so there +is no ContextVar side-channel for logs. Each executor's ``execute_callable`` +returns ``(raw_result, CapturedLogs)``, and ``direct_call`` returns +``(output_packet, CapturedLogs)`` — catching user-function exceptions +internally rather than re-raising. + +Typical usage +------------- +Call ``install_capture_streams()`` once when a logging Observer is created. +The executor or ``direct_call`` wraps function execution in +``LocalCaptureContext`` and returns CapturedLogs alongside the result:: + + result, captured = packet_function.call(packet) + pkt_logger.record(captured) +""" + +from __future__ import annotations + +import contextvars +import io +import logging +import sys +from dataclasses import dataclass +from typing import Any + + +# --------------------------------------------------------------------------- +# CapturedLogs +# --------------------------------------------------------------------------- + + +@dataclass +class CapturedLogs: + """I/O captured from a single packet function execution.""" + + stdout: str = "" + stderr: str = "" + python_logs: str = "" + traceback: str | None = None + success: bool = True + + +# --------------------------------------------------------------------------- +# Context variables +# --------------------------------------------------------------------------- +# Each asyncio task and thread gets its own copy of these variables, so +# captures from concurrent packets never intermingle. + +_stdout_capture: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_stdout_capture", default=None +) +_stderr_capture: contextvars.ContextVar[io.StringIO | None] = contextvars.ContextVar( + "_stderr_capture", default=None +) +_log_capture: contextvars.ContextVar[list[str] | None] = contextvars.ContextVar( + "_log_capture", default=None +) + + +# --------------------------------------------------------------------------- +# ContextLocalTeeStream +# --------------------------------------------------------------------------- + + +class ContextLocalTeeStream: + """A stream that writes to the original *and* a per-context capture buffer. + + All writes go to *original* (terminal output is preserved) and also to a + ``StringIO`` buffer active for the current asyncio task / thread (selected + via ``capture_var``). Concurrent tasks each have their own buffer and do + not interfere with each other. + """ + + def __init__( + self, + original: Any, + capture_var: contextvars.ContextVar[io.StringIO | None], + ) -> None: + self._original = original + self._capture_var = capture_var + + def write(self, s: str) -> int: + buf = self._capture_var.get() + if buf is not None: + buf.write(s) + return self._original.write(s) + + def flush(self) -> None: + buf = self._capture_var.get() + if buf is not None: + buf.flush() + self._original.flush() + + def __getattr__(self, name: str) -> Any: + return getattr(self._original, name) + + +# --------------------------------------------------------------------------- +# ContextVarLoggingHandler +# --------------------------------------------------------------------------- + + +class ContextVarLoggingHandler(logging.Handler): + """A logging handler that captures records into a per-context buffer. + + When a capture buffer is active for the current context (asyncio task or + thread), log records are formatted and appended to it. When no buffer is + active the record is silently discarded (not duplicated to other handlers). + """ + + def emit(self, record: logging.LogRecord) -> None: + buf = _log_capture.get() + if buf is not None: + buf.append(self.format(record)) + + +# --------------------------------------------------------------------------- +# Global installation (idempotent) +# --------------------------------------------------------------------------- + +_installed = False +_logging_handler: ContextVarLoggingHandler | None = None + + +def install_capture_streams() -> None: + """Install tee streams and the logging handler globally. + + Idempotent — safe to call multiple times. Should be called once when a + concrete logging Observer is instantiated. + + After installation: + + * ``sys.stdout`` / ``sys.stderr`` tee writes to per-context buffers while + also forwarding to the original stream (terminal output preserved). + * The root logger gains a ``ContextVarLoggingHandler`` that captures + records to per-context buffers (covering Python ``logging`` calls). + + .. note:: + Subprocess and C-extension output bypasses Python's stream objects and + goes directly to file descriptors 1/2. For local execution these are + *not* captured (but are still visible in the terminal). Ray remote + execution uses fd-level capture via + ``RayExecutor._make_capture_wrapper``. + + The stream check runs on every call so that if something (e.g. a test + harness) replaces ``sys.stdout``/``sys.stderr`` between calls we + re-wrap the new stream. The logging handler is only added once. + """ + global _installed, _logging_handler + + # Always re-check in case sys.stdout/stderr was replaced (e.g. by pytest). + if not isinstance(sys.stdout, ContextLocalTeeStream): + sys.stdout = ContextLocalTeeStream(sys.stdout, _stdout_capture) + if not isinstance(sys.stderr, ContextLocalTeeStream): + sys.stderr = ContextLocalTeeStream(sys.stderr, _stderr_capture) + + if _installed: + return + + _logging_handler = ContextVarLoggingHandler() + _logging_handler.setFormatter( + logging.Formatter("%(levelname)s:%(name)s:%(message)s") + ) + logging.getLogger().addHandler(_logging_handler) + + _installed = True + + +# --------------------------------------------------------------------------- +# LocalCaptureContext +# --------------------------------------------------------------------------- + + +class LocalCaptureContext: + """Context manager that activates per-context capture for one packet. + + Requires ``install_capture_streams()`` to have been called; without it the + ContextVars are set but nothing tees into them, so captured strings will be + empty (acceptable when no logging Observer is configured). + + Example:: + + ctx = LocalCaptureContext() + try: + with ctx: + result = call_something() + except Exception: + captured = ctx.get_captured(success=False, tb=traceback.format_exc()) + else: + captured = ctx.get_captured(success=True) + """ + + def __init__(self) -> None: + self._stdout_buf = io.StringIO() + self._stderr_buf = io.StringIO() + self._log_buf: list[str] = [] + self._tokens: list[contextvars.Token] = [] + + def __enter__(self) -> "LocalCaptureContext": + self._tokens = [ + _stdout_capture.set(self._stdout_buf), + _stderr_capture.set(self._stderr_buf), + _log_capture.set(self._log_buf), + ] + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> bool: + for token, var in zip( + self._tokens, + [_stdout_capture, _stderr_capture, _log_capture], + ): + var.reset(token) + return False # do not suppress exceptions + + def get_captured( + self, + success: bool, + tb: str | None = None, + ) -> CapturedLogs: + """Return a :class:`CapturedLogs` from what was captured in this context.""" + return CapturedLogs( + stdout=self._stdout_buf.getvalue(), + stderr=self._stderr_buf.getvalue(), + python_logs="\n".join(self._log_buf) if self._log_buf else "", + traceback=tb, + success=success, + ) + diff --git a/src/orcapod/pipeline/logging_observer.py b/src/orcapod/pipeline/logging_observer.py new file mode 100644 index 00000000..6db61312 --- /dev/null +++ b/src/orcapod/pipeline/logging_observer.py @@ -0,0 +1,272 @@ +"""Concrete logging observer for orcapod pipelines. + +Provides :class:`LoggingObserver`, a drop-in observer that captures stdout, +stderr, Python logging, and tracebacks from every packet execution and writes +structured log rows to any :class:`~orcapod.protocols.database_protocols.ArrowDatabaseProtocol` +(in-memory, Delta Lake, etc.). + +Typical usage:: + + from orcapod.pipeline.logging_observer import LoggingObserver + from orcapod.pipeline import SyncPipelineOrchestrator + from orcapod.databases import InMemoryArrowDatabase + + obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + pipeline.run(orchestrator=SyncPipelineOrchestrator(observer=obs)) + + # Inspect captured logs + logs = obs.get_logs() # pyarrow.Table + logs.to_pandas() # pandas DataFrame + +Log schema (fixed columns) +-------------------------- +.. list-table:: + :header-rows: 1 + + * - Column + - Type + - Description + * - ``log_id`` + - ``large_utf8`` + - UUID unique to this log entry + * - ``run_id`` + - ``large_utf8`` + - UUID of the pipeline run (from ``on_run_start``) + * - ``node_label`` + - ``large_utf8`` + - Label of the function node + * - ``stdout`` + - ``large_utf8`` + - Captured standard output + * - ``stderr`` + - ``large_utf8`` + - Captured standard error + * - ``python_logs`` + - ``large_utf8`` + - Python ``logging`` output captured during execution + * - ``traceback`` + - ``large_utf8`` + - Full traceback on failure; ``None`` on success + * - ``success`` + - ``bool_`` + - ``True`` if the packet function returned normally + * - ``timestamp`` + - ``large_utf8`` + - ISO-8601 UTC timestamp when ``record()`` was called + +In addition, each tag key from the packet's tag becomes a separate +``large_utf8`` column (queryable, not JSON-encoded). + +Log storage +----------- +Logs are stored at a pipeline-path-mirrored location: +``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. +Each function node gets its own log table. Use +``get_logs(pipeline_path=node.pipeline_path)`` to retrieve +node-specific logs. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +from uuid_utils import uuid7 + +from orcapod.pipeline.logging_capture import CapturedLogs, install_capture_streams + +if TYPE_CHECKING: + import pyarrow as pa + + from orcapod.protocols.database_protocols import ArrowDatabaseProtocol + +logger = logging.getLogger(__name__) + +# Default path (table name) within the database where log rows are stored. +DEFAULT_LOG_PATH: tuple[str, ...] = ("execution_logs",) + + +class PacketLogger: + """Context-bound logger created by :class:`LoggingObserver` per packet. + + Holds all context needed to write a structured log row + (run_id, node_label, tag data) so the caller only needs to pass the + :class:`~orcapod.pipeline.logging_capture.CapturedLogs` payload. + + Tag data is stored as individual queryable columns (not JSON) alongside + the fixed log columns. + + This class is not intended to be instantiated directly — use + :meth:`LoggingObserver.create_packet_logger` instead. + """ + + def __init__( + self, + db: "ArrowDatabaseProtocol", + log_path: tuple[str, ...], + run_id: str, + node_label: str, + tag_data: dict[str, Any], + ) -> None: + self._db = db + self._log_path = log_path + self._run_id = run_id + self._node_label = node_label + self._tag_data = tag_data + + def record(self, captured: CapturedLogs) -> None: + """Write one log row to the database.""" + import pyarrow as pa + + log_id = str(uuid7()) + timestamp = datetime.now(timezone.utc).isoformat() + + # Fixed columns + columns: dict[str, pa.Array] = { + "log_id": pa.array([log_id], type=pa.large_utf8()), + "run_id": pa.array([self._run_id], type=pa.large_utf8()), + "node_label": pa.array([self._node_label], type=pa.large_utf8()), + "stdout": pa.array([captured.stdout], type=pa.large_utf8()), + "stderr": pa.array([captured.stderr], type=pa.large_utf8()), + "python_logs": pa.array([captured.python_logs], type=pa.large_utf8()), + "traceback": pa.array([captured.traceback], type=pa.large_utf8()), + "success": pa.array([captured.success], type=pa.bool_()), + "timestamp": pa.array([timestamp], type=pa.large_utf8()), + } + + # Dynamic tag columns — each tag key becomes its own column + for key, value in self._tag_data.items(): + columns[key] = pa.array([str(value)], type=pa.large_utf8()) + + row = pa.table(columns) + try: + self._db.add_record(self._log_path, log_id, row, flush=True) + except Exception: + logger.exception( + "LoggingObserver: failed to write log row for node=%s", + self._node_label, + ) + + +class LoggingObserver: + """Concrete observer that writes packet execution logs to a database. + + Instantiate once, outside the pipeline, and pass to the orchestrator:: + + obs = LoggingObserver(log_database=InMemoryArrowDatabase()) + orch = SyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + # After the run, read back captured logs: + logs_table = obs.get_logs() # pyarrow.Table + + For async / Ray pipelines use :class:`~orcapod.pipeline.AsyncPipelineOrchestrator` + with the same observer:: + + orch = AsyncPipelineOrchestrator(observer=obs) + pipeline.run(orchestrator=orch) + + Args: + log_database: Any :class:`~orcapod.protocols.database_protocols.ArrowDatabaseProtocol` + instance — :class:`~orcapod.databases.InMemoryArrowDatabase`, + a Delta Lake database, etc. + log_path: Tuple of strings identifying the table within the database. + Defaults to ``("execution_logs",)``. + + Note: + Construction calls :func:`~orcapod.pipeline.logging_capture.install_capture_streams` + so that stdout/stderr tee-capture is active from the moment the observer + is created. + """ + + def __init__( + self, + log_database: "ArrowDatabaseProtocol", + log_path: tuple[str, ...] | None = None, + ) -> None: + self._db = log_database + self._log_path = log_path or DEFAULT_LOG_PATH + self._current_run_id: str = "" + # Activate tee-capture as soon as the observer is created. + install_capture_streams() + + # -- lifecycle hooks -- + + def on_run_start(self, run_id: str) -> None: + self._current_run_id = run_id + + def on_run_end(self, run_id: str) -> None: + pass + + def on_node_start(self, node: Any) -> None: + pass + + def on_node_end(self, node: Any) -> None: + pass + + def on_packet_start(self, node: Any, tag: Any, packet: Any) -> None: + pass + + def on_packet_end( + self, + node: Any, + tag: Any, + input_packet: Any, + output_packet: Any, + cached: bool, + ) -> None: + pass + + def on_packet_crash(self, node: Any, tag: Any, packet: Any, error: Exception) -> None: + pass + + def create_packet_logger( + self, + node: Any, + tag: Any, + packet: Any, + pipeline_path: tuple[str, ...] = (), + ) -> PacketLogger: + """Return a :class:`PacketLogger` bound to *node* + *tag* context. + + Log rows are stored at a pipeline-path-mirrored location: + ``pipeline_path[:1] + ("logs",) + pipeline_path[1:]``. This gives + each function node its own log table in the database. + """ + node_label = getattr(node, "label", None) or getattr(node, "node_type", "unknown") + tag_data = dict(tag) + + # Compute mirrored log path + if pipeline_path: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._log_path + + return PacketLogger( + db=self._db, + log_path=log_path, + run_id=self._current_run_id, + node_label=node_label, + tag_data=tag_data, + ) + + # -- convenience -- + + def get_logs( + self, pipeline_path: tuple[str, ...] | None = None + ) -> "pa.Table | None": + """Read log rows from the database as a :class:`pyarrow.Table`. + + Args: + pipeline_path: If provided, reads logs for a specific node + (mirrored path). If ``None``, reads from the default + log path. + + Returns ``None`` if no logs have been written yet. + """ + if pipeline_path is not None: + log_path = pipeline_path[:1] + ("logs",) + pipeline_path[1:] + else: + log_path = self._log_path + return self._db.get_all_records(log_path) diff --git a/src/orcapod/pipeline/nodes.py b/src/orcapod/pipeline/nodes.py deleted file mode 100644 index af639714..00000000 --- a/src/orcapod/pipeline/nodes.py +++ /dev/null @@ -1,505 +0,0 @@ -from abc import abstractmethod -from orcapod.core.kernels import KernelStream, WrappedKernel -from orcapod.core.sources.base import InvocationBase -from orcapod.core.pods import CachedPod -from orcapod.protocols import core_protocols as cp, database_protocols as dbp -from orcapod.types import PythonSchema -from orcapod.utils.lazy_module import LazyModule -from typing import TYPE_CHECKING, Any -from orcapod.core.system_constants import constants -from orcapod.utils import arrow_utils -from collections.abc import Collection -from orcapod.core.streams import PodNodeStream - -if TYPE_CHECKING: - import pyarrow as pa - import polars as pl - import pandas as pd -else: - pa = LazyModule("pyarrow") - pl = LazyModule("polars") - pd = LazyModule("pandas") - - -class NodeBase( - InvocationBase, -): - """ - Mixin class for pipeline nodes - """ - - def __init__( - self, - input_streams: Collection[cp.Stream], - pipeline_database: dbp.ArrowDatabase, - pipeline_path_prefix: tuple[str, ...] = (), - kernel_type: str = "operator", - **kwargs, - ): - super().__init__(**kwargs) - self.kernel_type = kernel_type - self._cached_stream: KernelStream | None = None - self._input_streams = tuple(input_streams) - self._pipeline_path_prefix = pipeline_path_prefix - # compute invocation hash - note that empty () is passed into identity_structure to signify - # identity structure of invocation with no input streams - self.pipeline_node_hash = self.data_context.object_hasher.hash_object( - self.identity_structure(()) - ).to_string() - tag_types, packet_types = self.types(include_system_tags=True) - - self.tag_schema_hash = self.data_context.object_hasher.hash_object( - tag_types - ).to_string() - - self.packet_schema_hash = self.data_context.object_hasher.hash_object( - packet_types - ).to_string() - - self.pipeline_database = pipeline_database - - @property - def id(self) -> str: - return self.content_hash().to_string() - - @property - def upstreams(self) -> tuple[cp.Stream, ...]: - return self._input_streams - - def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> None: - # Node invocation should not be tracked - return None - - @property - def contained_kernel(self) -> cp.Kernel: - raise NotImplementedError( - "This property should be implemented by subclasses to return the contained kernel." - ) - - @property - def reference(self) -> tuple[str, ...]: - return self.contained_kernel.reference - - @property - @abstractmethod - def pipeline_path(self) -> tuple[str, ...]: - """ - Return the path to the pipeline run records. - This is used to store the run-associated tag info. - """ - ... - - def validate_inputs(self, *streams: cp.Stream) -> None: - return - - # def forward(self, *streams: cp.Stream) -> cp.Stream: - # # TODO: re-evaluate the use here -- consider semi joining with input streams - # # super().validate_inputs(*self.input_streams) - # return super().forward(*self.upstreams) # type: ignore[return-value] - - def pre_kernel_processing(self, *streams: cp.Stream) -> tuple[cp.Stream, ...]: - return self.upstreams - - def kernel_output_types( - self, *streams: cp.Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Return the output types of the node. - This is used to determine the types of the output streams. - """ - return self.contained_kernel.output_types( - *self.upstreams, include_system_tags=include_system_tags - ) - - def kernel_identity_structure( - self, streams: Collection[cp.Stream] | None = None - ) -> Any: - # construct identity structure from the node's information and the - return self.contained_kernel.identity_structure(self.upstreams) - - def get_all_records( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - """ - Retrieve all records associated with the node. - If include_system_columns is True, system columns will be included in the result. - """ - raise NotImplementedError("This method should be implemented by subclasses.") - - def flush(self): - self.pipeline_database.flush() - - -class KernelNode(NodeBase, WrappedKernel): - """ - A node in the pipeline that represents a kernel. - This node can be used to execute the kernel and process data streams. - """ - - HASH_COLUMN_NAME = "_record_hash" - - def __init__( - self, - kernel: cp.Kernel, - input_streams: Collection[cp.Stream], - pipeline_database: dbp.ArrowDatabase, - pipeline_path_prefix: tuple[str, ...] = (), - **kwargs, - ) -> None: - super().__init__( - kernel=kernel, - input_streams=input_streams, - pipeline_database=pipeline_database, - pipeline_path_prefix=pipeline_path_prefix, - **kwargs, - ) - self.skip_recording = True - - @property - def contained_kernel(self) -> cp.Kernel: - return self.kernel - - def __repr__(self): - return f"KernelNode(kernel={self.kernel!r})" - - def __str__(self): - return f"KernelNode:{self.kernel!s}" - - def forward(self, *streams: cp.Stream) -> cp.Stream: - output_stream = super().forward(*streams) - - if not self.skip_recording: - self.record_pipeline_output(output_stream) - return output_stream - - def record_pipeline_output(self, output_stream: cp.Stream) -> None: - key_column_name = self.HASH_COLUMN_NAME - # FIXME: compute record id based on each record in its entirety - output_table = output_stream.as_table( - include_data_context=True, - include_system_tags=True, - include_source=True, - ) - # compute hash for output_table - # include system tags - columns_to_hash = ( - output_stream.tag_keys(include_system_tags=True) - + output_stream.packet_keys() - ) - - arrow_hasher = self.data_context.arrow_hasher - record_hashes = [] - table_to_hash = output_table.select(columns_to_hash) - - for record_batch in table_to_hash.to_batches(): - for i in range(len(record_batch)): - record_hashes.append( - arrow_hasher.hash_table(record_batch.slice(i, 1)).to_hex() - ) - # add the hash column - output_table = output_table.add_column( - 0, key_column_name, pa.array(record_hashes, type=pa.large_string()) - ) - - self.pipeline_database.add_records( - self.pipeline_path, - output_table, - record_id_column=key_column_name, - skip_duplicates=True, - ) - - @property - def pipeline_path(self) -> tuple[str, ...]: - """ - Return the path to the pipeline run records. - This is used to store the run-associated tag info. - """ - return ( - self._pipeline_path_prefix # pipeline ID - + self.reference # node ID - + ( - f"node:{self.pipeline_node_hash}", # pipeline node ID - f"packet:{self.packet_schema_hash}", # packet schema ID - f"tag:{self.tag_schema_hash}", # tag schema ID - ) - ) - - def get_all_records( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - results = self.pipeline_database.get_all_records(self.pipeline_path) - - if results is None: - return None - - if not include_system_columns: - system_columns = [ - c - for c in results.column_names - if c.startswith(constants.META_PREFIX) - or c.startswith(constants.DATAGRAM_PREFIX) - ] - results = results.drop(system_columns) - - return results - - -class PodNode(NodeBase, CachedPod): - def __init__( - self, - pod: cp.Pod, - input_streams: Collection[cp.Stream], - pipeline_database: dbp.ArrowDatabase, - result_database: dbp.ArrowDatabase | None = None, - record_path_prefix: tuple[str, ...] = (), - pipeline_path_prefix: tuple[str, ...] = (), - **kwargs, - ) -> None: - super().__init__( - pod=pod, - result_database=result_database, - record_path_prefix=record_path_prefix, - input_streams=input_streams, - pipeline_database=pipeline_database, - pipeline_path_prefix=pipeline_path_prefix, - **kwargs, - ) - self._execution_engine_opts: dict[str, Any] = {} - - @property - def execution_engine_opts(self) -> dict[str, Any]: - return self._execution_engine_opts.copy() - - @execution_engine_opts.setter - def execution_engine_opts(self, opts: dict[str, Any]) -> None: - self._execution_engine_opts = opts - - def flush(self): - self.pipeline_database.flush() - if self.result_database is not None: - self.result_database.flush() - - @property - def contained_kernel(self) -> cp.Kernel: - return self.pod - - @property - def pipeline_path(self) -> tuple[str, ...]: - """ - Return the path to the pipeline run records. - This is used to store the run-associated tag info. - """ - return ( - self._pipeline_path_prefix # pipeline ID - + self.reference # node ID - + ( - f"node:{self.pipeline_node_hash}", # pipeline node ID - f"tag:{self.tag_schema_hash}", # tag schema ID - ) - ) - - def __repr__(self): - return f"PodNode(pod={self.pod!r})" - - def __str__(self): - return f"PodNode:{self.pod!s}" - - def call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, - ) -> tuple[cp.Tag, cp.Packet | None]: - execution_engine_hash = execution_engine.name if execution_engine else "default" - if record_id is None: - record_id = self.get_record_id(packet, execution_engine_hash) - - combined_execution_engine_opts = self.execution_engine_opts - if execution_engine_opts is not None: - combined_execution_engine_opts.update(execution_engine_opts) - - - tag, output_packet = super().call( - tag, - packet, - record_id=record_id, - skip_cache_lookup=skip_cache_lookup, - skip_cache_insert=skip_cache_insert, - execution_engine=execution_engine, - execution_engine_opts=combined_execution_engine_opts, - ) - - # if output_packet is not None: - # retrieved = ( - # output_packet.get_meta_value(self.DATA_RETRIEVED_FLAG) is not None - # ) - # # add pipeline record if the output packet is not None - # # TODO: verify cache lookup logic - # self.add_pipeline_record( - # tag, - # packet, - # record_id, - # retrieved=retrieved, - # skip_cache_lookup=skip_cache_lookup, - # ) - return tag, output_packet - - async def async_call( - self, - tag: cp.Tag, - packet: cp.Packet, - record_id: str | None = None, - execution_engine: cp.ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, - ) -> tuple[cp.Tag, cp.Packet | None]: - execution_engine_hash = execution_engine.name if execution_engine else "default" - if record_id is None: - record_id = self.get_record_id(packet, execution_engine_hash) - - - combined_execution_engine_opts = self.execution_engine_opts - if execution_engine_opts is not None: - combined_execution_engine_opts.update(execution_engine_opts) - - - tag, output_packet = await super().async_call( - tag, - packet, - record_id=record_id, - skip_cache_lookup=skip_cache_lookup, - skip_cache_insert=skip_cache_insert, - execution_engine=execution_engine, - execution_engine_opts=combined_execution_engine_opts, - ) - - if output_packet is not None: - retrieved = ( - output_packet.get_meta_value(self.DATA_RETRIEVED_FLAG) is not None - ) - # add pipeline record if the output packet is not None - # TODO: verify cache lookup logic - self.add_pipeline_record( - tag, - packet, - record_id, - retrieved=retrieved, - skip_cache_lookup=skip_cache_lookup, - ) - return tag, output_packet - - def add_pipeline_record( - self, - tag: cp.Tag, - input_packet: cp.Packet, - packet_record_id: str, - retrieved: bool | None = None, - skip_cache_lookup: bool = False, - ) -> None: - # combine dp.Tag with packet content hash to compute entry hash - # TODO: add system tag columns - # TODO: consider using bytes instead of string representation - tag_with_hash = tag.as_table(include_system_tags=True).append_column( - constants.INPUT_PACKET_HASH, - pa.array([input_packet.content_hash().to_string()], type=pa.large_string()), - ) - - # unique entry ID is determined by the combination of tags, system_tags, and input_packet hash - entry_id = self.data_context.arrow_hasher.hash_table(tag_with_hash).to_string() - - # check presence of an existing entry with the same entry_id - existing_record = None - if not skip_cache_lookup: - existing_record = self.pipeline_database.get_record_by_id( - self.pipeline_path, - entry_id, - ) - - if existing_record is not None: - # if the record already exists, then skip - return - - # rename all keys to avoid potential collision with result columns - renamed_input_packet = input_packet.rename( - {k: f"_input_{k}" for k in input_packet.keys()} - ) - input_packet_info = ( - renamed_input_packet.as_table(include_source=True) - .append_column( - constants.PACKET_RECORD_ID, - pa.array([packet_record_id], type=pa.large_string()), - ) - .append_column( - f"{constants.META_PREFIX}input_packet{constants.CONTEXT_KEY}", - pa.array([input_packet.data_context_key], type=pa.large_string()), - ) - .append_column( - self.DATA_RETRIEVED_FLAG, - pa.array([retrieved], type=pa.bool_()), - ) - .drop_columns(list(renamed_input_packet.keys())) - ) - - combined_record = arrow_utils.hstack_tables( - tag.as_table(include_system_tags=True), input_packet_info - ) - - self.pipeline_database.add_record( - self.pipeline_path, - entry_id, - combined_record, - skip_duplicates=False, - ) - - def forward(self, *streams: cp.Stream) -> cp.Stream: - # TODO: re-evaluate the use here -- consider semi joining with input streams - # super().validate_inputs(*self.input_streams) - return PodNodeStream(self, *self.upstreams) # type: ignore[return-value] - - def get_all_records( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - results = self.result_database.get_all_records( - self.record_path, record_id_column=constants.PACKET_RECORD_ID - ) - - if self.pipeline_database is None: - raise ValueError( - "Pipeline database is not configured, cannot retrieve tag info" - ) - taginfo = self.pipeline_database.get_all_records( - self.pipeline_path, - ) - - if results is None or taginfo is None: - return None - - # hack - use polars for join as it can deal with complex data type - # TODO: convert the entire load logic to use polars with lazy evaluation - - joined_info = ( - pl.DataFrame(taginfo) - .join(pl.DataFrame(results), on=constants.PACKET_RECORD_ID, how="inner") - .to_arrow() - ) - - # joined_info = taginfo.join( - # results, - # constants.PACKET_RECORD_ID, - # join_type="inner", - # ) - - if not include_system_columns: - system_columns = [ - c - for c in joined_info.column_names - if c.startswith(constants.META_PREFIX) - or c.startswith(constants.DATAGRAM_PREFIX) - ] - joined_info = joined_info.drop(system_columns) - return joined_info diff --git a/src/orcapod/pipeline/observer.py b/src/orcapod/pipeline/observer.py new file mode 100644 index 00000000..e9df178c --- /dev/null +++ b/src/orcapod/pipeline/observer.py @@ -0,0 +1,100 @@ +"""No-op implementations of the observability protocols. + +Provides :class:`NoOpLogger` and :class:`NoOpObserver` — the defaults used +when no observability is configured. Every method is a zero-cost no-op. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from orcapod.protocols.observability_protocols import ( # noqa: F401 (re-exported for convenience) + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) + +if TYPE_CHECKING: + from orcapod.core.nodes import GraphNode + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +# --------------------------------------------------------------------------- +# NoOpLogger +# --------------------------------------------------------------------------- + + +class NoOpLogger: + """Logger that discards all captured output. + + Returned by :class:`NoOpObserver` when no logging sink is configured. + """ + + def record(self, captured: "CapturedLogs") -> None: + pass + + +# Singleton — NoOpLogger carries no state so one instance is enough. +_NOOP_LOGGER = NoOpLogger() + + +# --------------------------------------------------------------------------- +# NoOpObserver +# --------------------------------------------------------------------------- + + +class NoOpObserver: + """Observer that does nothing. + + Satisfies :class:`~orcapod.protocols.observability_protocols.ExecutionObserverProtocol` + and is the default when no observability is configured. + ``create_packet_logger`` returns the shared :data:`_NOOP_LOGGER` singleton. + """ + + def on_run_start(self, run_id: str) -> None: + pass + + def on_run_end(self, run_id: str) -> None: + pass + + def on_node_start(self, node: "GraphNode") -> None: + pass + + def on_node_end(self, node: "GraphNode") -> None: + pass + + def on_packet_start( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + ) -> None: + pass + + def on_packet_end( + self, + node: "GraphNode", + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + cached: bool, + ) -> None: + pass + + def on_packet_crash( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + error: Exception, + ) -> None: + pass + + def create_packet_logger( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + pipeline_path: tuple[str, ...] = (), + ) -> NoOpLogger: + return _NOOP_LOGGER diff --git a/src/orcapod/pipeline/result.py b/src/orcapod/pipeline/result.py new file mode 100644 index 00000000..1709580e --- /dev/null +++ b/src/orcapod/pipeline/result.py @@ -0,0 +1,23 @@ +"""Result type returned by pipeline orchestrators.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +@dataclass +class OrchestratorResult: + """Result of an orchestrator run. + + Attributes: + node_outputs: Mapping from graph node to its computed (tag, packet) + pairs. Empty when ``materialize_results=False``. + """ + + node_outputs: dict[Any, list[tuple["TagProtocol", "PacketProtocol"]]] = field( + default_factory=dict + ) diff --git a/src/orcapod/pipeline/serialization.py b/src/orcapod/pipeline/serialization.py new file mode 100644 index 00000000..99ea59f6 --- /dev/null +++ b/src/orcapod/pipeline/serialization.py @@ -0,0 +1,599 @@ +"""Pipeline serialization registries, helpers, and constants.""" + +from __future__ import annotations + +import logging +from enum import Enum +from typing import Any + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Format version +# --------------------------------------------------------------------------- + +PIPELINE_FORMAT_VERSION = "0.1.0" +SUPPORTED_FORMAT_VERSIONS = frozenset({"0.1.0"}) + +# --------------------------------------------------------------------------- +# LoadStatus +# --------------------------------------------------------------------------- + + +class LoadStatus(Enum): + """Status of a node after loading from a serialized pipeline.""" + + FULL = "full" + READ_ONLY = "read_only" + UNAVAILABLE = "unavailable" + + +# --------------------------------------------------------------------------- +# Registries +# --------------------------------------------------------------------------- + + +def _build_database_registry() -> dict[str, type]: + """Build the database type registry mapping type keys to classes. + + Returns: + Dict mapping type key strings to database classes. + """ + from orcapod.databases.delta_lake_databases import DeltaTableDatabase + from orcapod.databases.in_memory_databases import InMemoryArrowDatabase + from orcapod.databases.noop_database import NoOpArrowDatabase + + return { + "delta_table": DeltaTableDatabase, + "in_memory": InMemoryArrowDatabase, + "noop": NoOpArrowDatabase, + } + + +def _build_source_registry() -> dict[str, type]: + """Build the source type registry mapping type keys to classes. + + Returns: + Dict mapping type key strings to source classes. + """ + from orcapod.core.sources.arrow_table_source import ArrowTableSource + from orcapod.core.sources.cached_source import CachedSource + from orcapod.core.sources.csv_source import CSVSource + from orcapod.core.sources.data_frame_source import DataFrameSource + from orcapod.core.sources.delta_table_source import DeltaTableSource + from orcapod.core.sources.dict_source import DictSource + from orcapod.core.sources.list_source import ListSource + + return { + "csv": CSVSource, + "delta_table": DeltaTableSource, + "dict": DictSource, + "list": ListSource, + "data_frame": DataFrameSource, + "arrow_table": ArrowTableSource, + "cached": CachedSource, + } + + +def _build_operator_registry() -> dict[str, type]: + """Build the operator type registry mapping class names to classes. + + Returns: + Dict mapping class name strings to operator classes. + """ + from orcapod.core.operators import ( + Batch, + DropPacketColumns, + DropTagColumns, + Join, + MapPackets, + MapTags, + MergeJoin, + PolarsFilter, + SelectPacketColumns, + SelectTagColumns, + SemiJoin, + ) + + return { + "Join": Join, + "MergeJoin": MergeJoin, + "SemiJoin": SemiJoin, + "Batch": Batch, + "SelectTagColumns": SelectTagColumns, + "DropTagColumns": DropTagColumns, + "SelectPacketColumns": SelectPacketColumns, + "DropPacketColumns": DropPacketColumns, + "MapTags": MapTags, + "MapPackets": MapPackets, + "PolarsFilter": PolarsFilter, + } + + +def _build_packet_function_registry() -> dict[str, type]: + """Build the packet function type registry mapping type IDs to classes. + + Returns: + Dict mapping type ID strings to packet function classes. + """ + from orcapod.core.packet_function import PythonPacketFunction + + return { + "python.function.v0": PythonPacketFunction, + } + + +# Registries populated at module load time. +# Each registry maps a type key (or class name) to the corresponding class. +DATABASE_REGISTRY: dict[str, type] = _build_database_registry() +SOURCE_REGISTRY: dict[str, type] = _build_source_registry() +OPERATOR_REGISTRY: dict[str, type] = _build_operator_registry() +PACKET_FUNCTION_REGISTRY: dict[str, type] = _build_packet_function_registry() + + +def _ensure_registries() -> None: + """Ensure registries are populated. + + The registries are built at module import time, so this is a no-op in + normal use. It exists as a hook for tests or code that calls + ``register_*`` helpers before the first resolver call. + """ + # Registries are already populated at module load; nothing to do. + + +# --------------------------------------------------------------------------- +# Resolver helpers +# --------------------------------------------------------------------------- + + +def resolve_database_from_config(config: dict[str, Any]) -> Any: + """Reconstruct a database instance from a config dict. + + Args: + config: Dict with at least a ``"type"`` key matching a registered + database type. + + Returns: + A new database instance constructed from the config. + + Raises: + ValueError: If the ``"type"`` key is missing or unknown. + """ + _ensure_registries() + db_type = config.get("type") + if db_type not in DATABASE_REGISTRY: + raise ValueError( + f"Unknown database type: {db_type!r}. " + f"Known types: {sorted(DATABASE_REGISTRY.keys())}" + ) + if db_type == "in_memory": + logger.warning( + "Loading pipeline with in-memory database. Cached data from the " + "original run is not available — nodes will have UNAVAILABLE status." + ) + cls = DATABASE_REGISTRY[db_type] + return cls.from_config(config) + + +def resolve_operator_from_config(config: dict[str, Any]) -> Any: + """Reconstruct an operator instance from a config dict. + + Args: + config: Dict with at least a ``"class_name"`` key matching a registered + operator class. + + Returns: + A new operator instance constructed from the config. + + Raises: + ValueError: If the ``"class_name"`` key is missing or unknown. + """ + _ensure_registries() + class_name = config.get("class_name") + if class_name not in OPERATOR_REGISTRY: + raise ValueError( + f"Unknown operator: {class_name!r}. " + f"Known operators: {sorted(OPERATOR_REGISTRY.keys())}" + ) + cls = OPERATOR_REGISTRY[class_name] + return cls.from_config(config) + + +def resolve_packet_function_from_config(config: dict[str, Any]) -> Any: + """Reconstruct a packet function from a config dict. + + Args: + config: Dict with at least a ``"packet_function_type_id"`` key matching + a registered packet function type. + + Returns: + A new packet function instance constructed from the config. + + Raises: + ValueError: If the type ID is missing or unknown. + """ + _ensure_registries() + type_id = config.get("packet_function_type_id") + if type_id not in PACKET_FUNCTION_REGISTRY: + raise ValueError( + f"Unknown packet function type: {type_id!r}. " + f"Known types: {sorted(PACKET_FUNCTION_REGISTRY.keys())}" + ) + cls = PACKET_FUNCTION_REGISTRY[type_id] + return cls.from_config(config) + + +def resolve_source_from_config( + config: dict[str, Any], + *, + fallback_to_proxy: bool = False, +) -> Any: + """Reconstruct a source instance from a config dict. + + Args: + config: Dict with at least a ``"source_type"`` key matching a registered + source type. + fallback_to_proxy: If ``True`` and reconstruction fails, return a + ``SourceProxy`` preserving identity hashes and schemas from the + config. Requires the config to contain ``content_hash``, + ``pipeline_hash``, ``tag_schema``, and ``packet_schema`` fields + (as written by ``RootSource._identity_config()``). + + Returns: + A new source instance constructed from the config, or a ``SourceProxy`` + if reconstruction fails and *fallback_to_proxy* is ``True``. + + Raises: + ValueError: If the source type is missing or unknown. + Exception: Re-raised from ``from_config`` when *fallback_to_proxy* is + ``False`` and reconstruction fails. + """ + _ensure_registries() + source_type = config.get("source_type") + if source_type not in SOURCE_REGISTRY: + if fallback_to_proxy: + return _source_proxy_from_config(config) + raise ValueError( + f"Unknown source type: {source_type!r}. " + f"Known types: {sorted(SOURCE_REGISTRY.keys())}" + ) + cls = SOURCE_REGISTRY[source_type] + try: + return cls.from_config(config) + except Exception: + if fallback_to_proxy: + logger.warning( + "Could not reconstruct %s source; returning SourceProxy.", + source_type, + ) + return _source_proxy_from_config(config) + raise + + +def _source_proxy_from_config(config: dict[str, Any]) -> Any: + """Create a ``SourceProxy`` from identity fields in a source config. + + Args: + config: Source config dict containing ``content_hash``, + ``pipeline_hash``, ``tag_schema``, ``packet_schema``, and + ``source_id`` fields. + + Returns: + A ``SourceProxy`` preserving the original source's identity. + + Raises: + ValueError: If required identity fields are missing. + """ + from orcapod.core.sources.source_proxy import SourceProxy + from orcapod.types import Schema + + required = ("content_hash", "pipeline_hash", "tag_schema", "packet_schema") + missing = [k for k in required if k not in config] + if missing: + raise ValueError( + f"Cannot create SourceProxy: config is missing required identity " + f"fields: {missing}" + ) + + # Derive expected class name from source_type via the registry. + source_type = config.get("source_type") + expected_class_name: str | None = None + if source_type and source_type in SOURCE_REGISTRY: + expected_class_name = SOURCE_REGISTRY[source_type].__name__ + + tag_schema = Schema(deserialize_schema(config["tag_schema"])) + packet_schema = Schema(deserialize_schema(config["packet_schema"])) + + return SourceProxy( + source_id=config.get("source_id", "unknown"), + content_hash_str=config["content_hash"], + pipeline_hash_str=config["pipeline_hash"], + tag_schema=tag_schema, + packet_schema=packet_schema, + expected_class_name=expected_class_name, + source_config=config, + ) + + +# --------------------------------------------------------------------------- +# Registration helpers (extensibility) +# --------------------------------------------------------------------------- + + +def register_database(type_key: str, cls: type) -> None: + """Register a custom database implementation for deserialization. + + Args: + type_key: The string key to use in serialized configs. + cls: The database class to register. + """ + _ensure_registries() + DATABASE_REGISTRY[type_key] = cls + + +def register_source(type_key: str, cls: type) -> None: + """Register a custom source implementation for deserialization. + + Args: + type_key: The string key to use in serialized configs. + cls: The source class to register. + """ + _ensure_registries() + SOURCE_REGISTRY[type_key] = cls + + +def register_operator(class_name: str, cls: type) -> None: + """Register a custom operator implementation for deserialization. + + Args: + class_name: The class name string to use in serialized configs. + cls: The operator class to register. + """ + _ensure_registries() + OPERATOR_REGISTRY[class_name] = cls + + +def register_packet_function(type_id: str, cls: type) -> None: + """Register a custom packet function implementation for deserialization. + + Args: + type_id: The type ID string to use in serialized configs. + cls: The packet function class to register. + """ + _ensure_registries() + PACKET_FUNCTION_REGISTRY[type_id] = cls + + +# --------------------------------------------------------------------------- +# Schema serialization helpers +# --------------------------------------------------------------------------- + + +def serialize_schema(schema: Any, type_converter: Any | None = None) -> dict[str, str]: + """Convert a Schema mapping to JSON-serializable Arrow type strings. + + The result contains human-readable Arrow type strings for each field + (e.g. ``"int64"``, ``"large_string"``, ``"list"``). + These strings follow Arrow's canonical format and can be parsed back + by :func:`deserialize_schema` in any language that implements the + Arrow type grammar. + + Args: + schema: A Schema-like mapping from field name to data type. + type_converter: Optional type converter for Python→Arrow conversion. + When provided, Python types (e.g. ``int``, ``str``) are converted + to Arrow type strings (e.g. ``"int64"``, ``"large_string"``). + When ``None``, values are stringified directly. + + Returns: + A dict mapping field names to Arrow type string representations. + """ + if type_converter is not None: + result = {} + for k, v in schema.items(): + try: + arrow_type = type_converter.python_type_to_arrow_type(v) + result[k] = str(arrow_type) + except Exception: + result[k] = str(v) + return result + return {k: str(v) for k, v in schema.items()} + + +def deserialize_schema( + schema_dict: dict[str, str], + type_converter: Any | None = None, +) -> dict[str, Any]: + """Reconstruct a Python-type schema from Arrow type string values. + + Parses Arrow type strings (e.g. ``"int64"``, ``"list"``) + back into ``pa.DataType`` objects, then converts them to Python types + via the type converter. Falls back to raw strings for fields that + cannot be parsed. + + Args: + schema_dict: Dict mapping field names to Arrow type strings, as + produced by :func:`serialize_schema`. + type_converter: Optional type converter for Arrow→Python conversion. + When ``None``, the default data context's converter is used. + + Returns: + A dict mapping field names to Python types (or raw strings if + parsing fails). + """ + if type_converter is None: + from orcapod.contexts import resolve_context + + type_converter = resolve_context(None).type_converter + + result: dict[str, Any] = {} + for name, type_str in schema_dict.items(): + try: + arrow_type = parse_arrow_type_string(type_str) + result[name] = type_converter.arrow_type_to_python_type(arrow_type) + except Exception: + result[name] = type_str + return result + + +def parse_arrow_type_string(type_str: str) -> Any: + """Parse an Arrow type string into a ``pa.DataType``. + + Handles both primitive types (``"int64"``, ``"large_string"``) and + nested types (``"list"``, ``"struct"``, + ``"map"``). + + The grammar follows Arrow's canonical ``str(pa.DataType)`` output. + + Args: + type_str: Arrow type string to parse. + + Returns: + The corresponding ``pa.DataType``. + + Raises: + ValueError: If the type string cannot be parsed. + """ + import pyarrow as pa + + type_str = type_str.strip() + + # Primitive types — try direct lookup via pa.() + if "<" not in type_str: + factory = _ARROW_PRIMITIVE_TYPES.get(type_str) + if factory is not None: + return factory() + raise ValueError(f"Unknown Arrow type: {type_str!r}") + + # Nested types — parse the outer type and recurse + bracket_pos = type_str.index("<") + outer = type_str[:bracket_pos].strip() + inner = type_str[bracket_pos + 1 : -1].strip() # strip < and > + + if outer in ("list", "large_list"): + # "list" or "list>" + child_type_str = _strip_field_name(inner) + child_type = parse_arrow_type_string(child_type_str) + return ( + pa.large_list(child_type) if outer == "large_list" else pa.list_(child_type) + ) + + if outer == "struct": + # "struct" — split on top-level commas + fields = _split_struct_fields(inner) + pa_fields = [] + for field_str in fields: + colon_pos = field_str.index(":") + field_name = field_str[:colon_pos].strip() + field_type_str = field_str[colon_pos + 1 :].strip() + pa_fields.append( + pa.field(field_name, parse_arrow_type_string(field_type_str)) + ) + return pa.struct(pa_fields) + + if outer == "map": + # "map" — split on first top-level comma + parts = _split_top_level(inner, ",", max_splits=1) + if len(parts) != 2: + raise ValueError(f"Cannot parse map type: {type_str!r}") + key_type = parse_arrow_type_string(parts[0].strip()) + value_type = parse_arrow_type_string(parts[1].strip()) + return pa.map_(key_type, value_type) + + raise ValueError(f"Unknown nested Arrow type: {type_str!r}") + + +def _strip_field_name(s: str) -> str: + """Strip the ``"item: "`` or ``"field_name: "`` prefix from a child type string.""" + if ":" in s: + # Only strip if there's no nested '<' before the ':' + colon_pos = s.index(":") + bracket_pos = s.find("<") + if bracket_pos == -1 or colon_pos < bracket_pos: + return s[colon_pos + 1 :].strip() + return s + + +def _split_top_level(s: str, sep: str, max_splits: int = -1) -> list[str]: + """Split *s* on *sep*, but only at the top level (not inside ``<>``). + + Args: + s: String to split. + sep: Separator character. + max_splits: Maximum number of splits (-1 for unlimited). + + Returns: + List of substrings. + """ + parts: list[str] = [] + depth = 0 + current: list[str] = [] + for ch in s: + if ch == "<": + depth += 1 + elif ch == ">": + depth -= 1 + if ch == sep and depth == 0 and (max_splits == -1 or len(parts) < max_splits): + parts.append("".join(current)) + current = [] + continue + current.append(ch) + parts.append("".join(current)) + return parts + + +def _split_struct_fields(inner: str) -> list[str]: + """Split struct fields on top-level commas.""" + return _split_top_level(inner, ",") + + +def _build_arrow_primitive_types() -> dict[str, Any]: + """Build a mapping of Arrow type string names to factory callables.""" + import pyarrow as pa + + types = {} + for name in [ + "null", + "bool_", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "string", + "utf8", + "large_string", + "large_utf8", + "binary", + "large_binary", + "date32", + "date64", + "time32", + "time64", + "duration", + "timestamp", + ]: + factory = getattr(pa, name, None) + if factory is not None: + types[name] = factory + # Also map the str() output if it differs from the factory name + try: + canonical = str(factory()) + if canonical != name: + types[canonical] = factory + except Exception: + pass + # Common aliases + types["double"] = pa.float64 + types["float"] = pa.float32 + types["bool"] = pa.bool_ + return types + + +_ARROW_PRIMITIVE_TYPES: dict[str, Any] = _build_arrow_primitive_types() diff --git a/src/orcapod/pipeline/sync_orchestrator.py b/src/orcapod/pipeline/sync_orchestrator.py new file mode 100644 index 00000000..b3cf8d74 --- /dev/null +++ b/src/orcapod/pipeline/sync_orchestrator.py @@ -0,0 +1,228 @@ +"""Synchronous pipeline orchestrator. + +Walks a compiled pipeline's node graph topologically, delegating to each +node's ``execute()`` method with observer injection. +""" + +from __future__ import annotations + +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from orcapod.pipeline.result import OrchestratorResult +from orcapod.protocols.node_protocols import ( + is_function_node, + is_operator_node, + is_source_node, +) + +if TYPE_CHECKING: + import networkx as nx + + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + +logger = logging.getLogger(__name__) + + +class SyncPipelineOrchestrator: + """Execute a compiled pipeline synchronously via node ``execute()`` methods. + + Walks the node graph in topological order. For each node, delegates + to ``node.execute(observer=...)`` which owns all per-packet logic, + cache lookups, and observer hooks internally. + + The orchestrator is responsible only for topological ordering, + buffer management, and stream materialization between nodes. + + Args: + observer: Optional execution observer forwarded to nodes. + """ + + def __init__( + self, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", + ) -> None: + self._observer = observer + self._error_policy = error_policy + + def run( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, + run_id: str | None = None, + ) -> OrchestratorResult: + """Execute the node graph synchronously. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, keep all node outputs in memory + and return them. If False, discard buffers after downstream + consumption (only DB-persisted results survive). + run_id: Optional run identifier. If not provided, a UUID is + generated automatically. + + Returns: + OrchestratorResult with node outputs. + """ + import networkx as nx + + effective_run_id = run_id or str(uuid.uuid4()) + if self._observer is not None: + self._observer.on_run_start(effective_run_id) + + topo_order = list(nx.topological_sort(graph)) + buffers: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} + processed: set[Any] = set() + + for node in topo_order: + if is_source_node(node): + buffers[node] = node.execute(observer=self._observer) + elif is_function_node(node): + upstream_buf = self._gather_upstream(node, graph, buffers) + upstream_node = list(graph.predecessors(node))[0] + input_stream = self._materialize_as_stream(upstream_buf, upstream_node) + buffers[node] = node.execute( + input_stream, + observer=self._observer, + error_policy=self._error_policy, + ) + elif is_operator_node(node): + upstream_buffers = self._gather_upstream_multi(node, graph, buffers) + input_streams = [ + self._materialize_as_stream(buf, upstream_node) + for buf, upstream_node in upstream_buffers + ] + buffers[node] = node.execute(*input_streams, observer=self._observer) + else: + raise TypeError( + f"Unknown node type: {getattr(node, 'node_type', None)!r}" + ) + + processed.add(node) + + if not materialize_results: + self._gc_buffers(node, graph, buffers, processed) + + if not materialize_results: + buffers.clear() + + if self._observer is not None: + self._observer.on_run_end(effective_run_id) + return OrchestratorResult(node_outputs=buffers) + + @staticmethod + def _gather_upstream( + node: Any, graph: "nx.DiGraph", buffers: dict[Any, list[tuple[Any, Any]]] + ) -> list[tuple[Any, Any]]: + """Gather a single upstream buffer (for function nodes).""" + predecessors = list(graph.predecessors(node)) + if len(predecessors) != 1: + raise ValueError( + f"FunctionNode expects exactly 1 upstream, got {len(predecessors)}" + ) + return buffers[predecessors[0]] + + @staticmethod + def _gather_upstream_multi( + node: Any, graph: "nx.DiGraph", buffers: dict[Any, list[tuple[Any, Any]]] + ) -> list[tuple[list[tuple[Any, Any]], Any]]: + """Gather multiple upstream buffers with their nodes (for operators). + + Returns list of (buffer, upstream_node) tuples preserving the + order that matches the operator's input_streams order. + """ + predecessors = list(graph.predecessors(node)) + upstream_order = {id(upstream): i for i, upstream in enumerate(node.upstreams)} + sorted_preds = sorted( + predecessors, + key=lambda p: upstream_order.get(id(p), 0), + ) + return [(buffers[p], p) for p in sorted_preds] + + @staticmethod + def _materialize_as_stream(buf: list[tuple[Any, Any]], upstream_node: Any) -> Any: + """Wrap a (tag, packet) buffer as an ArrowTableStream. + + Uses the same column selection pattern as + ``StaticOutputOperatorPod._materialize_to_stream``: system_tags + for tags, source info for packets. + + Args: + buf: List of (tag, packet) tuples. + upstream_node: The node that produced this buffer (used to + determine tag column names). + + Returns: + An ArrowTableStream. + """ + from orcapod.core.streams.arrow_table_stream import ArrowTableStream + from orcapod.utils import arrow_utils + from orcapod.utils.lazy_module import LazyModule + + pa = LazyModule("pyarrow") + + if not buf: + # Build an empty stream with the correct schema from the upstream node + tag_schema, packet_schema = upstream_node.output_schema( + columns={"system_tags": True, "source": True} + ) + type_converter = upstream_node.data_context.type_converter + empty_fields = {} + for name, py_type in {**tag_schema, **packet_schema}.items(): + arrow_type = type_converter.python_type_to_arrow_type(py_type) + empty_fields[name] = pa.array([], type=arrow_type) + empty_table = pa.table(empty_fields) + tag_keys = upstream_node.keys()[0] + return ArrowTableStream( + empty_table, + tag_columns=tag_keys, + producer=upstream_node.producer, + upstreams=upstream_node.upstreams, + ) + + tag_tables = [tag.as_table(columns={"system_tags": True}) for tag, _ in buf] + packet_tables = [pkt.as_table(columns={"source": True}) for _, pkt in buf] + + combined_tags = pa.concat_tables(tag_tables) + combined_packets = pa.concat_tables(packet_tables) + + user_tag_keys = tuple(buf[0][0].keys()) + source_info = buf[0][1].source_info() + + full_table = arrow_utils.hstack_tables(combined_tags, combined_packets) + + # Pass the upstream node's producer and upstreams so the + # materialized stream inherits the correct identity_structure + # and pipeline_identity_structure (via StreamBase delegation). + # This ensures downstream operators produce correct system tag + # column names (which embed pipeline hashes of their inputs). + producer = upstream_node.producer + upstreams = upstream_node.upstreams + + return ArrowTableStream( + full_table, + tag_columns=user_tag_keys, + source_info=source_info, + producer=producer, + upstreams=upstreams, + ) + + @staticmethod + def _gc_buffers( + current_node: Any, + graph: "nx.DiGraph", + buffers: dict[Any, list[tuple[Any, Any]]], + processed: set[Any], + ) -> None: + """Discard buffers no longer needed by any unprocessed downstream.""" + for pred in graph.predecessors(current_node): + if pred not in buffers: + continue + all_successors_done = all( + succ in processed for succ in graph.successors(pred) + ) + if all_successors_done: + del buffers[pred] diff --git a/src/orcapod/protocols/__init__.py b/src/orcapod/protocols/__init__.py index e69de29b..7f8ba7c7 100644 --- a/src/orcapod/protocols/__init__.py +++ b/src/orcapod/protocols/__init__.py @@ -0,0 +1,4 @@ +from orcapod.protocols.observability_protocols import ( + ExecutionObserverProtocol, + PacketExecutionLoggerProtocol, +) diff --git a/src/orcapod/protocols/core_protocols/__init__.py b/src/orcapod/protocols/core_protocols/__init__.py index f9c711d4..c658e86f 100644 --- a/src/orcapod/protocols/core_protocols/__init__.py +++ b/src/orcapod/protocols/core_protocols/__init__.py @@ -1,24 +1,31 @@ -from .base import ExecutionEngine, PodFunction -from .datagrams import Datagram, Tag, Packet -from .streams import Stream, LiveStream -from .kernel import Kernel -from .pods import Pod, CachedPod -from .source import Source -from .trackers import Tracker, TrackerManager +from orcapod.types import ColumnConfig +from orcapod.protocols.hashing_protocols import PipelineElementProtocol +from .datagrams import DatagramProtocol, PacketProtocol, TagProtocol +from .executor import PacketFunctionExecutorProtocol, PythonFunctionExecutorProtocol +from .function_pod import FunctionPodProtocol +from .operator_pod import OperatorPodProtocol +from .packet_function import PacketFunctionProtocol +from .pod import ArgumentGroup, PodProtocol +from .sources import SourceProtocol +from .streams import StreamProtocol +from .trackers import TrackerProtocol, TrackerManagerProtocol __all__ = [ - "ExecutionEngine", - "PodFunction", - "Datagram", - "Tag", - "Packet", - "Stream", - "LiveStream", - "Kernel", - "Pod", - "CachedPod", - "Source", - "Tracker", - "TrackerManager", + "ColumnConfig", + "DatagramProtocol", + "TagProtocol", + "PacketProtocol", + "SourceProtocol", + "StreamProtocol", + "PodProtocol", + "ArgumentGroup", + "PipelineElementProtocol", + "FunctionPodProtocol", + "OperatorPodProtocol", + "PacketFunctionProtocol", + "PacketFunctionExecutorProtocol", + "PythonFunctionExecutorProtocol", + "TrackerProtocol", + "TrackerManagerProtocol", ] diff --git a/src/orcapod/protocols/core_protocols/base.py b/src/orcapod/protocols/core_protocols/base.py deleted file mode 100644 index 87d9a819..00000000 --- a/src/orcapod/protocols/core_protocols/base.py +++ /dev/null @@ -1,220 +0,0 @@ -from collections.abc import Callable -from typing import Any, Protocol, runtime_checkable -from orcapod.types import DataValue - - -@runtime_checkable -class ExecutionEngine(Protocol): - """ - Abstract execution backend responsible for running user functions. - - ExecutionEngine defines the minimal contract that any execution backend - must satisfy to be used by Orcapod. Concrete implementations may execute - work in the current process (synchronously), on background threads or - processes, or on remote/distributed systems (e.g., Ray, Dask, Slurm). - - Responsibilities - - Accept a Python callable plus arguments and execute it. - - Provide both a synchronous API (blocking) and an asynchronous API - (awaitable) with consistent error semantics. - - Surface the original exception from the user function without - wrapping where practical, while preserving traceback information. - - Be safe to construct/read concurrently from the pipeline orchestration. - - Contract - - Inputs: a Python callable and its positional/keyword arguments. - - Outputs: the callable's return value (or a coroutine result when awaited). - - Errors: exceptions raised by the callable must be propagated to the - caller of submit_sync/submit_async. - - Cancellation: implementations may optionally support task cancellation - in submit_async via standard asyncio cancellation; submit_sync is - expected to block until completion. - - Notes - - Serialization: Distributed engines may require the function and its - arguments to be serializable (pickle/cloudpickle). Local engines have - no such requirement beyond normal Python callability. - - Resource usage: Engines may schedule work with resource hints - (CPU/GPU/memory) outside this minimal protocol; higher-level APIs can - extend this interface if needed. - - Naming: ``name`` should be a short, human-friendly identifier such as - "local", "threadpool", "processpool", or "ray" and is used for logging - and diagnostics. - """ - @property - def supports_async(self) -> bool: - """Indicate whether this engine supports async execution.""" - ... - - @property - def name(self) -> str: - """Return a short, human-friendly identifier for the engine. - - Examples: "local", "threadpool", "processpool", "ray". - Used for logging, metrics, and debugging output. - """ - ... - - def submit_sync( - self, - func: Callable[..., Any], - /, - *, - fn_args: tuple[Any, ...] = (), - fn_kwargs: dict[str, Any] | None = None, - **engine_opts: Any, - ) -> Any: - """ - Execute a callable and return its result (blocking). - - This call is blocking. Engines may choose where/how the function - executes (same thread, worker thread/process, remote node), but the - call does not return until the work completes or fails. - - Parameters - - func: Python callable to execute. - - fn_args: Tuple of positional arguments to pass to ``func``. - - fn_kwargs: Mapping of keyword arguments to pass to ``func``. - - **engine_opts: Engine-specific options (e.g., resources, priority), - never forwarded to ``func``. - - Returns: - Any: The return value of ``func``. - - Raises: - Exception: Any exception raised by ``func`` must be propagated to - the caller. Engines should preserve the original traceback whenever - practical. - - Notes - - This API separates function inputs from engine configuration. - ``fn_args``/``fn_kwargs`` are always applied to ``func``; - ``engine_opts`` configures the engine and is never forwarded. - """ - ... - - async def submit_async( - self, - func: Callable[..., Any], - /, - *, - fn_args: tuple[Any, ...] = (), - fn_kwargs: dict[str, Any] | None = None, - **engine_opts: Any, - ) -> Any: - """ - Asynchronously execute a callable and return the result when awaited. - - The returned awaitable resolves to the callable's return value or - raises the callable's exception. Implementations should integrate with - asyncio semantics: if the awaiting task is cancelled, the engine may - attempt to cancel the underlying work when supported. - - Parameters - - func: Python callable to execute. - - fn_args: Tuple of positional arguments to pass to ``func``. - - fn_kwargs: Mapping of keyword arguments to pass to ``func``. - - **engine_opts: Engine-specific options (e.g., resources, priority), - never forwarded to ``func``. - - Returns: - Any: The return value of ``func`` when awaited. - - Raises: - asyncio.CancelledError: If the awaiting task is cancelled and the - implementation propagates cancellation. - Exception: Any exception raised by ``func`` must be propagated to - the awaiting caller, with traceback preserved where possible. - - Notes - - Mirrors the sync API: ``fn_args``/``fn_kwargs`` target ``func``; - ``engine_opts`` configures the engine and is never forwarded. - """ - ... - - # TODO: consider adding batch submission - - -@runtime_checkable -class PodFunction(Protocol): - """ - A function suitable for use in a FunctionPod. - - PodFunctions define the computational logic that operates on individual - packets within a Pod. They represent pure functions that transform - data values without side effects. - - These functions are designed to be: - - Stateless: No dependency on external state - - Deterministic: Same inputs always produce same outputs - - Serializable: Can be cached and distributed - - Type-safe: Clear input/output contracts - - PodFunctions accept named arguments corresponding to packet fields - and return transformed data values. - """ - - def __call__(self, **kwargs: DataValue) -> None | DataValue: - """ - Execute the pod function with the given arguments. - - The function receives packet data as named arguments and returns - either transformed data or None (for filtering operations). - - Args: - **kwargs: Named arguments mapping packet fields to data values - - Returns: - None: Filter out this packet (don't include in output) - DataValue: Single transformed value - - Raises: - TypeError: If required arguments are missing - ValueError: If argument values are invalid - """ - ... - - -@runtime_checkable -class Labelable(Protocol): - """ - Protocol for objects that can have a human-readable label. - - Labels provide meaningful names for objects in the computational graph, - making debugging, visualization, and monitoring much easier. They serve - as human-friendly identifiers that complement the technical identifiers - used internally. - - Labels are optional but highly recommended for: - - Debugging complex computational graphs - - Visualization and monitoring tools - - Error messages and logging - - User interfaces and dashboards - """ - - @property - def label(self) -> str: - """ - Return the human-readable label for this object. - - Labels should be descriptive and help users understand the purpose - or role of the object in the computational graph. - - Returns: - str: Human-readable label for this object - None: No label is set (will use default naming) - """ - ... - - @label.setter - def label(self, label: str | None) -> None: - """ - Set the human-readable label for this object. - - Labels should be descriptive and help users understand the purpose - or role of the object in the computational graph. - - Args: - value (str): Human-readable label for this object - """ - ... diff --git a/src/orcapod/protocols/core_protocols/datagrams.py b/src/orcapod/protocols/core_protocols/datagrams.py index a0f24d87..27ec274a 100644 --- a/src/orcapod/protocols/core_protocols/datagrams.py +++ b/src/orcapod/protocols/core_protocols/datagrams.py @@ -1,15 +1,26 @@ -from collections.abc import Collection, Iterator, Mapping -from typing import Any, Protocol, Self, TYPE_CHECKING, runtime_checkable -from orcapod.protocols.hashing_protocols import ContentIdentifiable -from orcapod.types import DataValue, PythonSchema - +from __future__ import annotations + +from collections.abc import Iterator, Mapping +from typing import ( + TYPE_CHECKING, + Any, + Protocol, + Self, + runtime_checkable, +) + +from orcapod.protocols.hashing_protocols import ( + ContentIdentifiableProtocol, + DataContextAwareProtocol, +) +from orcapod.types import ColumnConfig, DataValue, Schema if TYPE_CHECKING: import pyarrow as pa @runtime_checkable -class Datagram(ContentIdentifiable, Protocol): +class DatagramProtocol(ContentIdentifiableProtocol, DataContextAwareProtocol, Protocol): """ Protocol for immutable datagram containers in Orcapod. @@ -22,9 +33,9 @@ class Datagram(ContentIdentifiable, Protocol): - **Meta columns**: Internal system metadata with {constants.META_PREFIX} (typically '__') prefixes (e.g. __processed_at, etc.) - **Context column**: Data context information ({constants.CONTEXT_KEY}) - Derivative of datagram (such as Packet or Tag) will also include some specific columns pertinent to the function of the specialized datagram: - - **Source info columns**: Data provenance with {constants.SOURCE_PREFIX} ('_source_') prefixes (_source_user_id, etc.) used in Packet - - **System tags**: Internal tags for system use, typically prefixed with {constants.SYSTEM_TAG_PREFIX} ('_system_') (_system_created_at, etc.) used in Tag + Derivative of datagram (such as PacketProtocol or TagProtocol) will also include some specific columns pertinent to the function of the specialized datagram: + - **Source info columns**: Data provenance with {constants.SOURCE_PREFIX} ('_source_') prefixes (_source_user_id, etc.) used in PacketProtocol + - **System tags**: Internal tags for system use, typically prefixed with {constants.SYSTEM_TAG_PREFIX} ('_system_') (_system_created_at, etc.) used in TagProtocol All operations are by design immutable - methods return new datagram instances rather than modifying existing ones. @@ -35,20 +46,13 @@ class Datagram(ContentIdentifiable, Protocol): >>> table = datagram.as_table() """ - # 1. Core Properties (Identity & Structure) @property - def data_context_key(self) -> str: + def datagram_id(self) -> str: """ - Return the data context key for this datagram. - - This key identifies a collection of system components that collectively controls - how information is serialized, hashed and represented, including the semantic type registry, - arrow data hasher, and other contextual information. Same piece of information (that is two datagrams - with an identical *logical* content) may bear distinct internal representation if they are - represented under two distinct data context, as signified by distinct data context keys. + Return the UUID of this datagram. Returns: - str: Context key for proper datagram interpretation + UUID: The unique identifier for this instance of datagram. """ ... @@ -139,9 +143,9 @@ def get(self, key: str, default: DataValue = None) -> DataValue: # 3. Structural Information def keys( self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> tuple[str, ...]: """ Return tuple of column names. @@ -172,12 +176,12 @@ def keys( """ ... - def types( + def schema( self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> PythonSchema: + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> Schema: """ Return type specification mapping field names to Python types. @@ -202,10 +206,10 @@ def types( def arrow_schema( self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - ) -> "pa.Schema": + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> pa.Schema: """ Return PyArrow schema representation. @@ -233,9 +237,9 @@ def arrow_schema( # 4. Format Conversions (Export) def as_dict( self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> dict[str, DataValue]: """ Convert datagram to dictionary format. @@ -267,9 +271,9 @@ def as_dict( def as_table( self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pa.Table": """ Convert datagram to PyArrow Table format. @@ -301,12 +305,12 @@ def as_table( def as_arrow_compatible_dict( self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> dict[str, Any]: """ - Return dictionary with values optimized for Arrow table conversion. + Return a dictionary with values optimized for Arrow table conversion. This method returns a dictionary where values are in a form that can be efficiently converted to Arrow format using pa.Table.from_pylist(). @@ -328,7 +332,7 @@ def as_arrow_compatible_dict( include_context: Whether to include context key Returns: - Dictionary with values optimized for Arrow conversion + A dictionary with values optimized for Arrow table conversion. Example: # Efficient batch conversion pattern @@ -365,7 +369,7 @@ def get_meta_value(self, key: str, default: DataValue = None) -> DataValue: def with_meta_columns(self, **updates: DataValue) -> Self: """ - Create new datagram with updated meta columns. + Create a new datagram with updated meta columns. Adds or updates operational metadata while preserving all data columns. Keys are automatically prefixed with {orcapod.META_PREFIX} ('__') if needed. @@ -374,7 +378,7 @@ def with_meta_columns(self, **updates: DataValue) -> Self: **updates: Meta column updates as keyword arguments. Returns: - New datagram instance with updated meta columns. + A new datagram instance with updated meta columns. Example: >>> tracked = datagram.with_meta_columns( @@ -386,7 +390,7 @@ def with_meta_columns(self, **updates: DataValue) -> Self: def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: """ - Create new datagram with specified meta columns removed. + Create a new datagram with specified meta columns removed. Args: *keys: Meta column keys to remove (prefixes optional). @@ -394,10 +398,10 @@ def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: Returns: - New datagram instance without specified meta columns. + A new datagram instance without specified meta columns. Raises: - KeryError: If any specified meta column to drop doesn't exist and ignore_missing=False. + KeyError: If any specified meta column to drop doesn't exist and ignore_missing=False. Example: >>> cleaned = datagram.drop_meta_columns("old_source", "temp_debug") @@ -407,7 +411,7 @@ def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: # 6. Data Column Operations def select(self, *column_names: str) -> Self: """ - Create new datagram with only specified data columns. + Create a new datagram with only specified data columns. Args: *column_names: Data column names to keep. @@ -427,7 +431,7 @@ def select(self, *column_names: str) -> Self: def drop(self, *column_names: str, ignore_missing: bool = False) -> Self: """ - Create new datagram with specified data columns removed. Note that this does not + Create a new datagram with specified data columns removed. Note that this does not remove meta columns or context column. Refer to `drop_meta_columns()` for dropping specific meta columns. Context key column can never be dropped but a modified copy can be created with a different context key using `with_data_context()`. @@ -452,7 +456,7 @@ def rename( column_mapping: Mapping[str, str], ) -> Self: """ - Create new datagram with data columns renamed. + Create a new datagram with data columns renamed. Args: column_mapping: Mapping from old names to new names. @@ -470,7 +474,7 @@ def rename( def update(self, **updates: DataValue) -> Self: """ - Create new datagram with existing column values updated. + Create a new datagram with existing column values updated. Updates values in existing data columns. Will error if any specified column doesn't exist - use with_columns() to add new columns. @@ -498,7 +502,7 @@ def with_columns( **updates: DataValue, ) -> Self: """ - Create new datagram with additional data columns. + Create a new datagram with additional data columns. Adds new data columns to the datagram. Will error if any specified column already exists - use update() to modify existing columns. @@ -526,7 +530,7 @@ def with_columns( # 7. Context Operations def with_context_key(self, new_context_key: str) -> Self: """ - Create new datagram with different context key. + Create new datagram with a different context key. Changes the semantic interpretation context while preserving all data. The context key affects how columns are processed and converted. @@ -585,13 +589,13 @@ def __repr__(self) -> str: Shows the datagram type and comprehensive information for debugging. Returns: - Detailed representation with type and metadata information. + A detailed representation with type and metadata information. """ ... @runtime_checkable -class Tag(Datagram, Protocol): +class TagProtocol(DatagramProtocol, Protocol): """ Metadata associated with each data item in a stream. @@ -600,7 +604,7 @@ class Tag(Datagram, Protocol): helps with: - Data lineage tracking - Grouping and aggregation operations - - Temporal information (timestamps) + - TemporalProtocol information (timestamps) - Source identification - Processing context @@ -612,214 +616,6 @@ class Tag(Datagram, Protocol): - Quality indicators or confidence scores """ - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> tuple[str, ...]: - """ - Return tuple of column names. - - Provides access to column names with filtering options for different - column types. Default returns only data column names. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column inclusion. - - False: Return only data column names (default) - - True: Include all meta column names - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include context column. - include_source: Whether to include source info fields. - - - Returns: - Tuple of column names based on inclusion criteria. - - Example: - >>> datagram.keys() # Data columns only - ('user_id', 'name', 'email') - >>> datagram.keys(include_meta_columns=True) - ('user_id', 'name', 'email', f'{orcapod.META_PREFIX}processed_at', f'{orcapod.META_PREFIX}pipeline_version') - >>> datagram.keys(include_meta_columns=["pipeline"]) - ('user_id', 'name', 'email',f'{orcapod.META_PREFIX}pipeline_version') - >>> datagram.keys(include_context=True) - ('user_id', 'name', 'email', f'{orcapod.CONTEXT_KEY}') - """ - ... - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> PythonSchema: - """ - Return type specification mapping field names to Python types. - - The TypeSpec enables type checking and validation throughout the system. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column type inclusion. - - False: Exclude meta column types (default) - - True: Include all meta column types - - Collection[str]: Include meta column types matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include context type. - include_source: Whether to include source info fields. - - Returns: - TypeSpec mapping field names to their Python types. - - Example: - >>> datagram.types() - {'user_id': , 'name': } - """ - ... - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> "pa.Schema": - """ - Return PyArrow schema representation. - - The schema provides structured field and type information for efficient - serialization and deserialization with PyArrow. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column schema inclusion. - - False: Exclude meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include context column. - include_source: Whether to include source info fields. - - - Returns: - PyArrow Schema describing the datagram structure. - - Example: - >>> schema = datagram.arrow_schema() - >>> schema.names - ['user_id', 'name'] - """ - ... - - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> dict[str, DataValue]: - """ - Convert datagram to dictionary format. - - Provides a simple key-value representation useful for debugging, - serialization, and interop with dict-based APIs. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column inclusion. - - False: Exclude all meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include the context key. - include_source: Whether to include source info fields. - - - Returns: - Dictionary with requested columns as key-value pairs. - - Example: - >>> data = datagram.as_dict() # {'user_id': 123, 'name': 'Alice'} - >>> full_data = datagram.as_dict( - ... include_meta_columns=True, - ... include_context=True - ... ) - """ - ... - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_system_tags: bool = False, - ) -> "pa.Table": - """ - Convert datagram to PyArrow Table format. - - Provides a standardized columnar representation suitable for analysis, - processing, and interoperability with Arrow-based tools. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column inclusion. - - False: Exclude all meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include the context column. - include_source: Whether to include source info columns in the schema. - - Returns: - PyArrow Table with requested columns. - - Example: - >>> table = datagram.as_table() # Data columns only - >>> full_table = datagram.as_table( - ... include_meta_columns=True, - ... include_context=True - ... ) - >>> filtered = datagram.as_table(include_meta_columns=["pipeline"]) # same as passing f"{orcapod.META_PREFIX}pipeline" - """ - ... - - # TODO: add this back - # def as_arrow_compatible_dict( - # self, - # include_all_info: bool = False, - # include_meta_columns: bool | Collection[str] = False, - # include_context: bool = False, - # include_source: bool = False, - # ) -> dict[str, Any]: - # """Extended version with source info support.""" - # ... - - def as_datagram( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_system_tags: bool = False, - ) -> Datagram: - """ - Convert the packet to a Datagram. - - Args: - include_meta_columns: Controls meta column inclusion. - - False: Exclude all meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - - Returns: - Datagram: Datagram representation of packet data - """ - ... - def system_tags(self) -> dict[str, DataValue]: """ Return metadata about the packet's source/origin. @@ -832,13 +628,13 @@ def system_tags(self) -> dict[str, DataValue]: - Processing pipeline information Returns: - dict[str, str | None]: Source information for each data column as key-value pairs. + A dictionary with source information for each data column as key-value pairs. """ ... @runtime_checkable -class Packet(Datagram, Protocol): +class PacketProtocol(DatagramProtocol, Protocol): """ The actual data payload in a stream. @@ -846,223 +642,15 @@ class Packet(Datagram, Protocol): graph. Unlike Tags (which are metadata), Packets contain the actual information that computations operate on. - Packets extend Datagram with additional capabilities for: + Packets extend DatagramProtocol with additional capabilities for: - Source tracking and lineage - Content-based hashing for caching - Metadata inclusion for debugging - The distinction between Tag and Packet is crucial for understanding + The distinction between TagProtocol and PacketProtocol is crucial for understanding data flow: Tags provide context, Packets provide content. """ - def keys( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> tuple[str, ...]: - """ - Return tuple of column names. - - Provides access to column names with filtering options for different - column types. Default returns only data column names. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column inclusion. - - False: Return only data column names (default) - - True: Include all meta column names - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include context column. - include_source: Whether to include source info fields. - - - Returns: - Tuple of column names based on inclusion criteria. - - Example: - >>> datagram.keys() # Data columns only - ('user_id', 'name', 'email') - >>> datagram.keys(include_meta_columns=True) - ('user_id', 'name', 'email', f'{orcapod.META_PREFIX}processed_at', f'{orcapod.META_PREFIX}pipeline_version') - >>> datagram.keys(include_meta_columns=["pipeline"]) - ('user_id', 'name', 'email',f'{orcapod.META_PREFIX}pipeline_version') - >>> datagram.keys(include_context=True) - ('user_id', 'name', 'email', f'{orcapod.CONTEXT_KEY}') - """ - ... - - def types( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> PythonSchema: - """ - Return type specification mapping field names to Python types. - - The TypeSpec enables type checking and validation throughout the system. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column type inclusion. - - False: Exclude meta column types (default) - - True: Include all meta column types - - Collection[str]: Include meta column types matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include context type. - include_source: Whether to include source info fields. - - Returns: - TypeSpec mapping field names to their Python types. - - Example: - >>> datagram.types() - {'user_id': , 'name': } - """ - ... - - def arrow_schema( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> "pa.Schema": - """ - Return PyArrow schema representation. - - The schema provides structured field and type information for efficient - serialization and deserialization with PyArrow. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column schema inclusion. - - False: Exclude meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include context column. - include_source: Whether to include source info fields. - - - Returns: - PyArrow Schema describing the datagram structure. - - Example: - >>> schema = datagram.arrow_schema() - >>> schema.names - ['user_id', 'name'] - """ - ... - - def as_dict( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> dict[str, DataValue]: - """ - Convert datagram to dictionary format. - - Provides a simple key-value representation useful for debugging, - serialization, and interop with dict-based APIs. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column inclusion. - - False: Exclude all meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include the context key. - include_source: Whether to include source info fields. - - - Returns: - Dictionary with requested columns as key-value pairs. - - Example: - >>> data = datagram.as_dict() # {'user_id': 123, 'name': 'Alice'} - >>> full_data = datagram.as_dict( - ... include_meta_columns=True, - ... include_context=True - ... ) - """ - ... - - def as_table( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_context: bool = False, - include_source: bool = False, - ) -> "pa.Table": - """ - Convert datagram to PyArrow Table format. - - Provides a standardized columnar representation suitable for analysis, - processing, and interoperability with Arrow-based tools. - - Args: - include_all_info: If True, include all available information. This option supersedes all other inclusion options. - include_meta_columns: Controls meta column inclusion. - - False: Exclude all meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - include_context: Whether to include the context column. - include_source: Whether to include source info columns in the schema. - - Returns: - PyArrow Table with requested columns. - - Example: - >>> table = datagram.as_table() # Data columns only - >>> full_table = datagram.as_table( - ... include_meta_columns=True, - ... include_context=True - ... ) - >>> filtered = datagram.as_table(include_meta_columns=["pipeline"]) # same as passing f"{orcapod.META_PREFIX}pipeline" - """ - ... - - # TODO: add this back - # def as_arrow_compatible_dict( - # self, - # include_all_info: bool = False, - # include_meta_columns: bool | Collection[str] = False, - # include_context: bool = False, - # include_source: bool = False, - # ) -> dict[str, Any]: - # """Extended version with source info support.""" - # ... - - def as_datagram( - self, - include_all_info: bool = False, - include_meta_columns: bool | Collection[str] = False, - include_source: bool = False, - ) -> Datagram: - """ - Convert the packet to a Datagram. - - Args: - include_meta_columns: Controls meta column inclusion. - - False: Exclude all meta columns (default) - - True: Include all meta columns - - Collection[str]: Include meta columns matching these prefixes. If absent, - {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - - Returns: - Datagram: Datagram representation of packet data - """ - ... - def source_info(self) -> dict[str, str | None]: """ Return metadata about the packet's source/origin. @@ -1084,7 +672,7 @@ def with_source_info( **source_info: str | None, ) -> Self: """ - Create new packet with updated source information. + Create a new packet with updated source information. Adds or updates source metadata for the packet. This is useful for tracking data provenance and lineage through the computational graph. @@ -1093,7 +681,7 @@ def with_source_info( **source_info: Source metadata as keyword arguments. Returns: - New packet instance with updated source information. + A new packet instance with updated source information. Example: >>> updated_packet = packet.with_source_info( diff --git a/src/orcapod/protocols/core_protocols/execution_engine.py b/src/orcapod/protocols/core_protocols/execution_engine.py new file mode 100644 index 00000000..b7e970cf --- /dev/null +++ b/src/orcapod/protocols/core_protocols/execution_engine.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class ExecutionEngineProtocol(Protocol): + # canonical name for the execution engine -- used to label the execution information when saving + @property + def name(self) -> str: ... + + def submit_sync(self, function: Callable, *args, **kwargs) -> Any: + """ + Run the given function with the provided arguments. + This method should be implemented by the execution engine. + """ + ... + + async def submit_async(self, function: Callable, *args, **kwargs) -> Any: + """ + Asynchronously run the given function with the provided arguments. + This method should be implemented by the execution engine. + """ + ... diff --git a/src/orcapod/protocols/core_protocols/executor.py b/src/orcapod/protocols/core_protocols/executor.py new file mode 100644 index 00000000..b56d4734 --- /dev/null +++ b/src/orcapod/protocols/core_protocols/executor.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Protocol, Self, runtime_checkable + +from orcapod.protocols.core_protocols.datagrams import PacketProtocol + +if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols.packet_function import PacketFunctionProtocol + + +@runtime_checkable +class PacketFunctionExecutorProtocol(Protocol): + """Strategy for executing a packet function on a single packet. + + Executors decouple *what* a packet function computes from *where/how* it + runs. Each executor declares which ``packet_function_type_id`` values it + supports. + """ + + @property + def executor_type_id(self) -> str: + """Unique identifier for this executor type, e.g. ``'local'``, ``'ray.v0'``.""" + ... + + def supported_function_type_ids(self) -> frozenset[str]: + """Return the set of ``packet_function_type_id`` values this executor can handle. + + Return an empty frozenset to indicate support for *all* function types. + """ + ... + + def supports(self, packet_function_type_id: str) -> bool: + """Return ``True`` if this executor can run functions of the given type.""" + ... + + def execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Synchronously execute *packet_function* on *packet*. + + The executor should invoke ``packet_function.direct_call(packet)`` + in the appropriate execution environment and pass through its + ``(result, CapturedLogs)`` tuple. + """ + ... + + async def async_execute( + self, + packet_function: PacketFunctionProtocol, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Asynchronous counterpart of ``execute``.""" + ... + + @property + def supports_concurrent_execution(self) -> bool: + """Whether this executor can meaningfully run multiple packets concurrently. + + When ``True``, iteration machinery may submit all packets via + ``async_execute`` concurrently and collect results before yielding. + """ + ... + + def with_options(self, **opts: Any) -> Self: + """Return a **new** executor instance configured with the given per-node options. + + Used by the pipeline to produce node-specific executor instances + (e.g. with different CPU/GPU allocations) from a shared base executor. + Implementations must always return a new instance, even when no + options change, so that executors are effectively immutable value + objects after construction. + """ + ... + + def get_execution_data(self) -> dict[str, Any]: + """Return metadata describing the execution environment. + + Stored alongside results for observability/provenance but does not + affect content or pipeline hashes. + """ + ... + + +@runtime_checkable +class PythonFunctionExecutorProtocol(PacketFunctionExecutorProtocol, Protocol): + """Executor protocol for Python callable-based packet functions. + + Extends ``PacketFunctionExecutorProtocol`` with callable-level + execution methods. The executor receives the raw Python function + and its keyword arguments directly — the packet function handles + packet construction/deconstruction around the executor call. + """ + + def execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + """Synchronously execute *fn* with *kwargs*, capturing I/O. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options (e.g. resource + overrides). + + Returns: + A ``(raw_result, CapturedLogs)`` tuple. ``raw_result`` is the + return value of *fn* (or ``None`` on failure). + ``CapturedLogs.success`` is ``False`` when the function raised; + the traceback is stored in ``CapturedLogs.traceback``. + """ + ... + + async def async_execute_callable( + self, + fn: Callable[..., Any], + kwargs: dict[str, Any], + executor_options: dict[str, Any] | None = None, + ) -> "tuple[Any, CapturedLogs]": + """Asynchronously execute *fn* with *kwargs*, capturing I/O. + + Args: + fn: The Python callable to execute. + kwargs: Keyword arguments to pass to *fn*. + executor_options: Optional per-call options. + + Returns: + A ``(raw_result, CapturedLogs)`` tuple. + """ + ... diff --git a/src/orcapod/protocols/core_protocols/function_pod.py b/src/orcapod/protocols/core_protocols/function_pod.py new file mode 100644 index 00000000..b56307b1 --- /dev/null +++ b/src/orcapod/protocols/core_protocols/function_pod.py @@ -0,0 +1,37 @@ +from typing import Any, Protocol, runtime_checkable + +from orcapod.protocols.core_protocols.datagrams import PacketProtocol, TagProtocol +from orcapod.protocols.core_protocols.packet_function import PacketFunctionProtocol +from orcapod.protocols.core_protocols.pod import PodProtocol +from orcapod.protocols.hashing_protocols import PipelineElementProtocol + + +@runtime_checkable +class FunctionPodProtocol(PodProtocol, PipelineElementProtocol, Protocol): + """ + PodProtocol based on PacketFunctionProtocol. + """ + + @property + def packet_function(self) -> PacketFunctionProtocol: + """ + The PacketFunctionProtocol that defines the computation for this FunctionPodProtocol. + """ + ... + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: ... + + async def async_process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: ... + + def to_config(self) -> dict[str, Any]: + """Serialize this function pod to a JSON-compatible config dict.""" + ... + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FunctionPodProtocol": + """Reconstruct a function pod from a config dict.""" + ... diff --git a/src/orcapod/protocols/core_protocols/kernel.py b/src/orcapod/protocols/core_protocols/kernel.py deleted file mode 100644 index 842d7af2..00000000 --- a/src/orcapod/protocols/core_protocols/kernel.py +++ /dev/null @@ -1,201 +0,0 @@ -from collections.abc import Collection -from datetime import datetime -from typing import Any, Protocol, runtime_checkable -from orcapod.protocols.hashing_protocols import ContentIdentifiable -from orcapod.types import PythonSchema -from orcapod.protocols.core_protocols.base import Labelable -from orcapod.protocols.core_protocols.streams import Stream, LiveStream - - -@runtime_checkable -class Kernel(ContentIdentifiable, Labelable, Protocol): - """ - The fundamental unit of computation in Orcapod. - - Kernels are the building blocks of computational graphs, transforming - zero, one, or more input streams into a single output stream. They - encapsulate computation logic while providing consistent interfaces - for validation, type checking, and execution. - - Key design principles: - - Immutable: Kernels don't change after creation - - Deterministic: Same inputs always produce same outputs - - Composable: Kernels can be chained and combined - - Trackable: All invocations are recorded for lineage - - Type-safe: Strong typing and validation throughout - - Execution modes: - - __call__(): Full-featured execution with tracking, returns LiveStream - - forward(): Pure computation without side effects, returns Stream - - The distinction between these modes enables both production use (with - full tracking) and testing/debugging (without side effects). - """ - - @property - def reference(self) -> tuple[str, ...]: - """ - Reference to the kernel - - The reference is used for caching/storage and tracking purposes. - As the name indicates, this is how data originating from the kernel will be referred to. - - - Returns: - tuple[str, ...]: Reference for this kernel - """ - ... - - @property - def data_context_key(self) -> str: - """ - Return the context key for this kernel's data processing. - - The context key is used to interpret how data columns should be - processed and converted. It provides semantic meaning to the data - being processed by this kernel. - - Returns: - str: Context key for this kernel's data processing - """ - ... - - @property - def last_modified(self) -> datetime | None: - """ - When the kernel was last modified. For most kernels, this is the timestamp - of the kernel creation. - """ - ... - - def __call__( - self, *streams: Stream, label: str | None = None, **kwargs - ) -> LiveStream: - """ - Main interface for kernel invocation with full tracking and guarantees. - - This is the primary way to invoke kernels in production. It provides - a complete execution pipeline: - 1. Validates input streams against kernel requirements - 2. Registers the invocation with the computational graph - 3. Calls forward() to perform the actual computation - 4. Ensures the result is a LiveStream that stays current - - The returned LiveStream automatically stays up-to-date with its - upstream dependencies, making it suitable for real-time processing - and reactive applications. - - Args: - *streams: Input streams to process (can be empty for source kernels) - label: Optional label for this invocation (overrides kernel.label) - **kwargs: Additional arguments for kernel configuration - - Returns: - LiveStream: Live stream that stays up-to-date with upstreams - - Raises: - ValidationError: If input streams are invalid for this kernel - TypeMismatchError: If stream types are incompatible - ValueError: If required arguments are missing - """ - ... - - def forward(self, *streams: Stream) -> Stream: - """ - Perform the actual computation without side effects. - - This method contains the core computation logic and should be - overridden by subclasses. It performs pure computation without: - - Registering with the computational graph - - Performing validation (caller's responsibility) - - Guaranteeing result type (may return static or live streams) - - The returned stream must be accurate at the time of invocation but - need not stay up-to-date with upstream changes. This makes forward() - suitable for: - - Testing and debugging - - Batch processing where currency isn't required - - Internal implementation details - - Args: - *streams: Input streams to process - - Returns: - Stream: Result of the computation (may be static or live) - """ - ... - - def output_types( - self, *streams: Stream, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Determine output types without triggering computation. - - This method performs type inference based on input stream types, - enabling efficient type checking and stream property queries. - It should be fast and not trigger any expensive computation. - - Used for: - - Pre-execution type validation - - Query planning and optimization - - Schema inference in complex pipelines - - IDE support and developer tooling - - Args: - *streams: Input streams to analyze - - Returns: - tuple[TypeSpec, TypeSpec]: (tag_types, packet_types) for output - - Raises: - ValidationError: If input types are incompatible - TypeError: If stream types cannot be processed - """ - ... - - def validate_inputs(self, *streams: Stream) -> None: - """ - Validate input streams, raising exceptions if incompatible. - - This method is called automatically by __call__ before computation - to provide fail-fast behavior. It should check: - - Number of input streams - - Stream types and schemas - - Any kernel-specific requirements - - Business logic constraints - - The goal is to catch errors early, before expensive computation - begins, and provide clear error messages for debugging. - - Args: - *streams: Input streams to validate - - Raises: - ValidationError: If streams are invalid for this kernel - TypeError: If stream types are incompatible - ValueError: If stream content violates business rules - """ - ... - - def identity_structure(self, streams: Collection[Stream] | None = None) -> Any: - """ - Generate a unique identity structure for this kernel and/or kernel invocation. - When invoked without streams, it should return a structure - that uniquely identifies the kernel itself (e.g., class name, parameters). - When invoked with streams, it should include the identity of the streams - to distinguish different invocations of the same kernel. - - This structure is used for: - - Caching and memoization - - Debugging and error reporting - - Tracking kernel invocations in computational graphs - - Args: - streams: Optional input streams for this invocation. If None, identity_structure is - based solely on the kernel. If streams are provided, they are included in the identity - to differentiate between different invocations of the same kernel. - - Returns: - Any: Unique identity structure (e.g., tuple of class name and stream identities) - """ - ... diff --git a/src/orcapod/protocols/core_protocols/labelable.py b/src/orcapod/protocols/core_protocols/labelable.py new file mode 100644 index 00000000..8e3de0d5 --- /dev/null +++ b/src/orcapod/protocols/core_protocols/labelable.py @@ -0,0 +1,47 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class LabelableProtocol(Protocol): + """ + Protocol for objects that can have a human-readable label. + + Labels provide meaningful names for objects in the computational graph, + aiding in debugging, visualization, and monitoring. They serve as + human-friendly identifiers that complement the technical identifiers + used internally. + + Labels are optional but highly recommended for: + - Debugging complex computational graphs + - Visualization and monitoring tools + - Error messages and logging + - User interfaces and dashboards + + """ + + @property + def label(self) -> str: + """ + Return the human-readable label for this object. + + Labels should be descriptive and help users understand the purpose + or role of the object in the computational graph. + + Returns: + str: Human-readable label for this object + None: No label is set (will use default naming) + """ + ... + + @label.setter + def label(self, label: str | None) -> None: + """ + Set the human-readable label for this object. + + Labels should be descriptive and help users understand the purpose + or role of the object in the computational graph. + + Args: + value (str): Human-readable label for this object + """ + ... diff --git a/src/orcapod/protocols/core_protocols/operator_pod.py b/src/orcapod/protocols/core_protocols/operator_pod.py new file mode 100644 index 00000000..ce514b7c --- /dev/null +++ b/src/orcapod/protocols/core_protocols/operator_pod.py @@ -0,0 +1,22 @@ +from typing import Any, Protocol, runtime_checkable + +from orcapod.protocols.core_protocols.pod import PodProtocol + + +@runtime_checkable +class OperatorPodProtocol(PodProtocol, Protocol): + """ + PodProtocol that performs operations on streams. + + This is a base protocol for pods that perform operations on streams. + TODO: add a method to map out source relationship + """ + + def to_config(self) -> dict[str, Any]: + """Serialize this operator to a JSON-compatible config dict.""" + ... + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "OperatorPodProtocol": + """Reconstruct an operator from a config dict.""" + ... diff --git a/src/orcapod/protocols/core_protocols/packet_function.py b/src/orcapod/protocols/core_protocols/packet_function.py new file mode 100644 index 00000000..c806597a --- /dev/null +++ b/src/orcapod/protocols/core_protocols/packet_function.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +from orcapod.protocols.core_protocols.datagrams import PacketProtocol +from orcapod.protocols.core_protocols.executor import PacketFunctionExecutorProtocol +from orcapod.protocols.core_protocols.labelable import LabelableProtocol +from orcapod.protocols.hashing_protocols import ( + ContentIdentifiableProtocol, + PipelineElementProtocol, +) +from orcapod.types import Schema + +if TYPE_CHECKING: + from orcapod.pipeline.logging_capture import CapturedLogs + + +@runtime_checkable +class PacketFunctionProtocol( + ContentIdentifiableProtocol, PipelineElementProtocol, LabelableProtocol, Protocol +): + """Protocol for a packet-processing function. + + Processes individual packets with declared input/output schemas. + """ + + # ==================== Identity & Metadata ==================== + @property + def packet_function_type_id(self) -> str: + """How functions are defined and executed (e.g., python.function.v2)""" + ... + + @property + def canonical_function_name(self) -> str: + """Human-readable function identifier""" + ... + + @property + def major_version(self) -> int: + """Breaking changes increment this""" + ... + + @property + def minor_version_string(self) -> str: + """Flexible minor version (e.g., "1", "4.3rc", "apple")""" + ... + + @property + def input_packet_schema(self) -> Schema: + """Schema describing the input packets this function accepts.""" + ... + + @property + def output_packet_schema(self) -> Schema: + """Schema describing the output packets this function produces.""" + ... + + # ==================== Content-Addressable Identity ==================== + def get_function_variation_data(self) -> dict[str, Any]: + """Raw data defining function variation - system computes hash""" + ... + + def get_execution_data(self) -> dict[str, Any]: + """Raw data defining execution context - system computes hash""" + ... + + # ==================== Executor ==================== + + @property + def executor(self) -> PacketFunctionExecutorProtocol | None: + """The executor used to run this function, or ``None`` for direct execution.""" + ... + + @executor.setter + def executor(self, executor: PacketFunctionExecutorProtocol | None) -> None: + """Set or clear the executor.""" + ... + + # ==================== Execution ==================== + + def call( + self, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Process a single packet, routing through the executor if one is set. + + Args: + packet: The data payload to process. + + Returns: + A ``(output_packet, captured_logs)`` tuple. ``output_packet`` + is ``None`` when the function filters the packet out or when + the execution failed (check ``captured_logs.success``). + """ + ... + + async def async_call( + self, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Asynchronously process a single packet, routing through the executor if set. + + Args: + packet: The data payload to process. + + Returns: + A ``(output_packet, captured_logs)`` tuple. + """ + ... + + def direct_call( + self, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Execute the function's native computation on *packet*. + + This is the method executors invoke, bypassing executor routing. + On user-function failure the exception is caught internally and + ``(None, captured_with_success=False)`` is returned — no re-raise. + + Args: + packet: The data payload to process. + + Returns: + A ``(output_packet, captured_logs)`` tuple. + """ + ... + + async def direct_async_call( + self, + packet: PacketProtocol, + ) -> "tuple[PacketProtocol | None, CapturedLogs]": + """Asynchronous counterpart of ``direct_call``.""" + ... + + # ==================== Serialization ==================== + + def to_config(self) -> dict[str, Any]: + """Serialize this packet function to a JSON-compatible config dict.""" + ... + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PacketFunctionProtocol": + """Reconstruct a packet function from a config dict.""" + ... diff --git a/src/orcapod/protocols/core_protocols/pod.py b/src/orcapod/protocols/core_protocols/pod.py new file mode 100644 index 00000000..e4466f79 --- /dev/null +++ b/src/orcapod/protocols/core_protocols/pod.py @@ -0,0 +1,144 @@ +from __future__ import annotations + +from collections.abc import Collection +from typing import Any, Protocol, TypeAlias, runtime_checkable + +from orcapod.protocols.core_protocols.streams import StreamProtocol +from orcapod.protocols.core_protocols.traceable import TraceableProtocol +from orcapod.types import ColumnConfig, Schema + +# Core recursive types +ArgumentGroup: TypeAlias = "SymmetricGroup | OrderedGroup | StreamProtocol" + +SymmetricGroup: TypeAlias = frozenset[ArgumentGroup] # Order-independent +OrderedGroup: TypeAlias = tuple[ArgumentGroup, ...] # Order-dependent + + +@runtime_checkable +class PodProtocol(TraceableProtocol, Protocol): + """ + The fundamental unit of computation in Orcapod. + + Pods are the building blocks of computational graphs, transforming + zero, one, or more input streams into a single output stream. They + encapsulate computation logic while providing consistent interfaces + for validation, type checking, and execution. + + Key design principles: + - Immutable: Pods don't change after creation + - Composable: Pods can be chained and combined + - Type-safe: Strong typing and validation throughout + + + The distinction between these modes enables both production use (with + full tracking) and testing/debugging (without side effects). + """ + + @property + def uri(self) -> tuple[str, ...]: + """ + Unique identifier for the pod + + The URI is used for caching/storage and tracking purposes. + As the name indicates, this is how data originating from the pod will be referred to. + + Returns: + tuple[str, ...]: URI for this pod + """ + ... + + def validate_inputs(self, *streams: StreamProtocol) -> None: + """ + Validate input streams, raising exceptions if invalid. + + Should check: + - Number of input streams + - StreamProtocol types and schemas + - Kernel-specific requirements + - Business logic constraints + + Args: + *streams: Input streams to validate + + Raises: + PodInputValidationError: If inputs are invalid + """ + ... + + def argument_symmetry(self, streams: Collection[StreamProtocol]) -> ArgumentGroup: + """ + Describe symmetry/ordering constraints on input arguments. + + Returns a structure encoding which arguments can be reordered: + - SymmetricGroup (frozenset): Arguments commute (order doesn't matter) + - OrderedGroup (tuple): Arguments have fixed positions + - Nesting expresses partial symmetry + + Examples: + Full symmetry (Join): + return frozenset([a, b, c]) + + No symmetry (Concatenate): + return (a, b, c) + + Partial symmetry: + return (frozenset([a, b]), c) + # a,b are interchangeable, c has fixed position + """ + ... + + def output_schema( + self, + *streams: StreamProtocol, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: + """ + Determine output schemas without triggering computation. + + This method performs type inference based on input stream types, + enabling efficient type checking and stream property queries. + It should be fast and not trigger any expensive computation. + + Used for: + - Pre-execution type validation + - Query planning and optimization + - Schema inference in complex pipelines + - IDE support and developer tooling + + Args: + *streams: Input streams to analyze + + Returns: + tuple[TypeSpec, TypeSpec]: (tag_types, packet_types) for output + + Raises: + ValidationError: If input types are incompatible + TypeError: If stream types cannot be processed + """ + ... + + def process( + self, *streams: StreamProtocol, label: str | None = None + ) -> StreamProtocol: + """ + Executes the computation on zero or more input streams. + This method contains the core computation logic and should be + overridden by subclasses. It performs pure computation without: + - Performing validation (caller's responsibility) + - Guaranteeing result type (may return static or live streams) + + The returned stream must be accurate at the time of invocation but + need not stay up-to-date with upstream changes. This makes forward() + suitable for: + - Testing and debugging + - Batch processing where currency isn't required + - Internal implementation details + + Args: + *streams: Input streams to process + + Returns: + StreamProtocol: Result of the computation (may be static or live) + """ + ... diff --git a/src/orcapod/protocols/core_protocols/pods.py b/src/orcapod/protocols/core_protocols/pods.py deleted file mode 100644 index 616fd793..00000000 --- a/src/orcapod/protocols/core_protocols/pods.py +++ /dev/null @@ -1,232 +0,0 @@ -from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable - -from orcapod.protocols.core_protocols.base import ExecutionEngine -from orcapod.protocols.core_protocols.datagrams import Packet, Tag -from orcapod.protocols.core_protocols.kernel import Kernel -from orcapod.types import PythonSchema - -if TYPE_CHECKING: - import pyarrow as pa - - -@runtime_checkable -class Pod(Kernel, Protocol): - """ - Specialized kernel for packet-level processing with advanced caching. - - Pods represent a different computational model from regular kernels: - - Process data one packet at a time (enabling fine-grained parallelism) - - Support just-in-time evaluation (computation deferred until needed) - - Provide stricter type contracts (clear input/output schemas) - - Enable advanced caching strategies (packet-level caching) - - The Pod abstraction is ideal for: - - Expensive computations that benefit from caching - - Operations that can be parallelized at the packet level - - Transformations with strict type contracts - - Processing that needs to be deferred until access time - - Functions that operate on individual data items - - Pods use a different execution model where computation is deferred - until results are actually needed, enabling efficient resource usage - and fine-grained caching. - """ - - @property - def version(self) -> str: ... - - def get_record_id(self, packet: Packet, execution_engine_hash: str) -> str: ... - - @property - def tiered_pod_id(self) -> dict[str, str]: - """ - Return a dictionary representation of the tiered pod's unique identifier. - The key is supposed to be ordered from least to most specific, allowing - for hierarchical identification of the pod. - - This is primarily used for tiered memoization/caching strategies. - - Returns: - dict[str, str]: Dictionary representation of the pod's ID - """ - ... - - def input_packet_types(self) -> PythonSchema: - """ - TypeSpec for input packets that this Pod can process. - - Defines the exact schema that input packets must conform to. - Pods are typically much stricter about input types than regular - kernels, requiring precise type matching for their packet-level - processing functions. - - This specification is used for: - - Runtime type validation - - Compile-time type checking - - Schema inference and documentation - - Input validation and error reporting - - Returns: - TypeSpec: Dictionary mapping field names to required packet types - """ - ... - - def output_packet_types(self) -> PythonSchema: - """ - TypeSpec for output packets that this Pod produces. - - Defines the schema of packets that will be produced by this Pod. - This is typically determined by the Pod's computational function - and is used for: - - Type checking downstream kernels - - Schema inference in complex pipelines - - Query planning and optimization - - Documentation and developer tooling - - Returns: - TypeSpec: Dictionary mapping field names to output packet types - """ - ... - - async def async_call( - self, - tag: Tag, - packet: Packet, - record_id: str | None = None, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[Tag, Packet | None]: ... - - def call( - self, - tag: Tag, - packet: Packet, - record_id: str | None = None, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> tuple[Tag, Packet | None]: - """ - Process a single packet with its associated tag. - - This is the core method that defines the Pod's computational behavior. - It processes one (tag, packet) pair at a time, enabling: - - Fine-grained caching at the packet level - - Parallelization opportunities - - Just-in-time evaluation - - Filtering operations (by returning None) - - The method signature supports: - - Tag transformation (modify metadata) - - Packet transformation (modify content) - - Filtering (return None to exclude packet) - - Pass-through (return inputs unchanged) - - Args: - tag: Metadata associated with the packet - packet: The data payload to process - - Returns: - tuple[Tag, Packet | None]: - - Tag: Output tag (may be modified from input) - - Packet: Processed packet, or None to filter it out - - Raises: - TypeError: If packet doesn't match input_packet_types - ValueError: If packet data is invalid for processing - """ - ... - - -@runtime_checkable -class CachedPod(Pod, Protocol): - async def async_call( - self, - tag: Tag, - packet: Packet, - record_id: str | None = None, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, - ) -> tuple[Tag, Packet | None]: ... - - def set_mode(self, mode: str) -> None: ... - - @property - def mode(self) -> str: ... - - # @mode.setter - # def mode(self, value: str) -> None: ... - - def call( - self, - tag: Tag, - packet: Packet, - record_id: str | None = None, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - skip_cache_lookup: bool = False, - skip_cache_insert: bool = False, - ) -> tuple[Tag, Packet | None]: - """ - Process a single packet with its associated tag. - - This is the core method that defines the Pod's computational behavior. - It processes one (tag, packet) pair at a time, enabling: - - Fine-grained caching at the packet level - - Parallelization opportunities - - Just-in-time evaluation - - Filtering operations (by returning None) - - The method signature supports: - - Tag transformation (modify metadata) - - Packet transformation (modify content) - - Filtering (return None to exclude packet) - - Pass-through (return inputs unchanged) - - Args: - tag: Metadata associated with the packet - packet: The data payload to process - - Returns: - tuple[Tag, Packet | None]: - - Tag: Output tag (may be modified from input) - - Packet: Processed packet, or None to filter it out - - Raises: - TypeError: If packet doesn't match input_packet_types - ValueError: If packet data is invalid for processing - """ - ... - - def get_cached_output_for_packet(self, input_packet: Packet) -> Packet | None: - """ - Retrieve the cached output packet for a given input packet. - - Args: - input_packet: The input packet to look up in the cache - - Returns: - Packet | None: The cached output packet, or None if not found - """ - ... - - def get_all_cached_outputs( - self, include_system_columns: bool = False - ) -> "pa.Table | None": - """ - Retrieve all packets processed by this Pod. - - This method returns a table containing all packets processed by the Pod, - including metadata and system columns if requested. It is useful for: - - Debugging and analysis - - Auditing and data lineage tracking - - Performance monitoring - - Args: - include_system_columns: Whether to include system columns in the output - - Returns: - pa.Table | None: A table containing all processed records, or None if no records are available - """ - ... diff --git a/src/orcapod/protocols/core_protocols/source.py b/src/orcapod/protocols/core_protocols/source.py deleted file mode 100644 index e94f3367..00000000 --- a/src/orcapod/protocols/core_protocols/source.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Protocol, runtime_checkable - -from orcapod.protocols.core_protocols.kernel import Kernel -from orcapod.protocols.core_protocols.streams import Stream - - -@runtime_checkable -class Source(Kernel, Stream, Protocol): - """ - Entry point for data into the computational graph. - - Sources are special objects that serve dual roles: - - As Kernels: Can be invoked to produce streams - - As Streams: Directly provide data without upstream dependencies - - Sources represent the roots of computational graphs and typically - interface with external data sources. They bridge the gap between - the outside world and the Orcapod computational model. - - Common source types: - - File readers (CSV, JSON, Parquet, etc.) - - Database connections and queries - - API endpoints and web services - - Generated data sources (synthetic data) - - Manual data input and user interfaces - - Message queues and event streams - - Sources have unique properties: - - No upstream dependencies (upstreams is empty) - - Can be both invoked and iterated - - Serve as the starting point for data lineage - - May have their own refresh/update mechanisms - """ diff --git a/src/orcapod/protocols/core_protocols/sources.py b/src/orcapod/protocols/core_protocols/sources.py new file mode 100644 index 00000000..32117c25 --- /dev/null +++ b/src/orcapod/protocols/core_protocols/sources.py @@ -0,0 +1,31 @@ +from typing import Any, Protocol, runtime_checkable + +from orcapod.protocols.core_protocols.streams import StreamProtocol + + +@runtime_checkable +class SourceProtocol(StreamProtocol, Protocol): + """ + Protocol for root sources — streams with no upstream dependencies that + expose provenance identity and optional field resolution. + + A SourceProtocol is a StreamProtocol where: + - ``source`` is always ``None`` (no upstream pod) + - ``upstreams`` is always empty + - ``source_id`` provides a canonical name for registry and provenance + - ``resolve_field`` enables lookup of individual field values by record id + """ + + @property + def source_id(self) -> str: ... + + def resolve_field(self, record_id: str, field_name: str) -> Any: ... + + def to_config(self) -> dict[str, Any]: + """Serialize source configuration to a JSON-compatible dict.""" + ... + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "SourceProtocol": + """Reconstruct a source instance from a config dict.""" + ... diff --git a/src/orcapod/protocols/core_protocols/streams.py b/src/orcapod/protocols/core_protocols/streams.py index 0cd3fb40..e4568f31 100644 --- a/src/orcapod/protocols/core_protocols/streams.py +++ b/src/orcapod/protocols/core_protocols/streams.py @@ -1,25 +1,25 @@ -from collections.abc import Collection, Iterator, Mapping -from datetime import datetime +from collections.abc import AsyncIterator, Collection, Iterator, Mapping from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable -from orcapod.protocols.core_protocols.base import ExecutionEngine, Labelable -from orcapod.protocols.core_protocols.datagrams import Packet, Tag -from orcapod.protocols.hashing_protocols import ContentIdentifiable -from orcapod.types import PythonSchema +from orcapod.protocols.core_protocols.datagrams import PacketProtocol, TagProtocol +from orcapod.protocols.core_protocols.traceable import TraceableProtocol +from orcapod.protocols.hashing_protocols import PipelineElementProtocol +from orcapod.types import ColumnConfig, Schema if TYPE_CHECKING: + import pandas as pd import polars as pl import pyarrow as pa - import pandas as pd - from orcapod.protocols.core_protocols.kernel import Kernel + + from orcapod.protocols.core_protocols.pod import PodProtocol @runtime_checkable -class Stream(ContentIdentifiable, Labelable, Protocol): +class StreamProtocol(TraceableProtocol, PipelineElementProtocol, Protocol): """ Base protocol for all streams in Orcapod. - Streams represent sequences of (Tag, Packet) pairs flowing through the + Streams represent sequences of (TagProtocol, PacketProtocol) pairs flowing through the computational graph. They are the fundamental data structure connecting kernels and carrying both data and metadata. @@ -35,89 +35,43 @@ class Stream(ContentIdentifiable, Labelable, Protocol): - Conversion to common formats (tables, dictionaries) """ - @property - def substream_identities(self) -> tuple[str, ...]: - """ - Unique identifiers for sub-streams within this stream. - - This property provides a way to identify and differentiate - sub-streams that may be part of a larger stream. It is useful - for tracking and managing complex data flows. - - Returns: - tuple[str, ...]: Unique identifiers for each sub-stream - """ - ... + # TODO: add substream system @property - def execution_engine(self) -> ExecutionEngine | None: - """ - The execution engine attached to this stream. By default, the stream - will use this execution engine whenever it needs to perform computation. - None means the stream is not attached to any execution engine and will default - to running natively. + def producer(self) -> "PodProtocol | None": """ - - @execution_engine.setter - def execution_engine(self, engine: ExecutionEngine | None) -> None: - """ - Set the execution engine for this stream. - - This allows the stream to use a specific execution engine for - computation, enabling optimized execution strategies and resource - management. - - Args: - engine: The execution engine to attach to this stream - """ - ... - - def get_substream(self, substream_id: str) -> "Stream": - """ - Retrieve a specific sub-stream by its identifier. - - This method allows access to individual sub-streams within the - main stream, enabling focused operations on specific data segments. - - Args: - substream_id: Unique identifier for the desired sub-stream. - - Returns: - Stream: The requested sub-stream if it exists - """ - ... - - @property - def source(self) -> "Kernel | None": - """ - The kernel that produced this stream. + The pod that produced this stream, if any. This provides lineage information for tracking data flow through the computational graph. Root streams (like file sources) may - have no source kernel. + have no source pod. Returns: - Kernel: The source kernel that created this stream - None: This is a root stream with no source kernel + PodProtocol: The source pod that created this stream + None: This is a root stream with no source pod """ ... @property - def upstreams(self) -> tuple["Stream", ...]: + def upstreams(self) -> tuple["StreamProtocol", ...]: """ Input streams used to produce this stream. These are the streams that were provided as input to the source - kernel when this stream was created. Used for dependency tracking - and cache invalidation. + pod when this stream was created. Used for dependency tracking + and cache invalidation. Note that `source` must be checked for + upstreams to be meaningfully inspected. Returns: - tuple[Stream, ...]: Upstream dependency streams (empty for sources) + tuple[StreamProtocol, ...]: Upstream dependency streams (empty for sources) """ ... def keys( - self, include_system_tags: bool = False + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> tuple[tuple[str, ...], tuple[str, ...]]: """ Available keys/fields in the stream content. @@ -134,38 +88,12 @@ def keys( """ ... - def tag_keys(self, include_system_tags: bool = False) -> tuple[str, ...]: - """ - Return the keys used for the tag in the pipeline run records. - This is used to store the run-associated tag info. - """ - ... - - def packet_keys(self) -> tuple[str, ...]: - """ - Return the keys used for the packet in the pipeline run records. - This is used to store the run-associated packet info. - """ - ... - - def types( - self, include_system_tags: bool = False - ) -> tuple[PythonSchema, PythonSchema]: - """ - Type specifications for the stream content. - - Returns the type schema for both tags and packets in this stream. - This information is used for: - - Type checking and validation - - Schema inference and planning - - Compatibility checking between kernels - - Returns: - tuple[TypeSpec, TypeSpec]: (tag_types, packet_types) - """ - ... - - def tag_types(self, include_system_tags: bool = False) -> PythonSchema: + def output_schema( + self, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> tuple[Schema, Schema]: """ Type specifications for the stream content. @@ -176,136 +104,64 @@ def tag_types(self, include_system_tags: bool = False) -> PythonSchema: - Compatibility checking between kernels Returns: - tuple[TypeSpec, TypeSpec]: (tag_types, packet_types) + tuple[Schema, Schema]: (tag_types, packet_types) """ ... - def packet_types(self) -> PythonSchema: ... - - @property - def last_modified(self) -> datetime | None: + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: """ - When the stream's content was last modified. - - This property is crucial for caching decisions and dependency tracking: - - datetime: Content was last modified at this time (cacheable) - - None: Content is never stable, always recompute (some dynamic streams) - - Both static and live streams typically return datetime values, but - live streams update this timestamp whenever their content changes. - - Returns: - datetime: Timestamp of last modification for most streams - None: Stream content is never stable (some special dynamic streams) - """ - ... + Generates explicit iterator over (tag, packet) pairs in the stream. - @property - def is_current(self) -> bool: - """ - Whether the stream is up-to-date with its dependencies. - - A stream is current if its content reflects the latest state of its - source kernel and upstream streams. This is used for cache validation - and determining when refresh is needed. - - For live streams, this should always return True since they stay - current automatically. For static streams, this indicates whether - the cached content is still valid. - - Returns: - bool: True if stream is up-to-date, False if refresh needed - """ - ... - - def __iter__(self) -> Iterator[tuple[Tag, Packet]]: - """ - Iterate over (tag, packet) pairs in the stream. - - This is the primary way to access stream data. The behavior depends - on the stream type: - - Static streams: Return cached/precomputed data - - Live streams: May trigger computation and always reflect current state + Note that multiple invocation of `iter_packets` may not always + return an identical iterator. Yields: - tuple[Tag, Packet]: Sequential (tag, packet) pairs + tuple[TagProtocol, PacketProtocol]: Sequential (tag, packet) pairs """ ... - def iter_packets( - self, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Iterator[tuple[Tag, Packet]]: + def async_iter_packets(self) -> AsyncIterator[tuple[TagProtocol, PacketProtocol]]: """ - Alias for __iter__ for explicit packet iteration. - - Provides a more explicit method name when the intent is to iterate - over packets specifically, improving code readability. + Generates asynchronous iterator over (tag, packet) pairs in the stream. - This method must return an immutable iterator -- that is, the returned iterator - should not change and must consistently return identical tag,packet pairs across - multiple iterations of the iterator. - - Note that this is NOT to mean that multiple invocation of `iter_packets` must always - return an identical iterator. The iterator returned by `iter_packets` may change - between invocations, but the iterator itself must not change. Consequently, it should be understood - that the returned iterators may be a burden on memory if the stream is large or infinite. + Note that multiple invocation of `async_iter_packets` may not always + return an identical iterator. Yields: - tuple[Tag, Packet]: Sequential (tag, packet) pairs - """ - ... - - def run( - self, - *args: Any, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """ - Execute the stream using the provided execution engine. - - This method triggers computation of the stream content based on its - source kernel and upstream streams. It returns a new stream instance - containing the computed (tag, packet) pairs. - - Args: - execution_engine: The execution engine to use for computation + tuple[tagProtocol, PacketProtcol]: Asynchrnous sequential (tag, packet) pairs """ ... - async def run_async( + def as_table( self, - *args: Any, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pa.Table": """ - Asynchronously execute the stream using the provided execution engine. + Convert the entire stream to a PyArrow Table. - This method triggers computation of the stream content based on its - source kernel and upstream streams. It returns a new stream instance - containing the computed (tag, packet) pairs. + Materializes all (tag, packet) pairs into a single table for + analysis and processing. This operation may be expensive for + large streams or live streams that need computation. - Args: - execution_engine: The execution engine to use for computation + If include_content_hash is True, an additional column called "_content_hash" + containing the content hash of each packet is included. If include_content_hash + is a string, it is used as the name of the content hash column. + Returns: + pa.Table: Complete stream data as a PyArrow Table """ ... + +class StreamWithOperationsProtocol(StreamProtocol, Protocol): def as_df( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pl.DataFrame": """ Convert the entire stream to a Polars DataFrame. @@ -314,13 +170,9 @@ def as_df( def as_lazy_frame( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, ) -> "pl.LazyFrame": """ Load the entire stream to a Polars LazyFrame. @@ -329,58 +181,29 @@ def as_lazy_frame( def as_polars_df( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pl.DataFrame": ... + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pl.DataFrame": + """ + Convert the entire stream to a Polars DataFrame. + """ + ... def as_pandas_df( self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - index_by_tags: bool = True, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pd.DataFrame": ... - - def as_table( - self, - include_data_context: bool = False, - include_source: bool = False, - include_system_tags: bool = False, - include_content_hash: bool | str = False, - sort_by_tags: bool = True, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> "pa.Table": + *, + columns: ColumnConfig | dict[str, Any] | None = None, + all_info: bool = False, + ) -> "pd.DataFrame": """ - Convert the entire stream to a PyArrow Table. - - Materializes all (tag, packet) pairs into a single table for - analysis and processing. This operation may be expensive for - large streams or live streams that need computation. - - If include_content_hash is True, an additional column called "_content_hash" - containing the content hash of each packet is included. If include_content_hash - is a string, it is used as the name of the content hash column. - - Returns: - pa.Table: Complete stream data as a PyArrow Table + Convert the entire stream to a Pandas DataFrame. """ ... def flow( self, - execution_engine: ExecutionEngine | None = None, - execution_engine_opts: dict[str, Any] | None = None, - ) -> Collection[tuple[Tag, Packet]]: + ) -> Collection[tuple[TagProtocol, PacketProtocol]]: """ Return the entire stream as a collection of (tag, packet) pairs. @@ -394,7 +217,9 @@ def flow( """ ... - def join(self, other_stream: "Stream", label: str | None = None) -> "Stream": + def join( + self, other_stream: "StreamProtocol", label: str | None = None + ) -> "StreamProtocol": """ Join this stream with another stream. @@ -410,7 +235,9 @@ def join(self, other_stream: "Stream", label: str | None = None) -> "Stream": """ ... - def semi_join(self, other_stream: "Stream", label: str | None = None) -> "Stream": + def semi_join( + self, other_stream: "StreamProtocol", label: str | None = None + ) -> "StreamProtocol": """ Perform a semi-join with another stream. @@ -431,7 +258,7 @@ def map_tags( name_map: Mapping[str, str], drop_unmapped: bool = True, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Map tag names in this stream to new names based on the provided mapping. """ @@ -442,7 +269,7 @@ def map_packets( name_map: Mapping[str, str], drop_unmapped: bool = True, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Map packet names in this stream to new names based on the provided mapping. """ @@ -454,14 +281,14 @@ def polars_filter( constraint_map: Mapping[str, Any] | None = None, label: str | None = None, **constraints: Any, - ) -> "Stream": ... + ) -> "StreamProtocol": ... def select_tag_columns( self, tag_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Select the specified tag columns from the stream. A ValueError is raised if one or more specified tag columns do not exist in the stream unless strict = False. @@ -473,7 +300,7 @@ def select_packet_columns( packet_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Select the specified tag columns from the stream. A ValueError is raised if one or more specified tag columns do not exist in the stream unless strict = False. @@ -485,7 +312,7 @@ def drop_tag_columns( tag_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Drop the specified tag columns from the stream. A ValueError is raised if one or more specified tag columns do not exist in the stream unless strict = False. @@ -498,7 +325,7 @@ def drop_packet_columns( packet_columns: str | Collection[str], strict: bool = True, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Drop the specified packet columns from the stream. A ValueError is raised if one or more specified packet columns do not exist in the stream unless strict = False. @@ -510,7 +337,7 @@ def batch( batch_size: int = 0, drop_partial_batch: bool = False, label: str | None = None, - ) -> "Stream": + ) -> "StreamProtocol": """ Batch the stream into groups of the specified size. @@ -528,76 +355,3 @@ def batch( Self: New stream containing batched (tag, packet) pairs. """ ... - - -@runtime_checkable -class LiveStream(Stream, Protocol): - """ - A stream that automatically stays up-to-date with its upstream dependencies. - - LiveStream extends the base Stream protocol with capabilities for "up-to-date" - data flow and reactive computation. Unlike static streams which represent - snapshots, LiveStreams provide the guarantee that their content always - reflects the current state of their dependencies. - - Key characteristics: - - Automatically refresh the stream if changes in the upstreams are detected - - Track last_modified timestamp when content changes - - Support manual refresh triggering and invalidation - - By design, LiveStream would return True for is_current except when auto-update fails. - - LiveStreams are always returned by Kernel.__call__() methods, ensuring - that normal kernel usage produces live, up-to-date results. - - Caching behavior: - - last_modified updates whenever content changes - - Can be cached based on dependency timestamps - - Invalidation happens automatically when upstreams change - - Use cases: - - Real-time data processing pipelines - - Reactive user interfaces - - Monitoring and alerting systems - - Dynamic dashboard updates - - Any scenario requiring current data - """ - - def refresh(self, force: bool = False) -> bool: - """ - Manually trigger a refresh of this stream's content. - - Forces the stream to check its upstream dependencies and update - its content if necessary. This is useful when: - - You want to ensure the latest data before a critical operation - - You need to force computation at a specific time - - You're debugging data flow issues - - You want to pre-compute results for performance - Args: - force: If True, always refresh even if the stream is current. - If False, only refresh if the stream is not current. - - Returns: - bool: True if the stream was refreshed, False if it was already current. - Note: LiveStream refreshes automatically on access, so this - method may be a no-op for some implementations. However, it's - always safe to call if you need to control when the cache is refreshed. - """ - ... - - def invalidate(self) -> None: - """ - Mark this stream as invalid, forcing a refresh on next access. - - This method is typically called when: - - Upstream dependencies have changed - - The source kernel has been modified - - External data sources have been updated - - Manual cache invalidation is needed - - The stream will automatically refresh its content the next time - it's accessed (via iteration, as_table(), etc.). - - This is more efficient than immediate refresh when you know the - data will be accessed later. - """ - ... diff --git a/src/orcapod/protocols/core_protocols/temporal.py b/src/orcapod/protocols/core_protocols/temporal.py new file mode 100644 index 00000000..a0697ac8 --- /dev/null +++ b/src/orcapod/protocols/core_protocols/temporal.py @@ -0,0 +1,25 @@ +from datetime import datetime +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class TemporalProtocol(Protocol): + """ + Protocol for objects that track temporal state. + + Objects implementing TemporalProtocol carries a computed property to + report when their content was last modified, enabling time-sensitive + actions such as cache invalidation, incremental processing, and + dependency staleness tracking. + """ + + @property + def last_modified(self) -> datetime | None: + """ + When this object's content was last modified. + + Returns: + datetime: Content last modified timestamp (timezone-aware) + None: Modification time unknown (assume always changed) + """ + ... diff --git a/src/orcapod/protocols/core_protocols/traceable.py b/src/orcapod/protocols/core_protocols/traceable.py new file mode 100644 index 00000000..589ce58e --- /dev/null +++ b/src/orcapod/protocols/core_protocols/traceable.py @@ -0,0 +1,22 @@ +from typing import Protocol + +from orcapod.protocols.core_protocols.labelable import LabelableProtocol +from orcapod.protocols.core_protocols.temporal import TemporalProtocol +from orcapod.protocols.hashing_protocols import ( + ContentIdentifiableProtocol, + DataContextAwareProtocol, +) + + +class TraceableProtocol( + DataContextAwareProtocol, + ContentIdentifiableProtocol, + LabelableProtocol, + TemporalProtocol, + Protocol, +): + """ + Base protocol for objects that can be traced. + """ + + pass diff --git a/src/orcapod/protocols/core_protocols/trackers.py b/src/orcapod/protocols/core_protocols/trackers.py index 7bc9a1e3..787fbfd2 100644 --- a/src/orcapod/protocols/core_protocols/trackers.py +++ b/src/orcapod/protocols/core_protocols/trackers.py @@ -1,13 +1,13 @@ -from typing import Protocol, runtime_checkable from contextlib import AbstractContextManager -from orcapod.protocols.core_protocols.kernel import Kernel -from orcapod.protocols.core_protocols.pods import Pod -from orcapod.protocols.core_protocols.source import Source -from orcapod.protocols.core_protocols.streams import Stream +from typing import Protocol, runtime_checkable + +from orcapod.protocols.core_protocols.function_pod import FunctionPodProtocol +from orcapod.protocols.core_protocols.operator_pod import OperatorPodProtocol +from orcapod.protocols.core_protocols.streams import StreamProtocol @runtime_checkable -class Tracker(Protocol): +class TrackerProtocol(Protocol): """ Records kernel invocations and stream creation for computational graph tracking. @@ -49,68 +49,57 @@ def is_active(self) -> bool: """ ... - def record_kernel_invocation( - self, kernel: Kernel, upstreams: tuple[Stream, ...], label: str | None = None + def record_operator_pod_invocation( + self, + pod: OperatorPodProtocol, + upstreams: tuple[StreamProtocol, ...] = (), + label: str | None = None, ) -> None: """ - Record a kernel invocation in the computational graph. + Record an operator pod invocation in the computational graph. - This method is called whenever a kernel is invoked. The tracker + This method is called whenever a pod is invoked. The tracker should record: - - The kernel and its properties - - The input streams that were used as input + - The pod and its properties + - The input streams that were used as input. If no streams are provided, the pod is considered a source pod. - Timing and performance information - Any relevant metadata Args: - kernel: The kernel that was invoked + pod: The pod that was invoked upstreams: The input streams used for this invocation """ ... - def record_source_invocation( - self, source: Source, label: str | None = None - ) -> None: - """ - Record a source invocation in the computational graph. - - This method is called whenever a source is invoked. The tracker - should record: - - The source and its properties - - Timing and performance information - - Any relevant metadata - - Args: - source: The source that was invoked - """ - ... - - def record_pod_invocation( - self, pod: Pod, upstreams: tuple[Stream, ...], label: str | None = None + def record_function_pod_invocation( + self, + pod: FunctionPodProtocol, + input_stream: StreamProtocol, + label: str | None = None, ) -> None: """ - Record a pod invocation in the computational graph. + Record a function pod invocation in the computational graph. - This method is called whenever a pod is invoked. The tracker + This method is called whenever a function pod is invoked. The tracker should record: - - The pod and its properties - - The upstream streams that were used as input + - The function pod and its properties + - The input stream that was used as input. If no streams are provided, the pod is considered a source pod. - Timing and performance information - Any relevant metadata Args: - pod: The pod that was invoked - upstreams: The input streams used for this invocation + pod: The function pod that was invoked + input_stream: The input stream used for this invocation """ ... @runtime_checkable -class TrackerManager(Protocol): +class TrackerManagerProtocol(Protocol): """ Manages multiple trackers and coordinates their activity. - The TrackerManager provides a centralized way to: + The TrackerManagerProtocol provides a centralized way to: - Register and manage multiple trackers - Coordinate recording across all active trackers - Provide a single interface for graph recording @@ -123,7 +112,7 @@ class TrackerManager(Protocol): - Performance optimization (selective tracking) """ - def get_active_trackers(self) -> list[Tracker]: + def get_active_trackers(self) -> list[TrackerProtocol]: """ Get all currently active trackers. @@ -131,11 +120,11 @@ def get_active_trackers(self) -> list[Tracker]: providing the list of trackers that will receive recording events. Returns: - list[Tracker]: List of trackers that are currently recording + list[TrackerProtocol]: List of trackers that are currently recording """ ... - def register_tracker(self, tracker: Tracker) -> None: + def register_tracker(self, tracker: TrackerProtocol) -> None: """ Register a new tracker in the system. @@ -148,7 +137,7 @@ def register_tracker(self, tracker: Tracker) -> None: """ ... - def deregister_tracker(self, tracker: Tracker) -> None: + def deregister_tracker(self, tracker: TrackerProtocol) -> None: """ Remove a tracker from the system. @@ -163,50 +152,43 @@ def deregister_tracker(self, tracker: Tracker) -> None: """ ... - def record_kernel_invocation( - self, kernel: Kernel, upstreams: tuple[Stream, ...], label: str | None = None + def record_operator_pod_invocation( + self, + pod: OperatorPodProtocol, + upstreams: tuple[StreamProtocol, ...] = (), + label: str | None = None, ) -> None: """ - Record a stream in all active trackers. + Record operator pod invocation in all active trackers. - This method broadcasts the stream recording to all currently + This method broadcasts the operator pod invocation recording to all currently active and registered trackers. It provides a single point of entry for recording events, simplifying kernel implementations. Args: - stream: The stream to record in all active trackers - """ - ... - - def record_source_invocation( - self, source: Source, label: str | None = None - ) -> None: - """ - Record a source invocation in the computational graph. - - This method is called whenever a source is invoked. The tracker - should record: - - The source and its properties - - Timing and performance information - - Any relevant metadata - - Args: - source: The source that was invoked + pod: The operator pod to record in all active trackers + upstreams: The upstream streams to record in all active trackers + label: The label to associate with the recording """ ... - def record_pod_invocation( - self, pod: Pod, upstreams: tuple[Stream, ...], label: str | None = None + def record_function_pod_invocation( + self, + pod: FunctionPodProtocol, + input_stream: StreamProtocol, + label: str | None = None, ) -> None: """ - Record a stream in all active trackers. + Record a function pod invocation in all active trackers. - This method broadcasts the stream recording to all currently` + This method broadcasts the function pod invocation recording to all currently active and registered trackers. It provides a single point of entry for recording events, simplifying kernel implementations. Args: - stream: The stream to record in all active trackers + pod: The function pod to record in all active trackers + input_stream: The input stream to record in all active trackers + label: The label to associate with the recording """ ... diff --git a/src/orcapod/protocols/database_protocols.py b/src/orcapod/protocols/database_protocols.py index 1bf9eac8..9af76082 100644 --- a/src/orcapod/protocols/database_protocols.py +++ b/src/orcapod/protocols/database_protocols.py @@ -1,11 +1,12 @@ -from typing import Any, Protocol, TYPE_CHECKING +from typing import Any, Protocol, TYPE_CHECKING, runtime_checkable from collections.abc import Collection, Mapping if TYPE_CHECKING: import pyarrow as pa -class ArrowDatabase(Protocol): +@runtime_checkable +class ArrowDatabaseProtocol(Protocol): def add_record( self, record_path: tuple[str, ...], @@ -60,8 +61,21 @@ def flush(self) -> None: """Flush any buffered writes to the underlying storage.""" ... + def to_config(self) -> "dict[str, Any]": + """Serialize database configuration to a JSON-compatible dict. -class MetadataCapable(Protocol): + The returned dict must include a ``"type"`` key identifying the + database implementation (e.g., ``"delta_table"``, ``"in_memory"``). + """ + ... + + @classmethod + def from_config(cls, config: "dict[str, Any]") -> "ArrowDatabaseProtocol": + """Reconstruct a database instance from a config dict.""" + ... + + +class MetadataCapableProtocol(Protocol): def set_metadata( self, record_path: tuple[str, ...], @@ -82,7 +96,9 @@ def validate_metadata( ) -> Collection[str]: ... -class ArrowDatabaseWithMetadata(ArrowDatabase, MetadataCapable, Protocol): - """A protocol that combines ArrowDatabase with metadata capabilities.""" +class ArrowDatabaseWithMetadataProtocol( + ArrowDatabaseProtocol, MetadataCapableProtocol, Protocol +): + """A protocol that combines ArrowDatabaseProtocol with metadata capabilities.""" pass diff --git a/src/orcapod/protocols/hashing_protocols.py b/src/orcapod/protocols/hashing_protocols.py index 10719af7..264c4f1f 100644 --- a/src/orcapod/protocols/hashing_protocols.py +++ b/src/orcapod/protocols/hashing_protocols.py @@ -1,172 +1,239 @@ """Hash strategy protocols for dependency injection.""" -import uuid +from __future__ import annotations + from collections.abc import Callable -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable -from orcapod.types import PathLike, PythonSchema +from orcapod.types import ContentHash, PathLike, Schema if TYPE_CHECKING: import pyarrow as pa -@dataclass(frozen=True, slots=True) -class ContentHash: - method: str - digest: bytes +@runtime_checkable +class DataContextAwareProtocol(Protocol): + """Protocol for objects aware of their data context.""" - # TODO: make the default char count configurable - def to_hex(self, char_count: int | None = 20) -> str: - """Convert digest to hex string, optionally truncated.""" - hex_str = self.digest.hex() - return hex_str[:char_count] if char_count else hex_str + @property + def data_context_key(self) -> str: + """ + Return the data context key associated with this object. - def to_int(self, hexdigits: int = 20) -> int: + Returns: + str: The data context key """ - Convert digest to integer representation. + ... - Args: - hexdigits: Number of hex digits to use (truncates if needed) - Returns: - Integer representation of the hash +@runtime_checkable +class PipelineElementProtocol(Protocol): + """ + Protocol for objects that have a stable identity as an element in a + pipeline graph — determined by schema and upstream topology, not by + data content. + + This is a parallel identity chain to ContentIdentifiableProtocol. + Where content identity captures the precise, data-inclusive identity of + an object, pipeline identity captures only what is structurally meaningful + for pipeline database path scoping: the schemas and the recursive topology + of the upstream computation. + + The base case (RootSource) returns a hash of (tag_schema, packet_schema). + Every other element recurses through the pipeline_hash() of its upstream + inputs, with the hash values themselves (ContentHash objects) used as + terminal leaves so no special hasher mode is required. + + Two sources with identical schemas processed through the same function pod + graph will produce the same pipeline_hash() at every downstream node, + enabling automatic multi-source table sharing in the pipeline database. + """ + + def pipeline_identity_structure(self) -> Any: + """ + Return a structure representing this element's pipeline identity. + + At source nodes (base case): return (tag_schema, packet_schema). + At all other nodes: return a structure containing references to + upstream pipeline elements and/or packet functions as raw objects. + The pipeline resolver threaded through pipeline_hash() ensures that + PipelineElementProtocol objects are resolved via pipeline_hash() and + other ContentIdentifiable objects via content_hash(), both using the + same hasher throughout the computation. """ - hex_str = self.to_hex()[:hexdigits] - return int(hex_str, 16) + ... - def to_uuid(self, namespace: uuid.UUID = uuid.NAMESPACE_OID) -> uuid.UUID: + def pipeline_hash(self, hasher=None) -> ContentHash: """ - Convert digest to UUID format. + Return the pipeline-level hash of this element, computed from + pipeline_identity_structure() and cached by hasher_id. Args: - namespace: UUID namespace for uuid5 generation - - Returns: - UUID derived from this hash + hasher: Optional semantic hasher to use. When omitted, resolved + from the element's data_context. """ - # Using uuid5 with the hex string ensures deterministic UUIDs - return uuid.uuid5(namespace, self.to_hex()) - - def to_base64(self) -> str: - """Convert digest to base64 string.""" - import base64 - - return base64.b64encode(self.digest).decode("ascii") - - def to_string(self, prefix_method: bool = True) -> str: - """Convert digest to a string representation.""" - if prefix_method: - return f"{self.method}:{self.to_hex()}" - return self.to_hex() - - def __str__(self) -> str: - return self.to_string() - - @classmethod - def from_string(cls, hash_string: str) -> "ContentHash": - """Parse 'method:hex_digest' format.""" - method, hex_digest = hash_string.split(":", 1) - return cls(method, bytes.fromhex(hex_digest)) - - def display_name(self, length: int = 8) -> str: - """Return human-friendly display like 'arrow_v2.1:1a2b3c4d'.""" - return f"{self.method}:{self.to_hex(length)}" + ... @runtime_checkable -class ContentIdentifiable(Protocol): - """Protocol for objects that can provide an identity structure.""" +class ContentIdentifiableProtocol(Protocol): + """ + Protocol for objects that can express their semantic identity as a plain + Python structure. + + This is the only method a class needs to implement to participate in the + content-based hashing system. The returned structure is recursively + resolved by the SemanticHasherProtocol -- any nested ContentIdentifiableProtocol objects + within the structure will themselves be expanded and hashed, producing a + Merkle-tree-like composition of hashes. + + The method should return a deterministic structure whose value depends + only on the semantic content of the object -- not on memory addresses, + object IDs, or other incidental runtime state. + """ def identity_structure(self) -> Any: """ - Return a structure that represents the identity of this object. + Return a structure that represents the semantic identity of this object. + + The returned value may be any Python object: + - Primitives (str, int, float, bool, None) are used as-is. + - Collections (list, dict, set, tuple) are recursively traversed. + - Nested ContentIdentifiableProtocol objects are recursively resolved by + the SemanticHasherProtocol: their identity structure is hashed to a + ContentHash hex token, which is then embedded in place of the + object in the parent structure. + - Any type that has a registered TypeHandlerProtocol in the + SemanticHasherProtocol's registry is handled by that handler. Returns: - Any: A structure representing this object's content. + Any: A structure representing this object's semantic content. Should be deterministic and include all identity-relevant data. - Return None to indicate no custom identity is available. """ ... - def content_hash(self) -> ContentHash: + def content_hash(self, hasher: SemanticHasherProtocol | None = None) -> ContentHash: """ - Compute a hash based on the content of this object. + Returns the content hash. - Returns: - bytes: A byte representation of the hash based on the content. - If no identity structure is provided, return None. + Args: + hasher: Optional semantic hasher to use for the entire recursive + computation. When omitted, resolved from the object's + data_context (or injected hasher for mixin-based objects). + The same hasher propagates to all nested ContentIdentifiable + objects, ensuring one consistent context per computation. """ ... - def __eq__(self, other: object) -> bool: + +class TypeHandlerProtocol(Protocol): + """ + Protocol for type-specific serialization handlers used by SemanticHasherProtocol. + + A TypeHandlerProtocol converts a specific Python type into a value that + ``hash_object`` can process. Handlers are registered with a + TypeHandlerRegistry and looked up via MRO-aware resolution. + + The returned value is passed directly back to ``hash_object``, so it may + be anything that ``hash_object`` understands: + + - A primitive (None, bool, int, float, str) -- hashed directly. + - A structure (list, tuple, dict, set, frozenset) -- expanded and hashed. + - A ContentHash -- treated as a terminal; returned as-is without + re-hashing. Use this when the handler has already computed the + definitive hash of the object (e.g. hashing a file's content). + - A ContentIdentifiableProtocol -- its identity_structure() will be called. + - Another registered type -- dispatched through the registry. + """ + + def handle(self, obj: Any, hasher: "SemanticHasherProtocol") -> Any: """ - Equality check that compares the identity structures of two objects. + Convert *obj* into a value that ``hash_object`` can process. Args: - other (object): The object to compare with. + obj: The object to handle. + hasher: The SemanticHasherProtocol, available if the handler needs to + hash sub-objects explicitly via ``hasher.hash_object()``. Returns: - bool: True if the identity structures are equal, False otherwise. + Any value accepted by ``hash_object``: a primitive, structure, + ContentHash, ContentIdentifiableProtocol, or another registered type. """ ... - def __hash__(self) -> int: - """ - Hash implementation that uses the identity structure if provided, - otherwise falls back to the default hash. - Returns: - int: A hash value based on either content or identity. - """ - ... +class SemanticHasherProtocol(Protocol): + """ + Protocol for the semantic content-based hasher. + + ``hash_object(obj)`` is the single recursive entry point. It produces a + ContentHash for any Python object using the following dispatch: + + - ContentHash → terminal; returned as-is + - Primitive → JSON-serialised and hashed directly + - Structure → structurally expanded (type-tagged), then hashed + - Handler match → handler.handle() returns a new value; recurse + - ContentIdentifiableProtocol→ identity_structure() returns a value; recurse + - Unknown → TypeError (strict) or best-effort string (lenient) + Containers are type-tagged before hashing so that list, tuple, dict, set, + and namedtuple produce distinct hashes even when their elements are equal. -class ObjectHasher(Protocol): - """Protocol for general object hashing.""" + Unknown types raise TypeError by default (strict mode). Set + strict=False on construction to fall back to a best-effort string + representation with a warning instead. + """ - # TODO: consider more explicitly stating types of objects accepted - def hash_object(self, obj: Any) -> ContentHash: + def hash_object( + self, + obj: Any, + resolver: Callable[[Any], ContentHash] | None = None, + ) -> ContentHash: """ - Hash an object to a byte representation. Object hasher must be - able to handle ContentIdentifiable objects to hash them based on their - identity structure. If compressed=True, the content identifiable object - is immediately replaced with its compressed string identity and used in the - computation of containing identity structure. + Hash *obj* based on its semantic content. Args: - obj (Any): The object to hash. + obj: The object to hash. + resolver: Optional callable invoked for any ContentIdentifiable + object encountered during hashing. When provided it overrides + the default obj.content_hash() call, allowing the caller to + control which identity chain is used and to propagate a + consistent hasher through the full recursive computation. Returns: - bytes: The byte representation of the hash. + ContentHash: Stable, content-based hash of the object. """ ... @property def hasher_id(self) -> str: """ - Returns a unique identifier/name assigned to the hasher + Returns a unique identifier/name for this hasher instance. + + The hasher_id is embedded in every ContentHash produced by this + hasher, allowing hashes from different versions or configurations + to be distinguished. """ ... -class FileContentHasher(Protocol): +class FileContentHasherProtocol(Protocol): """Protocol for file-related hashing.""" def hash_file(self, file_path: PathLike) -> ContentHash: ... -class ArrowHasher(Protocol): +@runtime_checkable +class ArrowHasherProtocol(Protocol): """Protocol for hashing arrow packets.""" - def get_hasher_id(self) -> str: ... + @property + def hasher_id(self) -> str: ... - def hash_table( - self, table: "pa.Table | pa.RecordBatch", prefix_hasher_id: bool = True - ) -> ContentHash: ... + def hash_table(self, table: "pa.Table | pa.RecordBatch") -> ContentHash: ... -class StringCacher(Protocol): +class StringCacherProtocol(Protocol): """Protocol for caching string key value pairs.""" def get_cached(self, cache_key: str) -> str | None: ... @@ -174,21 +241,21 @@ def set_cached(self, cache_key: str, value: str) -> None: ... def clear_cache(self) -> None: ... -class FunctionInfoExtractor(Protocol): +class FunctionInfoExtractorProtocol(Protocol): """Protocol for extracting function information.""" def extract_function_info( self, func: Callable[..., Any], function_name: str | None = None, - input_typespec: PythonSchema | None = None, - output_typespec: PythonSchema | None = None, + input_typespec: Schema | None = None, + output_typespec: Schema | None = None, exclude_function_signature: bool = False, exclude_function_body: bool = False, ) -> dict[str, Any]: ... -class SemanticTypeHasher(Protocol): +class SemanticTypeHasherProtocol(Protocol): """Abstract base class for semantic type-specific hashers.""" @property @@ -203,6 +270,6 @@ def hash_column( """Hash a column with this semantic type and return the hash bytes an an array""" ... - def set_cacher(self, cacher: StringCacher) -> None: + def set_cacher(self, cacher: StringCacherProtocol) -> None: """Add a string cacher for caching hash values.""" ... diff --git a/src/orcapod/protocols/legacy_data_protocols.py b/src/orcapod/protocols/legacy_data_protocols.py deleted file mode 100644 index 53a86576..00000000 --- a/src/orcapod/protocols/legacy_data_protocols.py +++ /dev/null @@ -1,2278 +0,0 @@ -# from collections.abc import Collection, Iterator, Mapping, Callable -# from datetime import datetime -# from typing import Any, ContextManager, Protocol, Self, TYPE_CHECKING, runtime_checkable -# from orcapod.protocols.hashing_protocols import ContentIdentifiable, ContentHash -# from orcapod.types import DataValue, TypeSpec - - -# if TYPE_CHECKING: -# import pyarrow as pa -# import polars as pl -# import pandas as pd - - -# @runtime_checkable -# class ExecutionEngine(Protocol): -# @property -# def name(self) -> str: ... - -# def submit_sync(self, function: Callable, *args, **kwargs) -> Any: -# """ -# Run the given function with the provided arguments. -# This method should be implemented by the execution engine. -# """ -# ... - -# async def submit_async(self, function: Callable, *args, **kwargs) -> Any: -# """ -# Asynchronously run the given function with the provided arguments. -# This method should be implemented by the execution engine. -# """ -# ... - -# # TODO: consider adding batch submission - - -# @runtime_checkable -# class Datagram(ContentIdentifiable, Protocol): -# """ -# Protocol for immutable datagram containers in Orcapod. - -# Datagrams are the fundamental units of data that flow through the system. -# They provide a unified interface for data access, conversion, and manipulation, -# ensuring consistent behavior across different storage backends (dict, Arrow table, etc.). - -# Each datagram contains: -# - **Data columns**: The primary business data (user_id, name, etc.) -# - **Meta columns**: Internal system metadata with {orcapod.META_PREFIX} (typically '__') prefixes (e.g. __processed_at, etc.) -# - **Context column**: Data context information ({orcapod.CONTEXT_KEY}) - -# Derivative of datagram (such as Packet or Tag) will also include some specific columns pertinent to the function of the specialized datagram: -# - **Source info columns**: Data provenance with {orcapod.SOURCE_PREFIX} ('_source_') prefixes (_source_user_id, etc.) used in Packet -# - **System tags**: Internal tags for system use, typically prefixed with {orcapod.SYSTEM_TAG_PREFIX} ('_system_') (_system_created_at, etc.) used in Tag - -# All operations are by design immutable - methods return new datagram instances rather than -# modifying existing ones. - -# Example: -# >>> datagram = DictDatagram({"user_id": 123, "name": "Alice"}) -# >>> updated = datagram.update(name="Alice Smith") -# >>> filtered = datagram.select("user_id", "name") -# >>> table = datagram.as_table() -# """ - -# # 1. Core Properties (Identity & Structure) -# @property -# def data_context_key(self) -> str: -# """ -# Return the data context key for this datagram. - -# This key identifies a collection of system components that collectively controls -# how information is serialized, hashed and represented, including the semantic type registry, -# arrow data hasher, and other contextual information. Same piece of information (that is two datagrams -# with an identical *logical* content) may bear distinct internal representation if they are -# represented under two distinct data context, as signified by distinct data context keys. - -# Returns: -# str: Context key for proper datagram interpretation -# """ -# ... - -# @property -# def meta_columns(self) -> tuple[str, ...]: -# """Return tuple of meta column names (with {orcapod.META_PREFIX} ('__') prefix).""" -# ... - -# # 2. Dict-like Interface (Data Access) -# def __getitem__(self, key: str) -> DataValue: -# """ -# Get data column value by key. - -# Provides dict-like access to data columns only. Meta columns -# are not accessible through this method (use `get_meta_value()` instead). - -# Args: -# key: Data column name. - -# Returns: -# The value stored in the specified data column. - -# Raises: -# KeyError: If the column doesn't exist in data columns. - -# Example: -# >>> datagram["user_id"] -# 123 -# >>> datagram["name"] -# 'Alice' -# """ -# ... - -# def __contains__(self, key: str) -> bool: -# """ -# Check if data column exists. - -# Args: -# key: Column name to check. - -# Returns: -# True if column exists in data columns, False otherwise. - -# Example: -# >>> "user_id" in datagram -# True -# >>> "nonexistent" in datagram -# False -# """ -# ... - -# def __iter__(self) -> Iterator[str]: -# """ -# Iterate over data column names. - -# Provides for-loop support over column names, enabling natural iteration -# patterns without requiring conversion to dict. - -# Yields: -# Data column names in no particular order. - -# Example: -# >>> for column in datagram: -# ... value = datagram[column] -# ... print(f"{column}: {value}") -# """ -# ... - -# def get(self, key: str, default: DataValue = None) -> DataValue: -# """ -# Get data column value with default fallback. - -# Args: -# key: Data column name. -# default: Value to return if column doesn't exist. - -# Returns: -# Column value if exists, otherwise the default value. - -# Example: -# >>> datagram.get("user_id") -# 123 -# >>> datagram.get("missing", "default") -# 'default' -# """ -# ... - -# # 3. Structural Information -# def keys( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# ) -> tuple[str, ...]: -# """ -# Return tuple of column names. - -# Provides access to column names with filtering options for different -# column types. Default returns only data column names. - -# Args: -# include_meta_columns: Controls meta column inclusion. -# - False: Return only data column names (default) -# - True: Include all meta column names -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context column. - -# Returns: -# Tuple of column names based on inclusion criteria. - -# Example: -# >>> datagram.keys() # Data columns only -# ('user_id', 'name', 'email') -# >>> datagram.keys(include_meta_columns=True) -# ('user_id', 'name', 'email', f'{orcapod.META_PREFIX}processed_at', f'{orcapod.META_PREFIX}pipeline_version') -# >>> datagram.keys(include_meta_columns=["pipeline"]) -# ('user_id', 'name', 'email',f'{orcapod.META_PREFIX}pipeline_version') -# >>> datagram.keys(include_context=True) -# ('user_id', 'name', 'email', f'{orcapod.CONTEXT_KEY}') -# """ -# ... - -# def types( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# ) -> TypeSpec: -# """ -# Return type specification mapping field names to Python types. - -# The TypeSpec enables type checking and validation throughout the system. - -# Args: -# include_meta_columns: Controls meta column type inclusion. -# - False: Exclude meta column types (default) -# - True: Include all meta column types -# - Collection[str]: Include meta column types matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context type. - -# Returns: -# TypeSpec mapping field names to their Python types. - -# Example: -# >>> datagram.types() -# {'user_id': , 'name': } -# """ -# ... - -# def arrow_schema( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# ) -> "pa.Schema": -# """ -# Return PyArrow schema representation. - -# The schema provides structured field and type information for efficient -# serialization and deserialization with PyArrow. - -# Args: -# include_meta_columns: Controls meta column schema inclusion. -# - False: Exclude meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context column. - -# Returns: -# PyArrow Schema describing the datagram structure. - -# Example: -# >>> schema = datagram.arrow_schema() -# >>> schema.names -# ['user_id', 'name'] -# """ -# ... - -# # 4. Format Conversions (Export) -# def as_dict( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# ) -> dict[str, DataValue]: -# """ -# Convert datagram to dictionary format. - -# Provides a simple key-value representation useful for debugging, -# serialization, and interop with dict-based APIs. - -# Args: -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include the context key. -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. - - -# Returns: -# Dictionary with requested columns as key-value pairs. - -# Example: -# >>> data = datagram.as_dict() # {'user_id': 123, 'name': 'Alice'} -# >>> full_data = datagram.as_dict( -# ... include_meta_columns=True, -# ... include_context=True -# ... ) -# """ -# ... - -# def as_table( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# ) -> "pa.Table": -# """ -# Convert datagram to PyArrow Table format. - -# Provides a standardized columnar representation suitable for analysis, -# processing, and interoperability with Arrow-based tools. - -# Args: -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include the context column. -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. - -# Returns: -# PyArrow Table with requested columns. - -# Example: -# >>> table = datagram.as_table() # Data columns only -# >>> full_table = datagram.as_table( -# ... include_meta_columns=True, -# ... include_context=True -# ... ) -# >>> filtered = datagram.as_table(include_meta_columns=["pipeline"]) # same as passing f"{orcapod.META_PREFIX}pipeline" -# """ -# ... - -# def as_arrow_compatible_dict( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# ) -> dict[str, Any]: -# """ -# Return dictionary with values optimized for Arrow table conversion. - -# This method returns a dictionary where values are in a form that can be -# efficiently converted to Arrow format using pa.Table.from_pylist(). - -# The key insight is that this avoids the expensive as_table() → concat pattern -# by providing values that are "Arrow-ready" while remaining in dict format -# for efficient batching. - -# Implementation note: This may involve format conversions (e.g., Path objects -# to strings, datetime objects to ISO strings, etc.) to ensure compatibility -# with Arrow's expected input formats. - -# Arrow table that results from pa.Table.from_pylist on the output of this should be accompanied -# with arrow_schema(...) with the same argument options to ensure that the schema matches the table. - -# Args: -# include_all_info: Include all available information -# include_meta_columns: Controls meta column inclusion -# include_context: Whether to include context key - -# Returns: -# Dictionary with values optimized for Arrow conversion - -# Example: -# # Efficient batch conversion pattern -# arrow_dicts = [datagram.as_arrow_compatible_dict() for datagram in datagrams] -# schema = datagrams[0].arrow_schema() -# table = pa.Table.from_pylist(arrow_dicts, schema=schema) -# """ -# ... - -# # 5. Meta Column Operations -# def get_meta_value(self, key: str, default: DataValue = None) -> DataValue: -# """ -# Get meta column value with optional default. - -# Meta columns store operational metadata and use {orcapod.META_PREFIX} ('__') prefixes. -# This method handles both prefixed and unprefixed key formats. - -# Args: -# key: Meta column key (with or without {orcapod.META_PREFIX} ('__') prefix). -# default: Value to return if meta column doesn't exist. - -# Returns: -# Meta column value if exists, otherwise the default value. - -# Example: -# >>> datagram.get_meta_value("pipeline_version") # Auto-prefixed -# 'v2.1.0' -# >>> datagram.get_meta_value("__pipeline_version") # Already prefixed -# 'v2.1.0' -# >>> datagram.get_meta_value("missing", "default") -# 'default' -# """ -# ... - -# def with_meta_columns(self, **updates: DataValue) -> Self: -# """ -# Create new datagram with updated meta columns. - -# Adds or updates operational metadata while preserving all data columns. -# Keys are automatically prefixed with {orcapod.META_PREFIX} ('__') if needed. - -# Args: -# **updates: Meta column updates as keyword arguments. - -# Returns: -# New datagram instance with updated meta columns. - -# Example: -# >>> tracked = datagram.with_meta_columns( -# ... processed_by="pipeline_v2", -# ... timestamp="2024-01-15T10:30:00Z" -# ... ) -# """ -# ... - -# def drop_meta_columns(self, *keys: str, ignore_missing: bool = False) -> Self: -# """ -# Create new datagram with specified meta columns removed. - -# Args: -# *keys: Meta column keys to remove (prefixes optional). -# ignore_missing: If True, ignore missing columns without raising an error. - - -# Returns: -# New datagram instance without specified meta columns. - -# Raises: -# KeryError: If any specified meta column to drop doesn't exist and ignore_missing=False. - -# Example: -# >>> cleaned = datagram.drop_meta_columns("old_source", "temp_debug") -# """ -# ... - -# # 6. Data Column Operations -# def select(self, *column_names: str) -> Self: -# """ -# Create new datagram with only specified data columns. - -# Args: -# *column_names: Data column names to keep. - - -# Returns: -# New datagram instance with only specified data columns. All other columns including -# meta columns and context are preserved. - -# Raises: -# KeyError: If any specified column doesn't exist. - -# Example: -# >>> subset = datagram.select("user_id", "name", "email") -# """ -# ... - -# def drop(self, *column_names: str, ignore_missing: bool = False) -> Self: -# """ -# Create new datagram with specified data columns removed. Note that this does not -# remove meta columns or context column. Refer to `drop_meta_columns()` for dropping -# specific meta columns. Context key column can never be dropped but a modified copy -# can be created with a different context key using `with_data_context()`. - -# Args: -# *column_names: Data column names to remove. -# ignore_missing: If True, ignore missing columns without raising an error. - -# Returns: -# New datagram instance without specified data columns. - -# Raises: -# KeryError: If any specified column to drop doesn't exist and ignore_missing=False. - -# Example: -# >>> filtered = datagram.drop("temp_field", "debug_info") -# """ -# ... - -# def rename( -# self, -# column_mapping: Mapping[str, str], -# ) -> Self: -# """ -# Create new datagram with data columns renamed. - -# Args: -# column_mapping: Mapping from old names to new names. - -# Returns: -# New datagram instance with renamed data columns. - -# Example: -# >>> renamed = datagram.rename( -# ... {"old_id": "user_id", "old_name": "full_name"}, -# ... column_types={"user_id": int} -# ... ) -# """ -# ... - -# def update(self, **updates: DataValue) -> Self: -# """ -# Create new datagram with existing column values updated. - -# Updates values in existing data columns. Will error if any specified -# column doesn't exist - use with_columns() to add new columns. - -# Args: -# **updates: Column names and their new values. - -# Returns: -# New datagram instance with updated values. - -# Raises: -# KeyError: If any specified column doesn't exist. - -# Example: -# >>> updated = datagram.update( -# ... file_path="/new/absolute/path.txt", -# ... status="processed" -# ... ) -# """ -# ... - -# def with_columns( -# self, -# column_types: Mapping[str, type] | None = None, -# **updates: DataValue, -# ) -> Self: -# """ -# Create new datagram with additional data columns. - -# Adds new data columns to the datagram. Will error if any specified -# column already exists - use update() to modify existing columns. - -# Args: -# column_types: Optional type specifications for new columns. If not provided, the column type is -# inferred from the provided values. If value is None, the column type defaults to `str`. -# **kwargs: New columns as keyword arguments. - -# Returns: -# New datagram instance with additional data columns. - -# Raises: -# ValueError: If any specified column already exists. - -# Example: -# >>> expanded = datagram.with_columns( -# ... status="active", -# ... score=95.5, -# ... column_types={"score": float} -# ... ) -# """ -# ... - -# # 7. Context Operations -# def with_context_key(self, new_context_key: str) -> Self: -# """ -# Create new datagram with different context key. - -# Changes the semantic interpretation context while preserving all data. -# The context key affects how columns are processed and converted. - -# Args: -# new_context_key: New context key string. - -# Returns: -# New datagram instance with updated context key. - -# Note: -# How the context is interpreted depends on the datagram implementation. -# Semantic processing may be rebuilt for the new context. - -# Example: -# >>> financial_datagram = datagram.with_context_key("financial_v1") -# """ -# ... - -# # 8. Utility Operations -# def copy(self) -> Self: -# """ -# Create a shallow copy of the datagram. - -# Returns a new datagram instance with the same data and cached values. -# This is more efficient than reconstructing from scratch when you need -# an identical datagram instance. - -# Returns: -# New datagram instance with copied data and caches. - -# Example: -# >>> copied = datagram.copy() -# >>> copied is datagram # False - different instance -# False -# """ -# ... - -# # 9. String Representations -# def __str__(self) -> str: -# """ -# Return user-friendly string representation. - -# Shows the datagram as a simple dictionary for user-facing output, -# messages, and logging. Only includes data columns for clean output. - -# Returns: -# Dictionary-style string representation of data columns only. -# """ -# ... - -# def __repr__(self) -> str: -# """ -# Return detailed string representation for debugging. - -# Shows the datagram type and comprehensive information for debugging. - -# Returns: -# Detailed representation with type and metadata information. -# """ -# ... - - -# @runtime_checkable -# class Tag(Datagram, Protocol): -# """ -# Metadata associated with each data item in a stream. - -# Tags carry contextual information about data packets as they flow through -# the computational graph. They are immutable and provide metadata that -# helps with: -# - Data lineage tracking -# - Grouping and aggregation operations -# - Temporal information (timestamps) -# - Source identification -# - Processing context - -# Common examples include: -# - Timestamps indicating when data was created/processed -# - Source identifiers showing data origin -# - Processing metadata like batch IDs or session information -# - Grouping keys for aggregation operations -# - Quality indicators or confidence scores -# """ - -# def keys( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_system_tags: bool = False, -# ) -> tuple[str, ...]: -# """ -# Return tuple of column names. - -# Provides access to column names with filtering options for different -# column types. Default returns only data column names. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column inclusion. -# - False: Return only data column names (default) -# - True: Include all meta column names -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context column. -# include_source: Whether to include source info fields. - - -# Returns: -# Tuple of column names based on inclusion criteria. - -# Example: -# >>> datagram.keys() # Data columns only -# ('user_id', 'name', 'email') -# >>> datagram.keys(include_meta_columns=True) -# ('user_id', 'name', 'email', f'{orcapod.META_PREFIX}processed_at', f'{orcapod.META_PREFIX}pipeline_version') -# >>> datagram.keys(include_meta_columns=["pipeline"]) -# ('user_id', 'name', 'email',f'{orcapod.META_PREFIX}pipeline_version') -# >>> datagram.keys(include_context=True) -# ('user_id', 'name', 'email', f'{orcapod.CONTEXT_KEY}') -# """ -# ... - -# def types( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_system_tags: bool = False, -# ) -> TypeSpec: -# """ -# Return type specification mapping field names to Python types. - -# The TypeSpec enables type checking and validation throughout the system. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column type inclusion. -# - False: Exclude meta column types (default) -# - True: Include all meta column types -# - Collection[str]: Include meta column types matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context type. -# include_source: Whether to include source info fields. - -# Returns: -# TypeSpec mapping field names to their Python types. - -# Example: -# >>> datagram.types() -# {'user_id': , 'name': } -# """ -# ... - -# def arrow_schema( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_system_tags: bool = False, -# ) -> "pa.Schema": -# """ -# Return PyArrow schema representation. - -# The schema provides structured field and type information for efficient -# serialization and deserialization with PyArrow. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column schema inclusion. -# - False: Exclude meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context column. -# include_source: Whether to include source info fields. - - -# Returns: -# PyArrow Schema describing the datagram structure. - -# Example: -# >>> schema = datagram.arrow_schema() -# >>> schema.names -# ['user_id', 'name'] -# """ -# ... - -# def as_dict( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_system_tags: bool = False, -# ) -> dict[str, DataValue]: -# """ -# Convert datagram to dictionary format. - -# Provides a simple key-value representation useful for debugging, -# serialization, and interop with dict-based APIs. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include the context key. -# include_source: Whether to include source info fields. - - -# Returns: -# Dictionary with requested columns as key-value pairs. - -# Example: -# >>> data = datagram.as_dict() # {'user_id': 123, 'name': 'Alice'} -# >>> full_data = datagram.as_dict( -# ... include_meta_columns=True, -# ... include_context=True -# ... ) -# """ -# ... - -# def as_table( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_system_tags: bool = False, -# ) -> "pa.Table": -# """ -# Convert datagram to PyArrow Table format. - -# Provides a standardized columnar representation suitable for analysis, -# processing, and interoperability with Arrow-based tools. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include the context column. -# include_source: Whether to include source info columns in the schema. - -# Returns: -# PyArrow Table with requested columns. - -# Example: -# >>> table = datagram.as_table() # Data columns only -# >>> full_table = datagram.as_table( -# ... include_meta_columns=True, -# ... include_context=True -# ... ) -# >>> filtered = datagram.as_table(include_meta_columns=["pipeline"]) # same as passing f"{orcapod.META_PREFIX}pipeline" -# """ -# ... - -# # TODO: add this back -# # def as_arrow_compatible_dict( -# # self, -# # include_all_info: bool = False, -# # include_meta_columns: bool | Collection[str] = False, -# # include_context: bool = False, -# # include_source: bool = False, -# # ) -> dict[str, Any]: -# # """Extended version with source info support.""" -# # ... - -# def as_datagram( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_system_tags: bool = False, -# ) -> Datagram: -# """ -# Convert the packet to a Datagram. - -# Args: -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - -# Returns: -# Datagram: Datagram representation of packet data -# """ -# ... - -# def system_tags(self) -> dict[str, DataValue]: -# """ -# Return metadata about the packet's source/origin. - -# Provides debugging and lineage information about where the packet -# originated. May include information like: -# - File paths for file-based sources -# - Database connection strings -# - API endpoints -# - Processing pipeline information - -# Returns: -# dict[str, str | None]: Source information for each data column as key-value pairs. -# """ -# ... - - -# @runtime_checkable -# class Packet(Datagram, Protocol): -# """ -# The actual data payload in a stream. - -# Packets represent the core data being processed through the computational -# graph. Unlike Tags (which are metadata), Packets contain the actual -# information that computations operate on. - -# Packets extend Datagram with additional capabilities for: -# - Source tracking and lineage -# - Content-based hashing for caching -# - Metadata inclusion for debugging - -# The distinction between Tag and Packet is crucial for understanding -# data flow: Tags provide context, Packets provide content. -# """ - -# def keys( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_source: bool = False, -# ) -> tuple[str, ...]: -# """ -# Return tuple of column names. - -# Provides access to column names with filtering options for different -# column types. Default returns only data column names. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column inclusion. -# - False: Return only data column names (default) -# - True: Include all meta column names -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context column. -# include_source: Whether to include source info fields. - - -# Returns: -# Tuple of column names based on inclusion criteria. - -# Example: -# >>> datagram.keys() # Data columns only -# ('user_id', 'name', 'email') -# >>> datagram.keys(include_meta_columns=True) -# ('user_id', 'name', 'email', f'{orcapod.META_PREFIX}processed_at', f'{orcapod.META_PREFIX}pipeline_version') -# >>> datagram.keys(include_meta_columns=["pipeline"]) -# ('user_id', 'name', 'email',f'{orcapod.META_PREFIX}pipeline_version') -# >>> datagram.keys(include_context=True) -# ('user_id', 'name', 'email', f'{orcapod.CONTEXT_KEY}') -# """ -# ... - -# def types( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_source: bool = False, -# ) -> TypeSpec: -# """ -# Return type specification mapping field names to Python types. - -# The TypeSpec enables type checking and validation throughout the system. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column type inclusion. -# - False: Exclude meta column types (default) -# - True: Include all meta column types -# - Collection[str]: Include meta column types matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context type. -# include_source: Whether to include source info fields. - -# Returns: -# TypeSpec mapping field names to their Python types. - -# Example: -# >>> datagram.types() -# {'user_id': , 'name': } -# """ -# ... - -# def arrow_schema( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_source: bool = False, -# ) -> "pa.Schema": -# """ -# Return PyArrow schema representation. - -# The schema provides structured field and type information for efficient -# serialization and deserialization with PyArrow. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column schema inclusion. -# - False: Exclude meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include context column. -# include_source: Whether to include source info fields. - - -# Returns: -# PyArrow Schema describing the datagram structure. - -# Example: -# >>> schema = datagram.arrow_schema() -# >>> schema.names -# ['user_id', 'name'] -# """ -# ... - -# def as_dict( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_source: bool = False, -# ) -> dict[str, DataValue]: -# """ -# Convert datagram to dictionary format. - -# Provides a simple key-value representation useful for debugging, -# serialization, and interop with dict-based APIs. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include the context key. -# include_source: Whether to include source info fields. - - -# Returns: -# Dictionary with requested columns as key-value pairs. - -# Example: -# >>> data = datagram.as_dict() # {'user_id': 123, 'name': 'Alice'} -# >>> full_data = datagram.as_dict( -# ... include_meta_columns=True, -# ... include_context=True -# ... ) -# """ -# ... - -# def as_table( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_context: bool = False, -# include_source: bool = False, -# ) -> "pa.Table": -# """ -# Convert datagram to PyArrow Table format. - -# Provides a standardized columnar representation suitable for analysis, -# processing, and interoperability with Arrow-based tools. - -# Args: -# include_all_info: If True, include all available information. This option supersedes all other inclusion options. -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. -# include_context: Whether to include the context column. -# include_source: Whether to include source info columns in the schema. - -# Returns: -# PyArrow Table with requested columns. - -# Example: -# >>> table = datagram.as_table() # Data columns only -# >>> full_table = datagram.as_table( -# ... include_meta_columns=True, -# ... include_context=True -# ... ) -# >>> filtered = datagram.as_table(include_meta_columns=["pipeline"]) # same as passing f"{orcapod.META_PREFIX}pipeline" -# """ -# ... - -# # TODO: add this back -# # def as_arrow_compatible_dict( -# # self, -# # include_all_info: bool = False, -# # include_meta_columns: bool | Collection[str] = False, -# # include_context: bool = False, -# # include_source: bool = False, -# # ) -> dict[str, Any]: -# # """Extended version with source info support.""" -# # ... - -# def as_datagram( -# self, -# include_all_info: bool = False, -# include_meta_columns: bool | Collection[str] = False, -# include_source: bool = False, -# ) -> Datagram: -# """ -# Convert the packet to a Datagram. - -# Args: -# include_meta_columns: Controls meta column inclusion. -# - False: Exclude all meta columns (default) -# - True: Include all meta columns -# - Collection[str]: Include meta columns matching these prefixes. If absent, -# {orcapod.META_PREFIX} ('__') prefix is prepended to each key. - -# Returns: -# Datagram: Datagram representation of packet data -# """ -# ... - -# def source_info(self) -> dict[str, str | None]: -# """ -# Return metadata about the packet's source/origin. - -# Provides debugging and lineage information about where the packet -# originated. May include information like: -# - File paths for file-based sources -# - Database connection strings -# - API endpoints -# - Processing pipeline information - -# Returns: -# dict[str, str | None]: Source information for each data column as key-value pairs. -# """ -# ... - -# def with_source_info( -# self, -# **source_info: str | None, -# ) -> Self: -# """ -# Create new packet with updated source information. - -# Adds or updates source metadata for the packet. This is useful for -# tracking data provenance and lineage through the computational graph. - -# Args: -# **source_info: Source metadata as keyword arguments. - -# Returns: -# New packet instance with updated source information. - -# Example: -# >>> updated_packet = packet.with_source_info( -# ... file_path="/new/path/to/file.txt", -# ... source_id="source_123" -# ... ) -# """ -# ... - - -# @runtime_checkable -# class PodFunction(Protocol): -# """ -# A function suitable for use in a FunctionPod. - -# PodFunctions define the computational logic that operates on individual -# packets within a Pod. They represent pure functions that transform -# data values without side effects. - -# These functions are designed to be: -# - Stateless: No dependency on external state -# - Deterministic: Same inputs always produce same outputs -# - Serializable: Can be cached and distributed -# - Type-safe: Clear input/output contracts - -# PodFunctions accept named arguments corresponding to packet fields -# and return transformed data values. -# """ - -# def __call__(self, **kwargs: DataValue) -> None | DataValue: -# """ -# Execute the pod function with the given arguments. - -# The function receives packet data as named arguments and returns -# either transformed data or None (for filtering operations). - -# Args: -# **kwargs: Named arguments mapping packet fields to data values - -# Returns: -# None: Filter out this packet (don't include in output) -# DataValue: Single transformed value - -# Raises: -# TypeError: If required arguments are missing -# ValueError: If argument values are invalid -# """ -# ... - - -# @runtime_checkable -# class Labelable(Protocol): -# """ -# Protocol for objects that can have a human-readable label. - -# Labels provide meaningful names for objects in the computational graph, -# making debugging, visualization, and monitoring much easier. They serve -# as human-friendly identifiers that complement the technical identifiers -# used internally. - -# Labels are optional but highly recommended for: -# - Debugging complex computational graphs -# - Visualization and monitoring tools -# - Error messages and logging -# - User interfaces and dashboards -# """ - -# @property -# def label(self) -> str | None: -# """ -# Return the human-readable label for this object. - -# Labels should be descriptive and help users understand the purpose -# or role of the object in the computational graph. - -# Returns: -# str: Human-readable label for this object -# None: No label is set (will use default naming) -# """ -# ... - - -# @runtime_checkable -# class Stream(ContentIdentifiable, Labelable, Protocol): -# """ -# Base protocol for all streams in Orcapod. - -# Streams represent sequences of (Tag, Packet) pairs flowing through the -# computational graph. They are the fundamental data structure connecting -# kernels and carrying both data and metadata. - -# Streams can be either: -# - Static: Immutable snapshots created at a specific point in time -# - Live: Dynamic streams that stay current with upstream dependencies - -# All streams provide: -# - Iteration over (tag, packet) pairs -# - Type information and schema access -# - Lineage information (source kernel and upstream streams) -# - Basic caching and freshness tracking -# - Conversion to common formats (tables, dictionaries) -# """ - -# @property -# def substream_identities(self) -> tuple[str, ...]: -# """ -# Unique identifiers for sub-streams within this stream. - -# This property provides a way to identify and differentiate -# sub-streams that may be part of a larger stream. It is useful -# for tracking and managing complex data flows. - -# Returns: -# tuple[str, ...]: Unique identifiers for each sub-stream -# """ -# ... - -# @property -# def execution_engine(self) -> ExecutionEngine | None: -# """ -# The execution engine attached to this stream. By default, the stream -# will use this execution engine whenever it needs to perform computation. -# None means the stream is not attached to any execution engine and will default -# to running natively. -# """ - -# @execution_engine.setter -# def execution_engine(self, engine: ExecutionEngine | None) -> None: -# """ -# Set the execution engine for this stream. - -# This allows the stream to use a specific execution engine for -# computation, enabling optimized execution strategies and resource -# management. - -# Args: -# engine: The execution engine to attach to this stream -# """ -# ... - -# def get_substream(self, substream_id: str) -> "Stream": -# """ -# Retrieve a specific sub-stream by its identifier. - -# This method allows access to individual sub-streams within the -# main stream, enabling focused operations on specific data segments. - -# Args: -# substream_id: Unique identifier for the desired sub-stream. - -# Returns: -# Stream: The requested sub-stream if it exists -# """ -# ... - -# @property -# def source(self) -> "Kernel | None": -# """ -# The kernel that produced this stream. - -# This provides lineage information for tracking data flow through -# the computational graph. Root streams (like file sources) may -# have no source kernel. - -# Returns: -# Kernel: The source kernel that created this stream -# None: This is a root stream with no source kernel -# """ -# ... - -# @property -# def upstreams(self) -> tuple["Stream", ...]: -# """ -# Input streams used to produce this stream. - -# These are the streams that were provided as input to the source -# kernel when this stream was created. Used for dependency tracking -# and cache invalidation. - -# Returns: -# tuple[Stream, ...]: Upstream dependency streams (empty for sources) -# """ -# ... - -# def keys(self) -> tuple[tuple[str, ...], tuple[str, ...]]: -# """ -# Available keys/fields in the stream content. - -# Returns the field names present in both tags and packets. -# This provides schema information without requiring type details, -# useful for: -# - Schema inspection and exploration -# - Query planning and optimization -# - Field validation and mapping - -# Returns: -# tuple[tuple[str, ...], tuple[str, ...]]: (tag_keys, packet_keys) -# """ -# ... - -# def types(self, include_system_tags: bool = False) -> tuple[TypeSpec, TypeSpec]: -# """ -# Type specifications for the stream content. - -# Returns the type schema for both tags and packets in this stream. -# This information is used for: -# - Type checking and validation -# - Schema inference and planning -# - Compatibility checking between kernels - -# Returns: -# tuple[TypeSpec, TypeSpec]: (tag_types, packet_types) -# """ -# ... - -# @property -# def last_modified(self) -> datetime | None: -# """ -# When the stream's content was last modified. - -# This property is crucial for caching decisions and dependency tracking: -# - datetime: Content was last modified at this time (cacheable) -# - None: Content is never stable, always recompute (some dynamic streams) - -# Both static and live streams typically return datetime values, but -# live streams update this timestamp whenever their content changes. - -# Returns: -# datetime: Timestamp of last modification for most streams -# None: Stream content is never stable (some special dynamic streams) -# """ -# ... - -# @property -# def is_current(self) -> bool: -# """ -# Whether the stream is up-to-date with its dependencies. - -# A stream is current if its content reflects the latest state of its -# source kernel and upstream streams. This is used for cache validation -# and determining when refresh is needed. - -# For live streams, this should always return True since they stay -# current automatically. For static streams, this indicates whether -# the cached content is still valid. - -# Returns: -# bool: True if stream is up-to-date, False if refresh needed -# """ -# ... - -# def __iter__(self) -> Iterator[tuple[Tag, Packet]]: -# """ -# Iterate over (tag, packet) pairs in the stream. - -# This is the primary way to access stream data. The behavior depends -# on the stream type: -# - Static streams: Return cached/precomputed data -# - Live streams: May trigger computation and always reflect current state - -# Yields: -# tuple[Tag, Packet]: Sequential (tag, packet) pairs -# """ -# ... - -# def iter_packets( -# self, execution_engine: ExecutionEngine | None = None -# ) -> Iterator[tuple[Tag, Packet]]: -# """ -# Alias for __iter__ for explicit packet iteration. - -# Provides a more explicit method name when the intent is to iterate -# over packets specifically, improving code readability. - -# This method must return an immutable iterator -- that is, the returned iterator -# should not change and must consistently return identical tag,packet pairs across -# multiple iterations of the iterator. - -# Note that this is NOT to mean that multiple invocation of `iter_packets` must always -# return an identical iterator. The iterator returned by `iter_packets` may change -# between invocations, but the iterator itself must not change. Consequently, it should be understood -# that the returned iterators may be a burden on memory if the stream is large or infinite. - -# Yields: -# tuple[Tag, Packet]: Sequential (tag, packet) pairs -# """ -# ... - -# def run(self, execution_engine: ExecutionEngine | None = None) -> None: -# """ -# Execute the stream using the provided execution engine. - -# This method triggers computation of the stream content based on its -# source kernel and upstream streams. It returns a new stream instance -# containing the computed (tag, packet) pairs. - -# Args: -# execution_engine: The execution engine to use for computation - -# """ -# ... - -# async def run_async(self, execution_engine: ExecutionEngine | None = None) -> None: -# """ -# Asynchronously execute the stream using the provided execution engine. - -# This method triggers computation of the stream content based on its -# source kernel and upstream streams. It returns a new stream instance -# containing the computed (tag, packet) pairs. - -# Args: -# execution_engine: The execution engine to use for computation - -# """ -# ... - -# def as_df( -# self, -# include_data_context: bool = False, -# include_source: bool = False, -# include_system_tags: bool = False, -# include_content_hash: bool | str = False, -# execution_engine: ExecutionEngine | None = None, -# ) -> "pl.DataFrame | None": -# """ -# Convert the entire stream to a Polars DataFrame. -# """ -# ... - -# def as_table( -# self, -# include_data_context: bool = False, -# include_source: bool = False, -# include_system_tags: bool = False, -# include_content_hash: bool | str = False, -# execution_engine: ExecutionEngine | None = None, -# ) -> "pa.Table": -# """ -# Convert the entire stream to a PyArrow Table. - -# Materializes all (tag, packet) pairs into a single table for -# analysis and processing. This operation may be expensive for -# large streams or live streams that need computation. - -# If include_content_hash is True, an additional column called "_content_hash" -# containing the content hash of each packet is included. If include_content_hash -# is a string, it is used as the name of the content hash column. - -# Returns: -# pa.Table: Complete stream data as a PyArrow Table -# """ -# ... - -# def flow( -# self, execution_engine: ExecutionEngine | None = None -# ) -> Collection[tuple[Tag, Packet]]: -# """ -# Return the entire stream as a collection of (tag, packet) pairs. - -# This method materializes the stream content into a list or similar -# collection type. It is useful for small streams or when you need -# to process all data at once. - -# Args: -# execution_engine: Optional execution engine to use for computation. -# If None, the stream will use its default execution engine. -# """ -# ... - -# def join(self, other_stream: "Stream") -> "Stream": -# """ -# Join this stream with another stream. - -# Combines two streams into a single stream by merging their content. -# The resulting stream contains all (tag, packet) pairs from both -# streams, preserving their order. - -# Args: -# other_stream: The other stream to join with this one. - -# Returns: -# Self: New stream containing combined content from both streams. -# """ -# ... - -# def semi_join(self, other_stream: "Stream") -> "Stream": -# """ -# Perform a semi-join with another stream. - -# This operation filters this stream to only include packets that have -# corresponding tags in the other stream. The resulting stream contains -# all (tag, packet) pairs from this stream that match tags in the other. - -# Args: -# other_stream: The other stream to semi-join with this one. - -# Returns: -# Self: New stream containing filtered content based on the semi-join. -# """ -# ... - -# def map_tags( -# self, name_map: Mapping[str, str], drop_unmapped: bool = True -# ) -> "Stream": -# """ -# Map tag names in this stream to new names based on the provided mapping. -# """ -# ... - -# def map_packets( -# self, name_map: Mapping[str, str], drop_unmapped: bool = True -# ) -> "Stream": -# """ -# Map packet names in this stream to new names based on the provided mapping. -# """ -# ... - - -# @runtime_checkable -# class LiveStream(Stream, Protocol): -# """ -# A stream that automatically stays up-to-date with its upstream dependencies. - -# LiveStream extends the base Stream protocol with capabilities for "up-to-date" -# data flow and reactive computation. Unlike static streams which represent -# snapshots, LiveStreams provide the guarantee that their content always -# reflects the current state of their dependencies. - -# Key characteristics: -# - Automatically refresh the stream if changes in the upstreams are detected -# - Track last_modified timestamp when content changes -# - Support manual refresh triggering and invalidation -# - By design, LiveStream would return True for is_current except when auto-update fails. - -# LiveStreams are always returned by Kernel.__call__() methods, ensuring -# that normal kernel usage produces live, up-to-date results. - -# Caching behavior: -# - last_modified updates whenever content changes -# - Can be cached based on dependency timestamps -# - Invalidation happens automatically when upstreams change - -# Use cases: -# - Real-time data processing pipelines -# - Reactive user interfaces -# - Monitoring and alerting systems -# - Dynamic dashboard updates -# - Any scenario requiring current data -# """ - -# def refresh(self, force: bool = False) -> bool: -# """ -# Manually trigger a refresh of this stream's content. - -# Forces the stream to check its upstream dependencies and update -# its content if necessary. This is useful when: -# - You want to ensure the latest data before a critical operation -# - You need to force computation at a specific time -# - You're debugging data flow issues -# - You want to pre-compute results for performance -# Args: -# force: If True, always refresh even if the stream is current. -# If False, only refresh if the stream is not current. - -# Returns: -# bool: True if the stream was refreshed, False if it was already current. -# Note: LiveStream refreshes automatically on access, so this -# method may be a no-op for some implementations. However, it's -# always safe to call if you need to control when the cache is refreshed. -# """ -# ... - -# def invalidate(self) -> None: -# """ -# Mark this stream as invalid, forcing a refresh on next access. - -# This method is typically called when: -# - Upstream dependencies have changed -# - The source kernel has been modified -# - External data sources have been updated -# - Manual cache invalidation is needed - -# The stream will automatically refresh its content the next time -# it's accessed (via iteration, as_table(), etc.). - -# This is more efficient than immediate refresh when you know the -# data will be accessed later. -# """ -# ... - - -# @runtime_checkable -# class Kernel(ContentIdentifiable, Labelable, Protocol): -# """ -# The fundamental unit of computation in Orcapod. - -# Kernels are the building blocks of computational graphs, transforming -# zero, one, or more input streams into a single output stream. They -# encapsulate computation logic while providing consistent interfaces -# for validation, type checking, and execution. - -# Key design principles: -# - Immutable: Kernels don't change after creation -# - Deterministic: Same inputs always produce same outputs -# - Composable: Kernels can be chained and combined -# - Trackable: All invocations are recorded for lineage -# - Type-safe: Strong typing and validation throughout - -# Execution modes: -# - __call__(): Full-featured execution with tracking, returns LiveStream -# - forward(): Pure computation without side effects, returns Stream - -# The distinction between these modes enables both production use (with -# full tracking) and testing/debugging (without side effects). -# """ - -# @property -# def kernel_id(self) -> tuple[str, ...]: -# """ -# Return a unique identifier for this Pod. - -# The pod_id is used for caching and tracking purposes. It should -# uniquely identify the Pod's computational logic, parameters, and -# any relevant metadata that affects its behavior. - -# Returns: -# tuple[str, ...]: Unique identifier for this Pod -# """ -# ... - -# @property -# def data_context_key(self) -> str: -# """ -# Return the context key for this kernel's data processing. - -# The context key is used to interpret how data columns should be -# processed and converted. It provides semantic meaning to the data -# being processed by this kernel. - -# Returns: -# str: Context key for this kernel's data processing -# """ -# ... - -# @property -# def last_modified(self) -> datetime | None: -# """ -# When the kernel was last modified. For most kernels, this is the timestamp -# of the kernel creation. -# """ -# ... - -# def __call__( -# self, *streams: Stream, label: str | None = None, **kwargs -# ) -> LiveStream: -# """ -# Main interface for kernel invocation with full tracking and guarantees. - -# This is the primary way to invoke kernels in production. It provides -# a complete execution pipeline: -# 1. Validates input streams against kernel requirements -# 2. Registers the invocation with the computational graph -# 3. Calls forward() to perform the actual computation -# 4. Ensures the result is a LiveStream that stays current - -# The returned LiveStream automatically stays up-to-date with its -# upstream dependencies, making it suitable for real-time processing -# and reactive applications. - -# Args: -# *streams: Input streams to process (can be empty for source kernels) -# label: Optional label for this invocation (overrides kernel.label) -# **kwargs: Additional arguments for kernel configuration - -# Returns: -# LiveStream: Live stream that stays up-to-date with upstreams - -# Raises: -# ValidationError: If input streams are invalid for this kernel -# TypeMismatchError: If stream types are incompatible -# ValueError: If required arguments are missing -# """ -# ... - -# def forward(self, *streams: Stream) -> Stream: -# """ -# Perform the actual computation without side effects. - -# This method contains the core computation logic and should be -# overridden by subclasses. It performs pure computation without: -# - Registering with the computational graph -# - Performing validation (caller's responsibility) -# - Guaranteeing result type (may return static or live streams) - -# The returned stream must be accurate at the time of invocation but -# need not stay up-to-date with upstream changes. This makes forward() -# suitable for: -# - Testing and debugging -# - Batch processing where currency isn't required -# - Internal implementation details - -# Args: -# *streams: Input streams to process - -# Returns: -# Stream: Result of the computation (may be static or live) -# """ -# ... - -# def output_types( -# self, *streams: Stream, include_system_tags: bool = False -# ) -> tuple[TypeSpec, TypeSpec]: -# """ -# Determine output types without triggering computation. - -# This method performs type inference based on input stream types, -# enabling efficient type checking and stream property queries. -# It should be fast and not trigger any expensive computation. - -# Used for: -# - Pre-execution type validation -# - Query planning and optimization -# - Schema inference in complex pipelines -# - IDE support and developer tooling - -# Args: -# *streams: Input streams to analyze - -# Returns: -# tuple[TypeSpec, TypeSpec]: (tag_types, packet_types) for output - -# Raises: -# ValidationError: If input types are incompatible -# TypeError: If stream types cannot be processed -# """ -# ... - -# def validate_inputs(self, *streams: Stream) -> None: -# """ -# Validate input streams, raising exceptions if incompatible. - -# This method is called automatically by __call__ before computation -# to provide fail-fast behavior. It should check: -# - Number of input streams -# - Stream types and schemas -# - Any kernel-specific requirements -# - Business logic constraints - -# The goal is to catch errors early, before expensive computation -# begins, and provide clear error messages for debugging. - -# Args: -# *streams: Input streams to validate - -# Raises: -# ValidationError: If streams are invalid for this kernel -# TypeError: If stream types are incompatible -# ValueError: If stream content violates business rules -# """ -# ... - -# def identity_structure(self, streams: Collection[Stream] | None = None) -> Any: -# """ -# Generate a unique identity structure for this kernel and/or kernel invocation. -# When invoked without streams, it should return a structure -# that uniquely identifies the kernel itself (e.g., class name, parameters). -# When invoked with streams, it should include the identity of the streams -# to distinguish different invocations of the same kernel. - -# This structure is used for: -# - Caching and memoization -# - Debugging and error reporting -# - Tracking kernel invocations in computational graphs - -# Args: -# streams: Optional input streams for this invocation. If None, identity_structure is -# based solely on the kernel. If streams are provided, they are included in the identity -# to differentiate between different invocations of the same kernel. - -# Returns: -# Any: Unique identity structure (e.g., tuple of class name and stream identities) -# """ -# ... - - -# @runtime_checkable -# class Pod(Kernel, Protocol): -# """ -# Specialized kernel for packet-level processing with advanced caching. - -# Pods represent a different computational model from regular kernels: -# - Process data one packet at a time (enabling fine-grained parallelism) -# - Support just-in-time evaluation (computation deferred until needed) -# - Provide stricter type contracts (clear input/output schemas) -# - Enable advanced caching strategies (packet-level caching) - -# The Pod abstraction is ideal for: -# - Expensive computations that benefit from caching -# - Operations that can be parallelized at the packet level -# - Transformations with strict type contracts -# - Processing that needs to be deferred until access time -# - Functions that operate on individual data items - -# Pods use a different execution model where computation is deferred -# until results are actually needed, enabling efficient resource usage -# and fine-grained caching. -# """ - -# @property -# def version(self) -> str: ... - -# def get_record_id(self, packet: Packet, execution_engine_hash: str) -> str: ... - -# @property -# def tiered_pod_id(self) -> dict[str, str]: -# """ -# Return a dictionary representation of the tiered pod's unique identifier. -# The key is supposed to be ordered from least to most specific, allowing -# for hierarchical identification of the pod. - -# This is primarily used for tiered memoization/caching strategies. - -# Returns: -# dict[str, str]: Dictionary representation of the pod's ID -# """ -# ... - -# def input_packet_types(self) -> TypeSpec: -# """ -# TypeSpec for input packets that this Pod can process. - -# Defines the exact schema that input packets must conform to. -# Pods are typically much stricter about input types than regular -# kernels, requiring precise type matching for their packet-level -# processing functions. - -# This specification is used for: -# - Runtime type validation -# - Compile-time type checking -# - Schema inference and documentation -# - Input validation and error reporting - -# Returns: -# TypeSpec: Dictionary mapping field names to required packet types -# """ -# ... - -# def output_packet_types(self) -> TypeSpec: -# """ -# TypeSpec for output packets that this Pod produces. - -# Defines the schema of packets that will be produced by this Pod. -# This is typically determined by the Pod's computational function -# and is used for: -# - Type checking downstream kernels -# - Schema inference in complex pipelines -# - Query planning and optimization -# - Documentation and developer tooling - -# Returns: -# TypeSpec: Dictionary mapping field names to output packet types -# """ -# ... - -# async def async_call( -# self, -# tag: Tag, -# packet: Packet, -# record_id: str | None = None, -# execution_engine: ExecutionEngine | None = None, -# ) -> tuple[Tag, Packet | None]: ... - -# def call( -# self, -# tag: Tag, -# packet: Packet, -# record_id: str | None = None, -# execution_engine: ExecutionEngine | None = None, -# ) -> tuple[Tag, Packet | None]: -# """ -# Process a single packet with its associated tag. - -# This is the core method that defines the Pod's computational behavior. -# It processes one (tag, packet) pair at a time, enabling: -# - Fine-grained caching at the packet level -# - Parallelization opportunities -# - Just-in-time evaluation -# - Filtering operations (by returning None) - -# The method signature supports: -# - Tag transformation (modify metadata) -# - Packet transformation (modify content) -# - Filtering (return None to exclude packet) -# - Pass-through (return inputs unchanged) - -# Args: -# tag: Metadata associated with the packet -# packet: The data payload to process - -# Returns: -# tuple[Tag, Packet | None]: -# - Tag: Output tag (may be modified from input) -# - Packet: Processed packet, or None to filter it out - -# Raises: -# TypeError: If packet doesn't match input_packet_types -# ValueError: If packet data is invalid for processing -# """ -# ... - - -# @runtime_checkable -# class CachedPod(Pod, Protocol): -# async def async_call( -# self, -# tag: Tag, -# packet: Packet, -# record_id: str | None = None, -# execution_engine: ExecutionEngine | None = None, -# skip_cache_lookup: bool = False, -# skip_cache_insert: bool = False, -# ) -> tuple[Tag, Packet | None]: ... - -# def call( -# self, -# tag: Tag, -# packet: Packet, -# record_id: str | None = None, -# execution_engine: ExecutionEngine | None = None, -# skip_cache_lookup: bool = False, -# skip_cache_insert: bool = False, -# ) -> tuple[Tag, Packet | None]: -# """ -# Process a single packet with its associated tag. - -# This is the core method that defines the Pod's computational behavior. -# It processes one (tag, packet) pair at a time, enabling: -# - Fine-grained caching at the packet level -# - Parallelization opportunities -# - Just-in-time evaluation -# - Filtering operations (by returning None) - -# The method signature supports: -# - Tag transformation (modify metadata) -# - Packet transformation (modify content) -# - Filtering (return None to exclude packet) -# - Pass-through (return inputs unchanged) - -# Args: -# tag: Metadata associated with the packet -# packet: The data payload to process - -# Returns: -# tuple[Tag, Packet | None]: -# - Tag: Output tag (may be modified from input) -# - Packet: Processed packet, or None to filter it out - -# Raises: -# TypeError: If packet doesn't match input_packet_types -# ValueError: If packet data is invalid for processing -# """ -# ... - -# def get_all_records( -# self, include_system_columns: bool = False -# ) -> "pa.Table | None": -# """ -# Retrieve all records processed by this Pod. - -# This method returns a table containing all packets processed by the Pod, -# including metadata and system columns if requested. It is useful for: -# - Debugging and analysis -# - Auditing and data lineage tracking -# - Performance monitoring - -# Args: -# include_system_columns: Whether to include system columns in the output - -# Returns: -# pa.Table | None: A table containing all processed records, or None if no records are available -# """ -# ... - - -# @runtime_checkable -# class Source(Kernel, Stream, Protocol): -# """ -# Entry point for data into the computational graph. - -# Sources are special objects that serve dual roles: -# - As Kernels: Can be invoked to produce streams -# - As Streams: Directly provide data without upstream dependencies - -# Sources represent the roots of computational graphs and typically -# interface with external data sources. They bridge the gap between -# the outside world and the Orcapod computational model. - -# Common source types: -# - File readers (CSV, JSON, Parquet, etc.) -# - Database connections and queries -# - API endpoints and web services -# - Generated data sources (synthetic data) -# - Manual data input and user interfaces -# - Message queues and event streams - -# Sources have unique properties: -# - No upstream dependencies (upstreams is empty) -# - Can be both invoked and iterated -# - Serve as the starting point for data lineage -# - May have their own refresh/update mechanisms -# """ - -# @property -# def tag_keys(self) -> tuple[str, ...]: -# """ -# Return the keys used for the tag in the pipeline run records. -# This is used to store the run-associated tag info. -# """ -# ... - -# @property -# def packet_keys(self) -> tuple[str, ...]: -# """ -# Return the keys used for the packet in the pipeline run records. -# This is used to store the run-associated packet info. -# """ -# ... - -# def get_all_records( -# self, include_system_columns: bool = False -# ) -> "pa.Table | None": -# """ -# Retrieve all records from the source. - -# Args: -# include_system_columns: Whether to include system columns in the output - -# Returns: -# pa.Table | None: A table containing all records, or None if no records are available -# """ -# ... - -# def as_lazy_frame(self, sort_by_tags: bool = False) -> "pl.LazyFrame | None": ... - -# def as_df(self, sort_by_tags: bool = True) -> "pl.DataFrame | None": ... - -# def as_polars_df(self, sort_by_tags: bool = False) -> "pl.DataFrame | None": ... - -# def as_pandas_df(self, sort_by_tags: bool = False) -> "pd.DataFrame | None": ... - - -# @runtime_checkable -# class Tracker(Protocol): -# """ -# Records kernel invocations and stream creation for computational graph tracking. - -# Trackers are responsible for maintaining the computational graph by recording -# relationships between kernels, streams, and invocations. They enable: -# - Lineage tracking and data provenance -# - Caching and memoization strategies -# - Debugging and error analysis -# - Performance monitoring and optimization -# - Reproducibility and auditing - -# Multiple trackers can be active simultaneously, each serving different -# purposes (e.g., one for caching, another for debugging, another for -# monitoring). This allows for flexible and composable tracking strategies. - -# Trackers can be selectively activated/deactivated to control overhead -# and focus on specific aspects of the computational graph. -# """ - -# def set_active(self, active: bool = True) -> None: -# """ -# Set the active state of the tracker. - -# When active, the tracker will record all kernel invocations and -# stream creations. When inactive, no recording occurs, reducing -# overhead for performance-critical sections. - -# Args: -# active: True to activate recording, False to deactivate -# """ -# ... - -# def is_active(self) -> bool: -# """ -# Check if the tracker is currently recording invocations. - -# Returns: -# bool: True if tracker is active and recording, False otherwise -# """ -# ... - -# def record_kernel_invocation( -# self, kernel: Kernel, upstreams: tuple[Stream, ...], label: str | None = None -# ) -> None: -# """ -# Record a kernel invocation in the computational graph. - -# This method is called whenever a kernel is invoked. The tracker -# should record: -# - The kernel and its properties -# - The input streams that were used as input -# - Timing and performance information -# - Any relevant metadata - -# Args: -# kernel: The kernel that was invoked -# upstreams: The input streams used for this invocation -# """ -# ... - -# def record_source_invocation( -# self, source: Source, label: str | None = None -# ) -> None: -# """ -# Record a source invocation in the computational graph. - -# This method is called whenever a source is invoked. The tracker -# should record: -# - The source and its properties -# - Timing and performance information -# - Any relevant metadata - -# Args: -# source: The source that was invoked -# """ -# ... - -# def record_pod_invocation( -# self, pod: Pod, upstreams: tuple[Stream, ...], label: str | None = None -# ) -> None: -# """ -# Record a pod invocation in the computational graph. - -# This method is called whenever a pod is invoked. The tracker -# should record: -# - The pod and its properties -# - The upstream streams that were used as input -# - Timing and performance information -# - Any relevant metadata - -# Args: -# pod: The pod that was invoked -# upstreams: The input streams used for this invocation -# """ -# ... - - -# @runtime_checkable -# class TrackerManager(Protocol): -# """ -# Manages multiple trackers and coordinates their activity. - -# The TrackerManager provides a centralized way to: -# - Register and manage multiple trackers -# - Coordinate recording across all active trackers -# - Provide a single interface for graph recording -# - Enable dynamic tracker registration/deregistration - -# This design allows for: -# - Multiple concurrent tracking strategies -# - Pluggable tracking implementations -# - Easy testing and debugging (mock trackers) -# - Performance optimization (selective tracking) -# """ - -# def get_active_trackers(self) -> list[Tracker]: -# """ -# Get all currently active trackers. - -# Returns only trackers that are both registered and active, -# providing the list of trackers that will receive recording events. - -# Returns: -# list[Tracker]: List of trackers that are currently recording -# """ -# ... - -# def register_tracker(self, tracker: Tracker) -> None: -# """ -# Register a new tracker in the system. - -# The tracker will be included in future recording operations -# if it is active. Registration is separate from activation -# to allow for dynamic control of tracking overhead. - -# Args: -# tracker: The tracker to register -# """ -# ... - -# def deregister_tracker(self, tracker: Tracker) -> None: -# """ -# Remove a tracker from the system. - -# The tracker will no longer receive recording notifications -# even if it is still active. This is useful for: -# - Cleaning up temporary trackers -# - Removing failed or problematic trackers -# - Dynamic tracker management - -# Args: -# tracker: The tracker to remove -# """ -# ... - -# def record_kernel_invocation( -# self, kernel: Kernel, upstreams: tuple[Stream, ...], label: str | None = None -# ) -> None: -# """ -# Record a stream in all active trackers. - -# This method broadcasts the stream recording to all currently -# active and registered trackers. It provides a single point -# of entry for recording events, simplifying kernel implementations. - -# Args: -# stream: The stream to record in all active trackers -# """ -# ... - -# def record_source_invocation( -# self, source: Source, label: str | None = None -# ) -> None: -# """ -# Record a source invocation in the computational graph. - -# This method is called whenever a source is invoked. The tracker -# should record: -# - The source and its properties -# - Timing and performance information -# - Any relevant metadata - -# Args: -# source: The source that was invoked -# """ -# ... - -# def record_pod_invocation( -# self, pod: Pod, upstreams: tuple[Stream, ...], label: str | None = None -# ) -> None: -# """ -# Record a stream in all active trackers. - -# This method broadcasts the stream recording to all currently` -# active and registered trackers. It provides a single point -# of entry for recording events, simplifying kernel implementations. - -# Args: -# stream: The stream to record in all active trackers -# """ -# ... - -# def no_tracking(self) -> ContextManager[None]: ... diff --git a/src/orcapod/protocols/node_protocols.py b/src/orcapod/protocols/node_protocols.py new file mode 100644 index 00000000..f51e6de6 --- /dev/null +++ b/src/orcapod/protocols/node_protocols.py @@ -0,0 +1,104 @@ +"""Node protocols for orchestrator interaction. + +Defines the three node protocols (Source, Function, Operator) that +formalize the interface between orchestrators and graph nodes, plus +TypeGuard dispatch functions for runtime type narrowing. + +Each protocol exposes ``execute`` (sync) and ``async_execute`` (async). +Nodes own their execution — caching, per-packet logic, and persistence +are internal. Orchestrators are topology schedulers. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Protocol, TypeGuard, runtime_checkable + +if TYPE_CHECKING: + from orcapod.channels import ReadableChannel, WritableChannel + from orcapod.core.nodes import GraphNode + from orcapod.protocols.observability_protocols import ExecutionObserverProtocol + from orcapod.protocols.core_protocols import ( + PacketProtocol, + StreamProtocol, + TagProtocol, + ) + + +@runtime_checkable +class SourceNodeProtocol(Protocol): + """Protocol for source nodes in orchestrated execution.""" + + node_type: str + + def execute( + self, + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... + + async def async_execute( + self, + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> None: ... + + +@runtime_checkable +class FunctionNodeProtocol(Protocol): + """Protocol for function nodes in orchestrated execution.""" + + node_type: str + + def execute( + self, + input_stream: "StreamProtocol", + *, + observer: "ExecutionObserverProtocol | None" = None, + error_policy: str = "continue", + ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... + + async def async_execute( + self, + input_channel: "ReadableChannel[tuple[TagProtocol, PacketProtocol]]", + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> None: ... + + +@runtime_checkable +class OperatorNodeProtocol(Protocol): + """Protocol for operator nodes in orchestrated execution.""" + + node_type: str + + def execute( + self, + *input_streams: "StreamProtocol", + observer: "ExecutionObserverProtocol | None" = None, + ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... + + async def async_execute( + self, + inputs: "Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]]", + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: "ExecutionObserverProtocol | None" = None, + ) -> None: ... + + +def is_source_node(node: "GraphNode") -> TypeGuard[SourceNodeProtocol]: + """Check if a node is a source node.""" + return node.node_type == "source" + + +def is_function_node(node: "GraphNode") -> TypeGuard[FunctionNodeProtocol]: + """Check if a node is a function node.""" + return node.node_type == "function" + + +def is_operator_node(node: "GraphNode") -> TypeGuard[OperatorNodeProtocol]: + """Check if a node is an operator node.""" + return node.node_type == "operator" diff --git a/src/orcapod/protocols/observability_protocols.py b/src/orcapod/protocols/observability_protocols.py new file mode 100644 index 00000000..dad3d0fd --- /dev/null +++ b/src/orcapod/protocols/observability_protocols.py @@ -0,0 +1,140 @@ +"""Observability protocols for pipeline execution tracking and logging. + +Defines: + +* :class:`PacketExecutionLoggerProtocol` — receives captured I/O from a single + packet execution and persists it to a configured sink. +* :class:`ExecutionObserverProtocol` — lifecycle hooks for pipeline/node/packet + events, plus a factory method for creating context-bound loggers. + +Both follow the same runtime-checkable Protocol pattern used throughout the +rest of the orcapod codebase. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from orcapod.core.nodes import GraphNode + from orcapod.pipeline.logging_capture import CapturedLogs + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +@runtime_checkable +class PacketExecutionLoggerProtocol(Protocol): + """Receives captured execution output and persists it. + + A logger is *bound* to a specific packet execution context (node, tag, + packet) when created by the Observer. It knows the destination (e.g. a + Delta Lake table) but does not know how the logs were collected — that is + the executor's responsibility. + """ + + def record(self, captured: "CapturedLogs") -> None: + """Persist the captured logs from a packet function execution. + + Called after every packet execution (success or failure), except for + cache hits when ``log_cache_hits=False`` (the default). + """ + ... + + +@runtime_checkable +class ExecutionObserverProtocol(Protocol): + """Observer protocol for pipeline execution lifecycle events. + + Instantiated once outside the pipeline and injected into the orchestrator. + Provides hooks for lifecycle events at the run, node, and packet level, and + acts as a factory for context-specific loggers. + + ``on_packet_start`` / ``on_packet_end`` / ``on_packet_crash`` are invoked + only for function nodes. ``on_node_start`` / ``on_node_end`` are invoked + for all node types. + """ + + def on_run_start(self, run_id: str) -> None: + """Called at the very start of an orchestrator ``run()`` call. + + Args: + run_id: A UUID string unique to this execution run. All loggers + created during the run will be stamped with this ID. + """ + ... + + def on_run_end(self, run_id: str) -> None: + """Called at the very end of an orchestrator ``run()`` call. + + Args: + run_id: The same UUID passed to ``on_run_start``. + """ + ... + + def on_node_start(self, node: "GraphNode") -> None: + """Called before a node begins processing its packets.""" + ... + + def on_node_end(self, node: "GraphNode") -> None: + """Called after a node finishes processing all packets.""" + ... + + def on_packet_start( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + ) -> None: + """Called before a packet is processed by a function node.""" + ... + + def on_packet_end( + self, + node: "GraphNode", + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + cached: bool, + ) -> None: + """Called after a packet is successfully processed (or served from cache). + + Args: + cached: ``True`` when the result came from a database cache and + the user function was not executed. + """ + ... + + def on_packet_crash( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + error: Exception, + ) -> None: + """Called when a packet's execution fails. + + Covers both user-function exceptions (captured on the worker) and + system-level crashes (e.g. ``WorkerCrashedError`` from Ray). The + pipeline continues processing remaining packets rather than aborting. + """ + ... + + def create_packet_logger( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + pipeline_path: tuple[str, ...] = (), + ) -> PacketExecutionLoggerProtocol: + """Create a context-bound logger for a single packet execution. + + The returned logger is pre-stamped with the node label, run ID, and + packet identity so every ``record()`` call writes the correct context + without the executor needing to know anything about the pipeline. + + Args: + node: The graph node being executed. + tag: The tag for the packet being processed. + packet: The input packet being processed. + pipeline_path: The node's pipeline path for log storage scoping. + """ + ... diff --git a/src/orcapod/protocols/pipeline_protocols.py b/src/orcapod/protocols/pipeline_protocols.py index 04ce8538..d8c8e6c3 100644 --- a/src/orcapod/protocols/pipeline_protocols.py +++ b/src/orcapod/protocols/pipeline_protocols.py @@ -1,27 +1,27 @@ # Protocols for pipeline and nodes -from typing import Protocol, runtime_checkable, TYPE_CHECKING -from orcapod.protocols import core_protocols as cp +from typing import TYPE_CHECKING, Protocol, runtime_checkable +from orcapod.protocols import core_protocols as cp if TYPE_CHECKING: import pyarrow as pa -class Node(cp.Source, Protocol): +class NodeProtocol(cp.Source, Protocol): # def record_pipeline_outputs(self): # pass ... @runtime_checkable -class PodNode(cp.CachedPod, Protocol): +class PodNodeProtocol(cp.CachedPod, Protocol): def get_all_records( self, include_system_columns: bool = False ) -> "pa.Table | None": """ - Retrieve all tag and packet processed by this Pod. + Retrieve all tag and packet processed by this PodProtocol. - This method returns a table containing all packets processed by the Pod, + This method returns a table containing all packets processed by the PodProtocol, including metadata and system columns if requested. It is useful for: - Debugging and analysis - Auditing and data lineage tracking @@ -50,8 +50,8 @@ def flush(self): def add_pipeline_record( self, - tag: cp.Tag, - input_packet: cp.Packet, + tag: cp.TagProtocol, + input_packet: cp.PacketProtocol, packet_record_id: str, retrieved: bool | None = None, skip_cache_lookup: bool = False, diff --git a/src/orcapod/protocols/semantic_types_protocols.py b/src/orcapod/protocols/semantic_types_protocols.py index 855f8a07..96ea4028 100644 --- a/src/orcapod/protocols/semantic_types_protocols.py +++ b/src/orcapod/protocols/semantic_types_protocols.py @@ -1,29 +1,27 @@ -from typing import TYPE_CHECKING, Any, Protocol from collections.abc import Callable -from orcapod.types import PythonSchema, PythonSchemaLike +from typing import TYPE_CHECKING, Any, Protocol +from orcapod.types import DataType, Schema, SchemaLike if TYPE_CHECKING: import pyarrow as pa -class TypeConverter(Protocol): - def python_type_to_arrow_type(self, python_type: type) -> "pa.DataType": ... +class TypeConverterProtocol(Protocol): + def python_type_to_arrow_type(self, python_type: DataType) -> "pa.DataType": ... def python_schema_to_arrow_schema( - self, python_schema: PythonSchemaLike + self, python_schema: SchemaLike ) -> "pa.Schema": ... - def arrow_type_to_python_type(self, arrow_type: "pa.DataType") -> type: ... + def arrow_type_to_python_type(self, arrow_type: "pa.DataType") -> DataType: ... - def arrow_schema_to_python_schema( - self, arrow_schema: "pa.Schema" - ) -> PythonSchema: ... + def arrow_schema_to_python_schema(self, arrow_schema: "pa.Schema") -> Schema: ... def python_dicts_to_struct_dicts( self, python_dicts: list[dict[str, Any]], - python_schema: PythonSchemaLike | None = None, + python_schema: SchemaLike | None = None, ) -> list[dict[str, Any]]: ... def struct_dicts_to_python_dicts( @@ -35,7 +33,7 @@ def struct_dicts_to_python_dicts( def python_dicts_to_arrow_table( self, python_dicts: list[dict[str, Any]], - python_schema: PythonSchemaLike | None = None, + python_schema: SchemaLike | None = None, arrow_schema: "pa.Schema | None" = None, ) -> "pa.Table": ... @@ -53,11 +51,11 @@ def get_arrow_to_python_converter( # Core protocols -class SemanticStructConverter(Protocol): +class SemanticStructConverterProtocol(Protocol): """Protocol for converting between Python objects and semantic structs.""" @property - def python_type(self) -> type: + def python_type(self) -> DataType: """The Python type this converter can handle.""" ... @@ -74,7 +72,7 @@ def struct_dict_to_python(self, struct_dict: dict[str, Any]) -> Any: """Convert struct dictionary back to Python value.""" ... - def can_handle_python_type(self, python_type: type) -> bool: + def can_handle_python_type(self, python_type: DataType) -> bool: """Check if this converter can handle the given Python type.""" ... diff --git a/src/orcapod/semantic_types/pydata_utils.py b/src/orcapod/semantic_types/pydata_utils.py index 5acc0207..d1bccfdf 100644 --- a/src/orcapod/semantic_types/pydata_utils.py +++ b/src/orcapod/semantic_types/pydata_utils.py @@ -2,8 +2,9 @@ # dictionary of lists from types import UnionType -from typing import Any, Union, get_origin, get_args -from orcapod.types import PythonSchema +from typing import Any, Union + +from orcapod.types import DataType, Schema def pylist_to_pydict(pylist: list[dict]) -> dict: @@ -81,7 +82,7 @@ def pydict_to_pylist(pydict: dict) -> list[dict]: def infer_python_schema_from_pylist_data( data: list[dict], default_type: type = str, -) -> PythonSchema: +) -> Schema: """ Infer schema from sample data (best effort). @@ -96,9 +97,9 @@ def infer_python_schema_from_pylist_data( For production use, explicit schemas are recommended. """ if not data: - return {} + return Schema({}) - schema = {} + schema_data = {} # Get all possible field names all_fields = [] @@ -121,27 +122,29 @@ def infer_python_schema_from_pylist_data( if not non_none_values: # Handle case where all values are None - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None continue # Infer type from non-None values inferred_type = _infer_type_from_values(non_none_values) if inferred_type is None: - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None elif has_none: # Wrap with Optional if None values present - schema[field_name] = inferred_type | None if inferred_type != Any else Any + schema_data[field_name] = ( + inferred_type | None if inferred_type != Any else Any + ) else: - schema[field_name] = inferred_type + schema_data[field_name] = inferred_type - return schema + return Schema(schema_data) def infer_python_schema_from_pydict_data( data: dict[str, list[Any]], default_type: type = str, -) -> PythonSchema: +) -> Schema: """ Infer schema from columnar sample data (best effort). @@ -156,15 +159,15 @@ def infer_python_schema_from_pydict_data( For production use, explicit schemas are recommended. """ if not data: - return {} + return Schema({}) - schema: PythonSchema = {} + schema_data: dict[str, DataType] = {} # Infer type for each field for field_name, field_values in data.items(): if not field_values: # Handle case where field has empty list - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None continue # Separate None and non-None values @@ -173,22 +176,22 @@ def infer_python_schema_from_pydict_data( if not non_none_values: # Handle case where all values are None - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None continue # Infer type from non-None values inferred_type = _infer_type_from_values(non_none_values) if inferred_type is None: - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None elif has_none: # Wrap with Optional if None values present # TODO: consider the case of Any - schema[field_name] = inferred_type | None + schema_data[field_name] = inferred_type | None else: - schema[field_name] = inferred_type + schema_data[field_name] = inferred_type - return schema + return Schema(schema_data) # TODO: reconsider this type hint -- use of Any effectively renders this type hint useless diff --git a/src/orcapod/semantic_types/semantic_registry.py b/src/orcapod/semantic_types/semantic_registry.py index aa1c604e..c2b299b6 100644 --- a/src/orcapod/semantic_types/semantic_registry.py +++ b/src/orcapod/semantic_types/semantic_registry.py @@ -1,12 +1,13 @@ -from typing import Any, TYPE_CHECKING from collections.abc import Mapping -from orcapod.protocols.semantic_types_protocols import SemanticStructConverter -from orcapod.utils.lazy_module import LazyModule +from typing import TYPE_CHECKING, Any -# from orcapod.semantic_types.type_inference import infer_python_schema_from_pylist_data -from orcapod.types import DataType, PythonSchema +from orcapod.protocols.semantic_types_protocols import SemanticStructConverterProtocol from orcapod.semantic_types import pydata_utils +# from orcapod.semantic_types.type_inference import infer_python_schema_from_pylist_data +from orcapod.types import DataType, Schema +from orcapod.utils.lazy_module import LazyModule + if TYPE_CHECKING: import pyarrow as pa else: @@ -23,27 +24,31 @@ class SemanticTypeRegistry: """ @staticmethod - def infer_python_schema_from_pylist(data: list[dict[str, Any]]) -> PythonSchema: + def infer_python_schema_from_pylist(data: list[dict[str, Any]]) -> Schema: """ Infer Python schema from a list of dictionaries (pylist) """ return pydata_utils.infer_python_schema_from_pylist_data(data) @staticmethod - def infer_python_schema_from_pydict(data: dict[str, list[Any]]) -> PythonSchema: + def infer_python_schema_from_pydict(data: dict[str, list[Any]]) -> Schema: # TODO: consider which data type is more efficient and use that pylist or pydict return pydata_utils.infer_python_schema_from_pylist_data( pydata_utils.pydict_to_pylist(data) ) - def __init__(self, converters: Mapping[str, SemanticStructConverter] | None = None): + def __init__( + self, converters: Mapping[str, SemanticStructConverterProtocol] | None = None + ): # Bidirectional mappings between Python types and struct signatures self._python_to_struct: dict[DataType, "pa.StructType"] = {} self._struct_to_python: dict["pa.StructType", DataType] = {} - self._struct_to_converter: dict["pa.StructType", SemanticStructConverter] = {} + self._struct_to_converter: dict[ + "pa.StructType", SemanticStructConverterProtocol + ] = {} # Name mapping for convenience - self._name_to_converter: dict[str, SemanticStructConverter] = {} + self._name_to_converter: dict[str, SemanticStructConverterProtocol] = {} self._struct_to_name: dict["pa.StructType", str] = {} # If initialized with a list of converters, register them @@ -52,7 +57,7 @@ def __init__(self, converters: Mapping[str, SemanticStructConverter] | None = No self.register_converter(semantic_type_name, converter) def register_converter( - self, semantic_type_name: str, converter: SemanticStructConverter + self, semantic_type_name: str, converter: SemanticStructConverterProtocol ) -> None: """ Register a semantic type converter. @@ -102,7 +107,7 @@ def register_converter( def get_converter_for_python_type( self, python_type: DataType - ) -> SemanticStructConverter | None: + ) -> SemanticStructConverterProtocol | None: """Get converter registered to the Python type.""" # Direct lookup first struct_signature = self._python_to_struct.get(python_type) @@ -126,13 +131,13 @@ def get_converter_for_python_type( def get_converter_for_semantic_type( self, semantic_type_name: str - ) -> SemanticStructConverter | None: + ) -> SemanticStructConverterProtocol | None: """Get converter registered to the semantic type name.""" return self._name_to_converter.get(semantic_type_name) def get_converter_for_struct_signature( self, struct_signature: "pa.StructType" - ) -> SemanticStructConverter | None: + ) -> SemanticStructConverterProtocol | None: """ Get converter registered to the Arrow struct signature. """ diff --git a/src/orcapod/semantic_types/semantic_struct_converters.py b/src/orcapod/semantic_types/semantic_struct_converters.py index 3ba45f55..63d1e236 100644 --- a/src/orcapod/semantic_types/semantic_struct_converters.py +++ b/src/orcapod/semantic_types/semantic_struct_converters.py @@ -5,12 +5,16 @@ making semantic types visible in schemas and preserved through operations. """ -from typing import Any, TYPE_CHECKING from pathlib import Path +from typing import TYPE_CHECKING, Any + +from orcapod.types import ContentHash from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import pyarrow as pa + + from orcapod.protocols.hashing_protocols import FileContentHasherProtocol else: pa = LazyModule("pyarrow") @@ -54,7 +58,7 @@ def _format_hash_string(self, hash_bytes: bytes, add_prefix: bool = False) -> st else: return hash_hex - def _compute_content_hash(self, content: bytes) -> bytes: + def _compute_content_hash(self, content: bytes) -> ContentHash: """ Compute SHA-256 hash of content bytes. @@ -66,16 +70,18 @@ def _compute_content_hash(self, content: bytes) -> bytes: """ import hashlib - return hashlib.sha256(content).digest() + digest = hashlib.sha256(content).digest() + return ContentHash(method=f"{self.semantic_type_name}:sha256", digest=digest) # Path-specific implementation class PathStructConverter(SemanticStructConverterBase): - """Converter for pathlib.Path objects to/from semantic structs.""" + """Converter for pathlib.Path objects to/from semantic structs of form { path: "/value/of/path"}""" - def __init__(self): + def __init__(self, file_hasher: "FileContentHasherProtocol"): super().__init__("path") self._python_type = Path + self._file_hasher = file_hasher # Define the Arrow struct type for paths self._arrow_struct_type = pa.struct( @@ -116,19 +122,18 @@ def can_handle_python_type(self, python_type: type) -> bool: def can_handle_struct_type(self, struct_type: pa.StructType) -> bool: """Check if this converter can handle the given struct type.""" # Check if struct has the expected fields - field_names = [field.name for field in struct_type] - expected_fields = {"path"} - - if set(field_names) != expected_fields: - return False + for field in self._arrow_struct_type: + if ( + field.name not in struct_type.names + or struct_type[field.name].type != field.type + ): + return False - # Check field types - field_types = {field.name: field.type for field in struct_type} - - return field_types["path"] == pa.large_string() + return True def is_semantic_struct(self, struct_dict: dict[str, Any]) -> bool: """Check if a struct dictionary represents this semantic type.""" + # TODO: infer this check based on identified struct type as defined in the __init__ return set(struct_dict.keys()) == {"path"} and isinstance( struct_dict["path"], str ) @@ -136,39 +141,28 @@ def is_semantic_struct(self, struct_dict: dict[str, Any]) -> bool: def hash_struct_dict( self, struct_dict: dict[str, Any], add_prefix: bool = False ) -> str: - """ - Compute hash of the file content pointed to by the path. + """Compute hash of a path semantic type by hashing the file content. Args: - struct_dict: Arrow struct dictionary with 'path' field - add_prefix: If True, prefix with semantic type and algorithm info + struct_dict: Dict with a "path" key containing a file path string. + add_prefix: If True, prefix with "path:sha256:...". Returns: - Hash string of the file content, optionally prefixed + Hash string of the file content. Raises: - FileNotFoundError: If the file doesn't exist - PermissionError: If the file can't be read - OSError: For other file system errors + FileNotFoundError: If the path does not exist. + IsADirectoryError: If the path is a directory. """ path_str = struct_dict.get("path") if path_str is None: - raise ValueError("Missing 'path' field in struct") + raise ValueError("Missing 'path' field in struct dict") path = Path(path_str) + if not path.exists(): + raise FileNotFoundError(f"Path does not exist: {path}") + if path.is_dir(): + raise IsADirectoryError(f"Path is a directory: {path}") - try: - # TODO: replace with FileHasher implementation - # Read file content and compute hash - content = path.read_bytes() - hash_bytes = self._compute_content_hash(content) - return self._format_hash_string(hash_bytes, add_prefix) - - except FileNotFoundError: - raise FileNotFoundError(f"File not found: {path}") - except PermissionError: - raise PermissionError(f"Permission denied reading file: {path}") - except IsADirectoryError: - raise ValueError(f"Path is a directory, not a file: {path}") - except OSError as e: - raise OSError(f"Error reading file {path}: {e}") + content_hash = self._file_hasher.hash_file(path) + return self._format_hash_string(content_hash.digest, add_prefix=add_prefix) diff --git a/src/orcapod/semantic_types/type_inference.py b/src/orcapod/semantic_types/type_inference.py index b51c2673..5ddc58aa 100644 --- a/src/orcapod/semantic_types/type_inference.py +++ b/src/orcapod/semantic_types/type_inference.py @@ -1,14 +1,14 @@ -from types import UnionType -from typing import Any, Union, get_origin, get_args from collections.abc import Collection, Mapping +from types import UnionType +from typing import Any, Union -from orcapod.types import PythonSchema +from orcapod.types import DataType, Schema def infer_python_schema_from_pylist_data( data: Collection[Mapping[str, Any]], default_type: type = str, -) -> PythonSchema: +) -> Schema: """ Infer schema from sample data (best effort). @@ -23,9 +23,9 @@ def infer_python_schema_from_pylist_data( For production use, explicit schemas are recommended. """ if not data: - return {} + return Schema.empty() - schema: PythonSchema = {} + schema_data: dict[str, DataType] = {} # Get all possible field names all_fields = [] @@ -48,28 +48,28 @@ def infer_python_schema_from_pylist_data( if not non_none_values: # Handle case where all values are None - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None continue # Infer type from non-None values inferred_type = _infer_type_from_values(non_none_values) if inferred_type is None: - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None elif has_none: # Wrap with Optional if None values present # TODO: consider the case of Any - schema[field_name] = inferred_type | None + schema_data[field_name] = inferred_type | None else: - schema[field_name] = inferred_type + schema_data[field_name] = inferred_type - return schema + return Schema(schema_data) def infer_python_schema_from_pydict_data( data: dict[str, list[Any]], default_type: type = str, -) -> PythonSchema: +) -> Schema: """ Infer schema from columnar sample data (best effort). @@ -84,15 +84,17 @@ def infer_python_schema_from_pydict_data( For production use, explicit schemas are recommended. """ if not data: - return {} + return Schema() - schema: PythonSchema = {} + schema_data = {} # Infer type for each field for field_name, field_values in data.items(): if not field_values: # Handle case where field has empty list - schema[field_name] = default_type | None + values = dict(schema_data) + values[field_name] = default_type | None + schema_data[field_name] = default_type | None continue # Separate None and non-None values @@ -101,26 +103,26 @@ def infer_python_schema_from_pydict_data( if not non_none_values: # Handle case where all values are None - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None continue # Infer type from non-None values inferred_type = _infer_type_from_values(non_none_values) if inferred_type is None: - schema[field_name] = default_type | None + schema_data[field_name] = default_type | None elif has_none: # Wrap with Optional if None values present # TODO: consider the case of Any - schema[field_name] = inferred_type | None + schema_data[field_name] = inferred_type | None else: - schema[field_name] = inferred_type + schema_data[field_name] = inferred_type - return schema + return Schema(schema_data) # TODO: reconsider this type hint -- use of Any effectively renders this type hint useless -def _infer_type_from_values(values: list) -> type | UnionType | Any | None: +def _infer_type_from_values(values: list) -> DataType | None: """Infer type from a list of non-None values.""" if not values: return None @@ -301,7 +303,7 @@ def test_schema_inference(): print("Inferred Schema:") for field, field_type in sorted(schema.items()): - print(f" {field}: {field_type}") + print(f" {field}: {getattr(field_type, '__name__', field_type)}") return schema diff --git a/src/orcapod/semantic_types/universal_converter.py b/src/orcapod/semantic_types/universal_converter.py index c3ba97e2..be76e808 100644 --- a/src/orcapod/semantic_types/universal_converter.py +++ b/src/orcapod/semantic_types/universal_converter.py @@ -9,21 +9,19 @@ 5. Integrates seamlessly with semantic type registries """ +import hashlib +import logging import types -from typing import TypedDict, Any import typing from collections.abc import Callable, Mapping -import hashlib -import logging -from orcapod.contexts import DataContext, resolve_context -from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry -from orcapod.semantic_types.type_inference import infer_python_schema_from_pylist_data # Handle generic types -from typing import get_origin, get_args +from typing import TYPE_CHECKING, Any, TypedDict, get_args, get_origin -from typing import TYPE_CHECKING -from orcapod.types import DataType, PythonSchemaLike +from orcapod.contexts import DataContext, resolve_context +from orcapod.semantic_types.semantic_registry import SemanticTypeRegistry +from orcapod.semantic_types.type_inference import infer_python_schema_from_pylist_data +from orcapod.types import DataType, Schema, SchemaLike from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: @@ -135,13 +133,12 @@ def python_type_to_arrow_type(self, python_type: DataType) -> pa.DataType: return arrow_type - def python_schema_to_arrow_schema( - self, python_schema: PythonSchemaLike - ) -> pa.Schema: + def python_schema_to_arrow_schema(self, python_schema: SchemaLike) -> pa.Schema: """ - Convert a Python schema (dict of field names to types) to an Arrow schema. + Convert a Python schema (dict of field names to data types) to an Arrow schema. - This uses the main conversion logic and caches results for performance. + This uses the main conversion logic, using caches for known type conversion for + an improved performance. """ fields = [] for field_name, python_type in python_schema.items(): @@ -167,32 +164,33 @@ def arrow_type_to_python_type(self, arrow_type: pa.DataType) -> DataType: return python_type - def arrow_schema_to_python_schema(self, arrow_schema: pa.Schema) -> dict[str, type]: + def arrow_schema_to_python_schema(self, arrow_schema: pa.Schema) -> Schema: """ - Convert an Arrow schema to a Python schema (dict of field names to types). + Convert an Arrow schema to a Python Schema (mapping of field names to types). - This uses the main conversion logic and caches results for performance. + This uses the main conversion logic, using caches for known type conversion for + an improved performance. """ - python_schema = {} + fields = {} for field in arrow_schema: - python_type = self.arrow_type_to_python_type(field.type) - python_schema[field.name] = python_type + fields[field.name] = self.arrow_type_to_python_type(field.type) - return python_schema + return Schema(fields) def python_dicts_to_struct_dicts( self, python_dicts: list[dict[str, Any]], - python_schema: PythonSchemaLike | None = None, + python_schema: SchemaLike | None = None, ) -> list[dict[str, Any]]: """ - Convert a list of Python dictionaries to an Arrow table. + Convert a list of Python dictionaries to Arrow compatible list of structural dicts. This uses the main conversion logic and caches results for performance. """ if python_schema is None: python_schema = infer_python_schema_from_pylist_data(python_dicts) + # prepare a LUT of converters from Python to Arrow-compatible data type converters = { field_name: self.get_python_to_arrow_converter(python_type) for field_name, python_type in python_schema.items() @@ -216,7 +214,7 @@ def struct_dict_to_python_dict( arrow_schema: pa.Schema, ) -> list[dict[str, Any]]: """ - Convert a list of Arrow structs to Python dictionaries. + Convert a list of Arrow-compatible structural dictionaries to Python dictionaries. This uses the main conversion logic and caches results for performance. """ @@ -241,7 +239,7 @@ def struct_dict_to_python_dict( def python_dicts_to_arrow_table( self, python_dicts: list[dict[str, Any]], - python_schema: PythonSchemaLike | None = None, + python_schema: SchemaLike | None = None, arrow_schema: "pa.Schema | None" = None, ) -> pa.Table: """ @@ -565,7 +563,7 @@ def _get_or_create_typeddict_for_struct( return typeddict_class - # TODO: consider setting type of field_specs to PythonSchema + # TODO: consider setting type of field_specs to Schema def _generate_unique_type_name(self, field_specs: Mapping[str, DataType]) -> str: """Generate a unique name for TypedDict based on field specifications.""" @@ -634,7 +632,7 @@ def _create_python_to_arrow_converter( element_converter = self.get_python_to_arrow_converter(args[0]) return ( lambda value: [element_converter(item) for item in value] - if value + if value is not None else [] ) @@ -646,7 +644,7 @@ def _create_python_to_arrow_converter( {"key": key_converter(k), "value": value_converter(v)} for k, v in value.items() ] - if value + if value is not None else [] ) diff --git a/src/orcapod/sources/__init__.py b/src/orcapod/sources/__init__.py new file mode 100644 index 00000000..feae7eca --- /dev/null +++ b/src/orcapod/sources/__init__.py @@ -0,0 +1,4 @@ +"""Public re-export of orcapod.core.sources.""" + +from orcapod.core.sources import * # noqa: F401,F403 +from orcapod.core.sources import __all__ diff --git a/src/orcapod/streams/__init__.py b/src/orcapod/streams/__init__.py new file mode 100644 index 00000000..605ee617 --- /dev/null +++ b/src/orcapod/streams/__init__.py @@ -0,0 +1,4 @@ +"""Public re-export of orcapod.core.streams.""" + +from orcapod.core.streams import * # noqa: F401,F403 +from orcapod.core.streams import __all__ diff --git a/src/orcapod/core/system_constants.py b/src/orcapod/system_constants.py similarity index 71% rename from src/orcapod/core/system_constants.py rename to src/orcapod/system_constants.py index 0cc55038..65d1d83e 100644 --- a/src/orcapod/core/system_constants.py +++ b/src/orcapod/system_constants.py @@ -3,10 +3,14 @@ DATAGRAM_PREFIX = "_" SOURCE_INFO_PREFIX = "source_" POD_ID_PREFIX = "pod_id_" +PF_VARIATION_PREFIX = "pf_var_" +PF_EXECUTION_PREFIX = "pf_exec_" DATA_CONTEXT_KEY = "context_key" -INPUT_PACKET_HASH = "input_packet_hash" +INPUT_PACKET_HASH_COL = "input_packet_hash" PACKET_RECORD_ID = "packet_id" -SYSTEM_TAG_PREFIX = "tag" +SYSTEM_TAG_PREFIX_NAME = "tag" +SYSTEM_TAG_SOURCE_ID_FIELD = "source_id" +SYSTEM_TAG_RECORD_ID_FIELD = "record_id" POD_VERSION = "pod_version" EXECUTION_ENGINE = "execution_engine" POD_TIMESTAMP = "pod_ts" @@ -48,8 +52,16 @@ def POD_ID_PREFIX(self) -> str: return f"{self._global_prefix}{SYSTEM_COLUMN_PREFIX}{POD_ID_PREFIX}" @property - def INPUT_PACKET_HASH(self) -> str: - return f"{self._global_prefix}{SYSTEM_COLUMN_PREFIX}{INPUT_PACKET_HASH}" + def PF_VARIATION_PREFIX(self) -> str: + return f"{self._global_prefix}{SYSTEM_COLUMN_PREFIX}{PF_VARIATION_PREFIX}" + + @property + def PF_EXECUTION_PREFIX(self) -> str: + return f"{self._global_prefix}{SYSTEM_COLUMN_PREFIX}{PF_EXECUTION_PREFIX}" + + @property + def INPUT_PACKET_HASH_COL(self) -> str: + return f"{self._global_prefix}{SYSTEM_COLUMN_PREFIX}{INPUT_PACKET_HASH_COL}" @property def PACKET_RECORD_ID(self) -> str: @@ -57,7 +69,15 @@ def PACKET_RECORD_ID(self) -> str: @property def SYSTEM_TAG_PREFIX(self) -> str: - return f"{self._global_prefix}{DATAGRAM_PREFIX}{SYSTEM_TAG_PREFIX}{self.BLOCK_SEPARATOR}" + return f"{self._global_prefix}{DATAGRAM_PREFIX}{SYSTEM_TAG_PREFIX_NAME}_" + + @property + def SYSTEM_TAG_SOURCE_ID_PREFIX(self) -> str: + return f"{self.SYSTEM_TAG_PREFIX}{SYSTEM_TAG_SOURCE_ID_FIELD}" + + @property + def SYSTEM_TAG_RECORD_ID_PREFIX(self) -> str: + return f"{self.SYSTEM_TAG_PREFIX}{SYSTEM_TAG_RECORD_ID_FIELD}" @property def POD_VERSION(self) -> str: diff --git a/src/orcapod/types.py b/src/orcapod/types.py index 0f84d9c9..3f8938de 100644 --- a/src/orcapod/types.py +++ b/src/orcapod/types.py @@ -1,37 +1,547 @@ -from types import UnionType -from typing import TypeAlias -import os -from collections.abc import Collection, Mapping +"""Core type definitions for OrcaPod. + +Defines the fundamental data types, type aliases, and data structures used +throughout the OrcaPod framework, including: + + - Type aliases for data values, schemas, paths, and tags. + - ``Schema`` -- an immutable, hashable mapping of field names to Python types. + - ``ContentHash`` -- a content-addressable hash pairing a method name with + a raw digest, with convenience conversions to hex, int, UUID, and base64. +""" + +from __future__ import annotations import logging +import os +import uuid +from collections.abc import Collection, Iterator, Mapping +from dataclasses import dataclass +from enum import Enum +from types import UnionType +from typing import TYPE_CHECKING, Any, Self, TypeAlias + +if TYPE_CHECKING: + from orcapod.protocols.core_protocols import PacketFunctionExecutorProtocol + +import pyarrow as pa logger = logging.getLogger(__name__) -DataType: TypeAlias = type | UnionType | list[type] | tuple[type, ...] +# Mapping from Python types to Arrow types. +_PYTHON_TO_ARROW: dict[type, pa.DataType] = { + int: pa.int64(), + float: pa.float64(), + str: pa.string(), + bool: pa.bool_(), + bytes: pa.binary(), +} -PythonSchema: TypeAlias = dict[str, DataType] # dict of parameter names to their types +# Reverse mapping from Arrow types back to Python types. +_ARROW_TO_PYTHON: dict[pa.DataType, type] = {v: k for k, v in _PYTHON_TO_ARROW.items()} -PythonSchemaLike: TypeAlias = Mapping[ - str, DataType -] # Mapping of parameter names to their types +# TODO: revisit and consider a way to incorporate older Union type +DataType: TypeAlias = type | UnionType # | type[Union] +"""A Python type or union of types used to describe the data type of a single +field within a ``Schema``.""" -# Convenience alias for anything pathlike -PathLike = str | os.PathLike +# TODO: accomodate other Path-like objects +PathLike: TypeAlias = str | os.PathLike +"""Convenience alias for any filesystem-path-like object (``str`` or +``os.PathLike``).""" -# an (optional) string or a collection of (optional) string values -# Note that TagValue can be nested, allowing for an arbitrary depth of nested lists +# TODO: accomodate other common data types such as datetime TagValue: TypeAlias = int | str | None | Collection["TagValue"] +"""A tag metadata value: an int, string, ``None``, or an arbitrarily nested +collection thereof. Tags are used to label and organise packets and +datagrams.""" -# a pathset is a path or an arbitrary depth of nested list of paths PathSet: TypeAlias = PathLike | Collection[PathLike | None] +"""A single path or an arbitrarily nested collection of paths (with optional +``None`` entries). Used when operations need to address multiple files at +once, e.g. batch hashing.""" -# Simple data types that we support (with clear Polars correspondence) SupportedNativePythonData: TypeAlias = str | int | float | bool | bytes +"""The simple Python scalar types that have a direct Arrow / Polars +correspondence.""" ExtendedSupportedPythonData: TypeAlias = SupportedNativePythonData | PathSet +"""Native scalar types extended with filesystem paths.""" -# Extended data values that can be stored in packets -# Either the original PathSet or one of our supported simple data types DataValue: TypeAlias = ExtendedSupportedPythonData | Collection["DataValue"] | None +"""The universe of values that can appear in a packet column -- scalars, +paths, arbitrarily nested collections, or ``None``.""" PacketLike: TypeAlias = Mapping[str, DataValue] +"""A dict-like structure mapping field names to ``DataValue`` entries. Serves +as a lightweight, protocol-free representation of a packet.""" + +SchemaLike: TypeAlias = Mapping[str, DataType] +"""A dict-like structure mapping field names to ``DataType`` entries. +Accepted wherever a ``Schema`` is expected so callers can pass plain dicts.""" + + +class Schema(Mapping[str, DataType]): + """Immutable schema representing a mapping of field names to Python types. + + Serves as the canonical internal schema representation in OrcaPod, + with interop to/from Arrow schemas. Hashable and suitable for use + in content-addressable contexts. + + Args: + fields: An optional mapping of field names to their data types. + **kwargs: Additional field name / type pairs. These are merged with + ``fields``; keyword arguments take precedence on conflict. + + Example:: + + schema = Schema({"x": int, "y": float}) + schema = Schema(x=int, y=float) + """ + + def __init__( + self, + fields: Mapping[str, DataType] | None = None, + optional_fields: Collection[str] | None = None, + **kwargs: type, + ) -> None: + combined = dict(fields or {}) + combined.update(kwargs) + self._data: dict[str, DataType] = combined + self._optional: frozenset[str] = frozenset(optional_fields or ()) + + # ==================== Mapping interface ==================== + + def __getitem__(self, key: str) -> DataType: + return self._data[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __repr__(self) -> str: + return f"Schema({self._data!r})" + + # ==================== Value semantics ==================== + + def __eq__(self, other: object) -> bool: + if isinstance(other, Schema): + return self._data == other._data and self._optional == other._optional + if isinstance(other, Mapping): + return self._data == dict(other) + raise NotImplementedError( + f"Equality check is not implemented for object of type {type(other)}" + ) + + # ==================== Optionality ==================== + + @property + def optional_fields(self) -> frozenset[str]: + """Field names that are optional (have a default value in the source function).""" + return self._optional + + @property + def required_fields(self) -> frozenset[str]: + """Field names that must be present in an incoming packet.""" + return frozenset(self._data.keys()) - self._optional + + def is_required(self, field: str) -> bool: + """Return True if *field* must be present (has no default).""" + return field not in self._optional + + # ==================== Schema operations ==================== + + def merge(self, other: Mapping[str, type]) -> Schema: + """Return a new Schema that is the union of ``self`` and ``other``. + + Args: + other: A mapping of field names to types to merge in. + + Returns: + A new ``Schema`` containing all fields from both schemas. + + Raises: + ValueError: If any shared field has a different type in ``other``. + """ + conflicts = {k for k in other if k in self._data and self._data[k] != other[k]} + if conflicts: + raise ValueError(f"Schema merge conflict on fields: {conflicts}") + other_optional = other._optional if isinstance(other, Schema) else frozenset() + return Schema( + {**self._data, **other}, optional_fields=self._optional | other_optional + ) + + def with_values(self, other: dict[str, type] | None, **kwargs: type) -> Schema: + """Return a new Schema with the specified fields added or overridden. + + Unlike ``merge``, this method silently overrides existing fields when + a key already exists. + + Args: + other: An optional mapping of field names to types. + **kwargs: Additional field name / type pairs. + + Returns: + A new ``Schema`` with the updated fields. + """ + if other is None: + other = {} + return Schema({**self._data, **other, **kwargs}) + + def select(self, *fields: str) -> Schema: + """Return a new Schema containing only the specified fields. + + Args: + *fields: Names of the fields to keep. + + Returns: + A new ``Schema`` with only the requested fields. + + Raises: + KeyError: If any of the requested fields are not present. + """ + missing = set(fields) - self._data.keys() + if missing: + raise KeyError(f"Fields not in schema: {missing}") + kept = frozenset(fields) + return Schema( + {k: self._data[k] for k in fields}, optional_fields=self._optional & kept + ) + + def drop(self, *fields: str) -> Schema: + """Return a new Schema with the specified fields removed. + + Args: + *fields: Names of the fields to drop. Fields not present in the + schema are silently ignored. + + Returns: + A new ``Schema`` without the dropped fields. + """ + dropped = frozenset(fields) + return Schema( + {k: v for k, v in self._data.items() if k not in fields}, + optional_fields=self._optional - dropped, + ) + + def is_compatible_with(self, other: Schema) -> bool: + """Check whether ``other`` is a superset of this schema. + + Args: + other: The schema to compare against. + + Returns: + ``True`` if ``other`` contains every field in ``self`` with a + matching type. + """ + return all(other.get(k) == v for k, v in self._data.items()) + + # ==================== Convenience constructors ==================== + + @classmethod + def empty(cls) -> Schema: + """Create an empty schema with no fields. + + Returns: + A new ``Schema`` containing zero fields. + """ + return cls({}) + + +class ExecutorType(Enum): + """Pipeline execution strategy. + + Attributes: + SYNCHRONOUS: Current behavior -- ``static_process`` chain with + pull-based materialization. + ASYNC_CHANNELS: Push-based async channel execution via + ``async_execute``. + """ + + SYNCHRONOUS = "synchronous" + ASYNC_CHANNELS = "async_channels" + + +@dataclass(frozen=True, slots=True) +class PipelineConfig: + """Pipeline-level execution configuration. + + Attributes: + executor: Which execution strategy to use. + channel_buffer_size: Max items buffered per channel edge. + default_max_concurrency: Pipeline-wide default for per-node + concurrency. ``None`` means unlimited. + execution_engine: Optional packet-function executor applied to all + function nodes (e.g. ``RayExecutor``). ``None`` means in-process + execution. + execution_engine_opts: Resource/options dict forwarded to the engine + via ``with_options()`` (e.g. ``{"num_cpus": 4}``). + """ + + executor: ExecutorType = ExecutorType.SYNCHRONOUS + channel_buffer_size: int = 64 + default_max_concurrency: int | None = None + execution_engine: PacketFunctionExecutorProtocol | None = None + execution_engine_opts: dict[str, Any] | None = None + + +@dataclass(frozen=True, slots=True) +class NodeConfig: + """Per-node execution configuration. + + Attributes: + max_concurrency: Override for this node's concurrency limit. + ``None`` inherits from ``PipelineConfig.default_max_concurrency``. + ``1`` means sequential (rate-limited APIs, preserves ordering). + """ + + max_concurrency: int | None = None + + +def resolve_concurrency( + node_config: NodeConfig, pipeline_config: PipelineConfig +) -> int | None: + """Resolve effective concurrency from node and pipeline configs. + + Returns: + The concurrency limit to use, or ``None`` for unlimited. + + Raises: + ValueError: If the resolved value is ``<= 0``. + """ + if node_config.max_concurrency is not None: + result = node_config.max_concurrency + else: + result = pipeline_config.default_max_concurrency + if result is not None and result <= 0: + raise ValueError(f"max_concurrency must be >= 1, got {result}") + return result + + +class CacheMode(Enum): + """Controls operator pod caching behaviour. + + Attributes: + OFF: No cache writes, always compute. Default for operator pods. + LOG: Cache writes and computation. The operator always recomputes; + the cache serves as an append-only historical record. + REPLAY: Skip computation and flow cached results downstream. Only + appropriate when the user explicitly wants to use the historical + record (e.g. auditing or run-over-run comparison). + """ + + OFF = "off" + LOG = "log" + REPLAY = "replay" + + +@dataclass(frozen=True, slots=True) +class ColumnConfig: + """ + Configuration for column inclusion in DatagramProtocol/PacketProtocol/TagProtocol operations. + + Controls which column types to include when converting to tables, dicts, + or querying keys/types. + + Attributes: + meta: Include meta columns (with '__' prefix). + - False: exclude all meta columns (default) + - True: include all meta columns + - Collection[str]: include specific meta columns by name + (prefix '__' is added automatically if not present) + context: Include context column + source: Include source info columns (PacketProtocol only, ignored for others) + system_tags: Include system tag columns (TagProtocol only, ignored for others) + all_info: Include all available columns (overrides other settings) + + Examples: + >>> # Data columns only (default) + >>> ColumnConfig() + + >>> # Everything + >>> ColumnConfig(all_info=True) + >>> # Or use convenience method: + >>> ColumnConfig.all() + + >>> # Specific combinations + >>> ColumnConfig(meta=True, context=True) + >>> ColumnConfig(meta=["pipeline", "processed"], source=True) + + >>> # As dict (alternative syntax) + >>> {"meta": True, "source": True} + """ + + meta: bool | Collection[str] = False + context: bool = False + source: bool = False # Only relevant for PacketProtocol + system_tags: bool = False # Only relevant for TagProtocol + content_hash: bool | str = False # Only relevant for PacketProtocol + sort_by_tags: bool = False # Only relevant for TagProtocol + all_info: bool = False + + @classmethod + def all(cls) -> Self: + """Convenience: include all available columns""" + return cls( + meta=True, + context=True, + source=True, + system_tags=True, + content_hash=True, + sort_by_tags=True, + all_info=True, + ) + + @classmethod + def data_only(cls) -> Self: + """Convenience: include only data columns (default)""" + return cls() + + # TODO: consider renaming this to something more intuitive + @classmethod + def handle_config( + cls, config: Self | dict[str, Any] | None, all_info: bool = False + ) -> Self: + """ + Normalize column configuration input. + + Args: + config: ColumnConfig instance or dict to normalize. + all_info: If True, override config to include all columns. + + Returns: + Normalized ColumnConfig instance. + """ + if all_info: + return cls.all() + # TODO: properly handle non-boolean values when using all_info + + if config is None: + column_config = cls() + elif isinstance(config, dict): + column_config = cls(**config) + elif isinstance(config, cls): + column_config = config + else: + raise TypeError( + f"Invalid column config type: {type(config)}. " + "Expected ColumnConfig instance or dict." + ) + + return column_config + + +@dataclass(frozen=True, slots=True) +class ContentHash: + """Content-addressable hash pairing a hashing method with a raw digest. + + ``ContentHash`` is the standard way to represent hashes throughout OrcaPod. + It is immutable (frozen dataclass) and provides convenience methods to + convert the digest into hex strings, integers, UUIDs, base64, and + human-friendly display names. + + Attributes: + method: Identifier for the hashing algorithm / strategy used + (e.g. ``"arrow_v2.1"``). + digest: The raw hash bytes. + """ + + method: str + digest: bytes + + # TODO: make the default char count configurable + def to_hex(self, char_count: int | None = None) -> str: + """Convert the digest to a hexadecimal string. + + Args: + char_count: If given, truncate the hex string to this many + characters. + + Returns: + The full (or truncated) hex representation of the digest. + """ + hex_str = self.digest.hex() + return hex_str[:char_count] if char_count else hex_str + + def to_int(self, hexdigits: int | None = None) -> int: + """Convert the digest to an integer. + + Args: + hexdigits: Number of hex digits to use. If provided, the hex + string is truncated before conversion. + + Returns: + Integer representation of the (optionally truncated) digest. + """ + return int(self.to_hex(hexdigits), 16) + + def to_uuid(self, namespace: uuid.UUID = uuid.NAMESPACE_OID) -> uuid.UUID: + """Derive a deterministic UUID from the digest. + + Uses ``uuid5`` with the full hex string to ensure deterministic output. + + Args: + namespace: UUID namespace for ``uuid5`` generation. Defaults to + ``uuid.NAMESPACE_OID``. + + Returns: + A UUID derived from this hash. + """ + # Using uuid5 with the hex string ensures deterministic UUIDs + return uuid.uuid5(namespace, self.to_hex()) + + def to_base64(self) -> str: + """Convert the digest to a base64-encoded ASCII string. + + Returns: + Base64 representation of the digest. + """ + import base64 + + return base64.b64encode(self.digest).decode("ascii") + + def to_string( + self, prefix_method: bool = True, hexdigits: int | None = None + ) -> str: + """Convert the digest to a human-readable string. + + Args: + prefix_method: If ``True`` (the default), prepend the method name + followed by a colon (e.g. ``"sha256:abcd1234"``). + hexdigits: Optional number of hex digits to include. + + Returns: + String representation of the hash. + """ + if prefix_method: + return f"{self.method}:{self.to_hex(hexdigits)}" + return self.to_hex(hexdigits) + + def __str__(self) -> str: + return self.to_string() + + @classmethod + def from_string(cls, hash_string: str) -> "ContentHash": + """Parse a ``"method:hex_digest"`` string into a ``ContentHash``. + + Args: + hash_string: A string in the format ``"method:hex_digest"``. + + Returns: + A new ``ContentHash`` instance. + """ + method, hex_digest = hash_string.split(":", 1) + return cls(method, bytes.fromhex(hex_digest)) + + def display_name(self, length: int = 8) -> str: + """Return a short, human-friendly label for this hash. + + Args: + length: Number of hex characters to include after the method + prefix. Defaults to 8. + + Returns: + A string like ``"arrow_v2.1:1a2b3c4d"``. + """ + return f"{self.method}:{self.to_hex(length)}" diff --git a/src/orcapod/utils/arrow_data_utils.py b/src/orcapod/utils/arrow_data_utils.py new file mode 100644 index 00000000..141fedf5 --- /dev/null +++ b/src/orcapod/utils/arrow_data_utils.py @@ -0,0 +1,280 @@ +# Collection of functions to work with Arrow table data that underlies streams and/or datagrams +from __future__ import annotations + +from collections.abc import Collection +from typing import TYPE_CHECKING + +from orcapod.system_constants import constants +from orcapod.utils.lazy_module import LazyModule + +if TYPE_CHECKING: + import pyarrow as pa +else: + pa = LazyModule("pyarrow") + + +def drop_columns_with_prefix( + table: pa.Table, + prefix: str | tuple[str, ...], + exclude_columns: Collection[str] = (), +) -> pa.Table: + """Drop columns with a specific prefix from an Arrow table.""" + columns_to_drop = [ + col + for col in table.column_names + if col.startswith(prefix) and col not in exclude_columns + ] + return table.drop(columns=columns_to_drop) + + +def drop_system_columns( + table: pa.Table, + system_column_prefix: tuple[str, ...] = ( + constants.META_PREFIX, + constants.DATAGRAM_PREFIX, + ), +) -> pa.Table: + return drop_columns_with_prefix(table, system_column_prefix) + + +def get_system_columns(table: pa.Table) -> pa.Table: + """Get system columns from an Arrow table.""" + return table.select( + [ + col + for col in table.column_names + if col.startswith(constants.SYSTEM_TAG_PREFIX) + ] + ) + + +def add_system_tag_columns( + table: pa.Table, + schema_hash: str, + source_ids: str | Collection[str], + record_ids: Collection[str], +) -> pa.Table: + """Add paired source_id and record_id system tag columns to an Arrow table.""" + if not table.column_names: + raise ValueError("Table is empty") + + # Normalize source_ids + if isinstance(source_ids, str): + source_ids = [source_ids] * table.num_rows + else: + source_ids = list(source_ids) + if len(source_ids) != table.num_rows: + raise ValueError( + "Length of source_ids must match number of rows in the table." + ) + + record_ids = list(record_ids) + if len(record_ids) != table.num_rows: + raise ValueError("Length of record_ids must match number of rows in the table.") + + source_id_col_name = f"{constants.SYSTEM_TAG_SOURCE_ID_PREFIX}{constants.BLOCK_SEPARATOR}{schema_hash}" + record_id_col_name = f"{constants.SYSTEM_TAG_RECORD_ID_PREFIX}{constants.BLOCK_SEPARATOR}{schema_hash}" + + source_id_array = pa.array(source_ids, type=pa.large_string()) + record_id_array = pa.array(record_ids, type=pa.large_string()) + + table = table.append_column(source_id_col_name, source_id_array) + table = table.append_column(record_id_col_name, record_id_array) + return table + + +def append_to_system_tags(table: pa.Table, value: str) -> pa.Table: + """Append a value to the system tags column in an Arrow table.""" + if not table.column_names: + raise ValueError("Table is empty") + + column_name_map = { + c: f"{c}{constants.BLOCK_SEPARATOR}{value}" + if c.startswith(constants.SYSTEM_TAG_PREFIX) + else c + for c in table.column_names + } + return table.rename_columns(column_name_map) + + +def _parse_system_tag_column( + col_name: str, +) -> tuple[str, str, str] | None: + """Parse a system tag column name into (field_type, provenance_path, position). + + For example: + _tag_source_id::abc123::def456:0 + → field_type="source_id", provenance_path="abc123::def456", position="0" + + _tag_record_id::abc123::def456:0 + → field_type="record_id", provenance_path="abc123::def456", position="0" + + Returns None if the column doesn't end with a :position suffix. + """ + # Strip the trailing :position + base, sep, position = col_name.rpartition(constants.FIELD_SEPARATOR) + if not sep or not position.isdigit(): + return None + + # Determine field type by checking known prefixes + prefix = constants.SYSTEM_TAG_PREFIX + if not base.startswith(prefix): + return None + + after_prefix = base[len(prefix) :] # e.g. "source_id::abc123::def456" + + # Extract field_type and provenance_path + # field_type is everything before the first BLOCK_SEPARATOR + field_type, block_sep, provenance_path = after_prefix.partition( + constants.BLOCK_SEPARATOR + ) + if not block_sep: + return None + + return field_type, provenance_path, position + + +def sort_system_tag_values(table: pa.Table) -> pa.Table: + """Sort paired system tag values for columns that share the same provenance path. + + System tag columns come in (source_id, record_id) pairs. Columns that differ + only by their canonical position (the final :N) represent streams with the same + pipeline_hash that were joined. For commutativity, paired (source_id, record_id) + tuples must be sorted together per row so that the result is independent of + input order. + + Algorithm: + 1. Parse each system tag column into (field_type, provenance_path, position) + 2. Group by provenance_path — source_id and record_id at the same path+position + are paired + 3. For each group with >1 position, sort per-row by (source_id, record_id) tuples + 4. Assign sorted values back to both columns at each position + """ + sys_tag_cols = [ + c for c in table.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + + if not sys_tag_cols: + return table + + # Parse all system tag columns and group by provenance_path + # groups[provenance_path][position] = {field_type: col_name} + groups: dict[str, dict[str, dict[str, str]]] = {} + for col in sys_tag_cols: + parsed = _parse_system_tag_column(col) + if parsed is None: + continue + field_type, provenance_path, position = parsed + groups.setdefault(provenance_path, {}).setdefault(position, {})[field_type] = ( + col + ) + + source_id_field = constants.SYSTEM_TAG_SOURCE_ID_PREFIX[ + len(constants.SYSTEM_TAG_PREFIX) : + ] + record_id_field = constants.SYSTEM_TAG_RECORD_ID_PREFIX[ + len(constants.SYSTEM_TAG_PREFIX) : + ] + + # For each provenance_path group with >1 position, sort paired tuples per row + for provenance_path, positions in groups.items(): + if len(positions) <= 1: + continue + + # Sort positions numerically + sorted_positions = sorted(positions.keys(), key=int) + + # Collect paired column names for each position + paired_cols: list[tuple[str | None, str | None]] = [] + for pos in sorted_positions: + field_map = positions[pos] + sid_col = field_map.get(source_id_field) + rid_col = field_map.get(record_id_field) + paired_cols.append((sid_col, rid_col)) + + # Get values for all paired columns + sid_values = [] + rid_values = [] + for sid_col, rid_col in paired_cols: + sid_values.append( + table.column(sid_col).to_pylist() + if sid_col + else [None] * table.num_rows + ) + rid_values.append( + table.column(rid_col).to_pylist() + if rid_col + else [None] * table.num_rows + ) + + # Sort per row by (source_id, record_id) tuples + n_positions = len(sorted_positions) + sorted_sid: list[list] = [[] for _ in range(n_positions)] + sorted_rid: list[list] = [[] for _ in range(n_positions)] + + for row_idx in range(table.num_rows): + row_tuples = [ + (sid_values[pos_idx][row_idx], rid_values[pos_idx][row_idx]) + for pos_idx in range(n_positions) + ] + row_tuples.sort() + for pos_idx, (sid_val, rid_val) in enumerate(row_tuples): + sorted_sid[pos_idx].append(sid_val) + sorted_rid[pos_idx].append(rid_val) + + # Replace columns with sorted values + for pos_idx, (sid_col, rid_col) in enumerate(paired_cols): + if sid_col: + orig_type = table.column(sid_col).type + tbl_idx = table.column_names.index(sid_col) + table = table.drop(sid_col) + table = table.add_column( + tbl_idx, + sid_col, + pa.array(sorted_sid[pos_idx], type=orig_type), + ) + if rid_col: + orig_type = table.column(rid_col).type + tbl_idx = table.column_names.index(rid_col) + table = table.drop(rid_col) + table = table.add_column( + tbl_idx, + rid_col, + pa.array(sorted_rid[pos_idx], type=orig_type), + ) + + return table + + +def add_source_info( + table: pa.Table, + source_info: str | Collection[str] | None, + exclude_prefixes: Collection[str] = ( + constants.META_PREFIX, + constants.DATAGRAM_PREFIX, + ), + exclude_columns: Collection[str] = (), +) -> pa.Table: + """Add source information to an Arrow table.""" + # Create a new column with the source information + if source_info is None or isinstance(source_info, str): + source_column = [source_info] * table.num_rows + elif isinstance(source_info, Collection): + if len(source_info) != table.num_rows: + raise ValueError( + "Length of source_info collection must match number of rows in the table." + ) + source_column = source_info + + # identify columns for which source columns should be created + + for col in table.column_names: + if col.startswith(tuple(exclude_prefixes)) or col in exclude_columns: + continue + source_column = pa.array( + [f"{source_val}::{col}" for source_val in source_column], + type=pa.large_string(), + ) + table = table.append_column(f"{constants.SOURCE_PREFIX}{col}", source_column) + + return table diff --git a/src/orcapod/utils/git_utils.py b/src/orcapod/utils/git_utils.py index 18b7caa4..d60f15dc 100644 --- a/src/orcapod/utils/git_utils.py +++ b/src/orcapod/utils/git_utils.py @@ -32,10 +32,9 @@ def get_git_info(path): commit_hash = repo.head.commit.hexsha short_hash = repo.head.commit.hexsha[:7] - # Check if repository is dirty + # Check if repository is dirty (staged or unstaged changes only; + # untracked_files=False avoids a slow git ls-files subprocess call) is_dirty = repo.is_dirty(untracked_files=False) - # Check if there are untracked files - has_untracked_files = len(repo.untracked_files) > 0 # Get current branch name try: @@ -44,22 +43,12 @@ def get_git_info(path): # Handle detached HEAD state branch_name = "HEAD (detached)" - # Get more detailed dirty status - dirty_details = { - "staged": len(repo.index.diff("HEAD")) > 0, - "unstaged": len(repo.index.diff(None)) > 0, - "untracked": len(repo.untracked_files) > 0, - } - return { "is_repo": True, "commit_hash": commit_hash, "short_hash": short_hash, "is_dirty": is_dirty, - "has_untracked_files": has_untracked_files, "branch": branch_name, - "dirty_details": dirty_details, - "untracked_files": repo.untracked_files, "repo_root": repo.working_dir, } diff --git a/src/orcapod/utils/object_spec.py b/src/orcapod/utils/object_spec.py index 8ecfd0ac..652170f5 100644 --- a/src/orcapod/utils/object_spec.py +++ b/src/orcapod/utils/object_spec.py @@ -1,6 +1,5 @@ import importlib from typing import Any -from weakref import ref def parse_objectspec( @@ -21,6 +20,8 @@ def parse_objectspec( return ref_lut[ref_key] else: raise ValueError(f"Unknown reference: {ref_key}") + elif "_type" in obj_spec: + return _resolve_type_from_spec(obj_spec) else: # Recursively process dict return { @@ -35,6 +36,19 @@ def parse_objectspec( return obj_spec +def _resolve_type_from_spec(spec: dict[str, Any]) -> type: + """Resolve a ``{"_type": "module.ClassName"}`` spec to the actual Python type. + + Bare names without a dot (e.g. ``"bytes"``) are resolved from ``builtins``. + """ + type_str: str = spec["_type"] + if "." not in type_str: + type_str = f"builtins.{type_str}" + module_name, _, attr_name = type_str.rpartition(".") + module = importlib.import_module(module_name) + return getattr(module, attr_name) + + def _create_instance_from_spec( spec: dict[str, Any], ref_lut: dict[str, Any], validate: bool ) -> Any: diff --git a/src/orcapod/core/polars_data_utils.py b/src/orcapod/utils/polars_data_utils.py similarity index 97% rename from src/orcapod/core/polars_data_utils.py rename to src/orcapod/utils/polars_data_utils.py index 7757a1d1..fbd4f6db 100644 --- a/src/orcapod/core/polars_data_utils.py +++ b/src/orcapod/utils/polars_data_utils.py @@ -1,8 +1,9 @@ # Collection of functions to work with Arrow table data that underlies streams and/or datagrams -from orcapod.utils.lazy_module import LazyModule -from typing import TYPE_CHECKING -from orcapod.core.system_constants import constants from collections.abc import Collection +from typing import TYPE_CHECKING + +from orcapod.system_constants import constants +from orcapod.utils.lazy_module import LazyModule if TYPE_CHECKING: import polars as pl @@ -80,7 +81,7 @@ def append_to_system_tags(df: "pl.DataFrame", value: str) -> "pl.DataFrame": df.rename column_name_map = { - c: f"{c}:{value}" + c: f"{c}{constants.BLOCK_SEPARATOR}{value}" for c in df.columns if c.startswith(constants.SYSTEM_TAG_PREFIX) } diff --git a/src/orcapod/utils/types_utils.py b/src/orcapod/utils/schema_utils.py similarity index 69% rename from src/orcapod/utils/types_utils.py rename to src/orcapod/utils/schema_utils.py index 5c25d031..227470c1 100644 --- a/src/orcapod/utils/types_utils.py +++ b/src/orcapod/utils/schema_utils.py @@ -1,26 +1,27 @@ -# Library of functions for working with TypeSpecs and for extracting TypeSpecs from a function's signature +# Library of functions for working with Schemas and for extracting Schemas from a function's signature -from collections.abc import Callable, Collection, Sequence, Mapping -from typing import get_origin, get_args, Any -from orcapod.types import PythonSchema, PythonSchemaLike import inspect import logging import sys +from collections.abc import Callable, Collection, Mapping, Sequence +from typing import Any, get_args, get_origin, get_type_hints + +from orcapod.types import Schema, SchemaLike logger = logging.getLogger(__name__) -def verify_against_typespec(packet: dict, typespec: PythonSchema) -> bool: - """Verify that the dictionary's types match the expected types in the typespec.""" +def verify_packet_schema(packet: dict, schema: SchemaLike) -> bool: + """Verify that the dictionary's types match the expected types in the schema.""" from beartype.door import is_bearable - # verify that packet contains no keys not in typespec - if set(packet.keys()) - set(typespec.keys()): + # verify that packet contains no keys not in schema + if set(packet.keys()) - set(schema.keys()): logger.warning( - f"Packet contains keys not in typespec: {set(packet.keys()) - set(typespec.keys())}. " + f"PacketProtocol contains keys not in schema: {set(packet.keys()) - set(schema.keys())}. " ) return False - for key, type_info in typespec.items(): + for key, type_info in schema.items(): if key not in packet: logger.warning( f"Key '{key}' not found in packet. Assuming None but this behavior may change in the future" @@ -35,8 +36,8 @@ def verify_against_typespec(packet: dict, typespec: PythonSchema) -> bool: # TODO: is_subhint does not handle invariance properly # so when working with mutable types, we have to make sure to perform deep copy -def check_typespec_compatibility( - incoming_types: PythonSchema, receiving_types: PythonSchema +def check_schema_compatibility( + incoming_types: SchemaLike, receiving_types: Schema ) -> bool: from beartype.door import is_subhint @@ -49,15 +50,22 @@ def check_typespec_compatibility( f"Type mismatch for key '{key}': expected {receiving_types[key]}, got {type_info}." ) return False + + # Every receiving key must be present in incoming OR be optional (has a default) + for key in receiving_types: + if key not in incoming_types and key not in receiving_types.optional_fields: + logger.warning(f"Required key '{key}' missing from incoming types.") + return False + return True -def extract_function_typespecs( +def extract_function_schemas( func: Callable, output_keys: Collection[str], - input_typespec: PythonSchemaLike | None = None, - output_typespec: PythonSchemaLike | Sequence[type] | None = None, -) -> tuple[PythonSchema, PythonSchema]: + input_typespec: SchemaLike | None = None, + output_typespec: SchemaLike | Sequence[type] | None = None, +) -> tuple[Schema, Schema]: """ Extract input and output data types from a function signature. @@ -136,7 +144,7 @@ def extract_function_typespecs( >>> output_types {'count': , 'total': , 'repr': } """ - verified_output_types: PythonSchema = {} + verified_output_types: Schema = {} if output_typespec is not None: if isinstance(output_typespec, dict): verified_output_types = output_typespec @@ -148,23 +156,36 @@ def extract_function_typespecs( ) verified_output_types = {k: v for k, v in zip(output_keys, output_typespec)} + # Use get_type_hints to resolve annotations that may be stored as strings + # (e.g. when the defining module uses `from __future__ import annotations`). + # Fall back to an empty dict if hints cannot be resolved (e.g. for built-ins). + try: + resolved_hints = get_type_hints(func) + except Exception: + resolved_hints = {} + signature = inspect.signature(func) - param_info: PythonSchema = {} + param_info: Schema = {} + optional_params: set[str] = set() for name, param in signature.parameters.items(): if input_typespec and name in input_typespec: param_info[name] = input_typespec[name] + elif name in resolved_hints: + param_info[name] = resolved_hints[name] + elif param.annotation is not inspect.Parameter.empty: + # annotation is already a live type (no __future__ postponement) + param_info[name] = param.annotation else: - # check if the parameter has annotation - if param.annotation is not inspect.Signature.empty: - param_info[name] = param.annotation - else: - raise ValueError( - f"Parameter '{name}' has no type annotation and is not specified in input_types." - ) + raise ValueError( + f"Parameter '{name}' has no type annotation and is not specified in input_types." + ) + if param.default is not inspect.Parameter.empty: + optional_params.add(name) - return_annot = signature.return_annotation - inferred_output_types: PythonSchema = {} + # get_type_hints stores the return annotation under the key 'return' + return_annot = resolved_hints.get("return", signature.return_annotation) + inferred_output_types: Schema = {} if return_annot is not inspect.Signature.empty and return_annot is not None: output_item_types = [] if len(output_keys) == 0: @@ -174,9 +195,12 @@ def extract_function_typespecs( elif len(output_keys) == 1: # if only one return key, the entire annotation is inferred as the return type output_item_types = [return_annot] - elif (get_origin(return_annot) or return_annot) in (tuple, list, Sequence): + elif get_origin(return_annot) in (tuple, list) or ( + isinstance(get_origin(return_annot), type) + and issubclass(get_origin(return_annot), Sequence) + ): if get_origin(return_annot) is None: - # right type was specified but did not specified the type of items + # right type was specified but did not specify the type of items raise ValueError( f"Function return type annotation {return_annot} is a Sequence type but does not specify item types." ) @@ -211,23 +235,42 @@ def extract_function_typespecs( raise ValueError( f"Type for return item '{key}' is not specified in output_types and has no type annotation in function signature." ) - return param_info, inferred_output_types + # Reject bare container types (must have type parameters) + _BARE_CONTAINER_TYPES = {dict, list, set, tuple} + _BARE_CONTAINER_EXAMPLES = { + dict: "dict[str, int]", + list: "list[int]", + set: "set[int]", + tuple: "tuple[int, ...]", + } + for name, type_annot in {**param_info, **inferred_output_types}.items(): + if type_annot in _BARE_CONTAINER_TYPES: + example = _BARE_CONTAINER_EXAMPLES[type_annot] + raise ValueError( + f"Type annotation for '{name}' is bare {type_annot.__name__} " + f"without type parameters. Use e.g. {example} instead." + ) + + return Schema(param_info, optional_fields=optional_params), Schema( + inferred_output_types + ) -def get_typespec_from_dict( - data: Mapping, typespec: PythonSchema | None = None, default=str -) -> PythonSchema: +def infer_schema_from_dict( + data: Mapping, schema: SchemaLike | None = None, default=str +) -> Schema: """ - Returns a TypeSpec for the given dictionary. - The TypeSpec is a mapping from field name to Python type. If typespec is provided, then - it is used as a base when inferring types for the fields in dict + Returns a Schema for the given dictionary by inferring types from values. + If schema is provided, it is used as a base when inferring types for the fields in dict. """ - if typespec is None: - typespec = {} - return { - key: typespec.get(key, type(value) if value is not None else default) - for key, value in data.items() - } + if schema is None: + schema = {} + return Schema( + { + key: schema.get(key, type(value) if value is not None else default) + for key, value in data.items() + } + ) # def get_compatible_type(type1: Any, type2: Any) -> Any: @@ -305,62 +348,35 @@ def get_compatible_type(type1: Any, type2: Any) -> Any: return _GenericAlias(origin1, tuple(compatible_args)) -def union_typespecs(*typespecs: PythonSchema) -> PythonSchema: - # Merge the two TypeSpecs but raise an error if conflicts in types are found - merged = dict(typespecs[0]) - for typespec in typespecs[1:]: - for key, right_type in typespec.items(): +def union_schemas(*schemas: SchemaLike) -> Schema: + """Merge multiple schemas, raising an error if type conflicts are found.""" + merged = dict(schemas[0]) + for schema in schemas[1:]: + for key, right_type in schema.items(): merged[key] = ( get_compatible_type(merged[key], right_type) if key in merged else right_type ) - return merged + return Schema(merged) -def intersection_typespecs(*typespecs: PythonSchema) -> PythonSchema: +def intersection_schemas(*schemas: SchemaLike) -> Schema: """ - Returns the intersection of all TypeSpecs, only returning keys that are present in all typespecs. - If a key is present in both TypeSpecs, the type must be the same. + Returns the intersection of all schemas, only returning keys that are present in all schemas. + If a key is present in multiple schemas, the types must be compatible. """ + common_keys = set(schemas[0].keys()) + for schema in schemas[1:]: + common_keys.intersection_update(schema.keys()) - # Find common keys and ensure types match - - common_keys = set(typespecs[0].keys()) - for typespec in typespecs[1:]: - common_keys.intersection_update(typespec.keys()) - - intersection = {k: typespecs[0][k] for k in common_keys} - for typespec in typespecs[1:]: + intersection = {k: schemas[0][k] for k in common_keys} + for schema in schemas[1:]: for key in common_keys: try: - intersection[key] = get_compatible_type( - intersection[key], typespec[key] - ) + intersection[key] = get_compatible_type(intersection[key], schema[key]) except TypeError: - # If types are not compatible, raise an error raise TypeError( - f"Type conflict for key '{key}': {intersection[key]} vs {typespec[key]}" + f"Type conflict for key '{key}': {intersection[key]} vs {schema[key]}" ) - return intersection - - -# def intersection_typespecs(left: TypeSpec, right: TypeSpec) -> TypeSpec: -# """ -# Returns the intersection of two TypeSpecs, only returning keys that are present in both. -# If a key is present in both TypeSpecs, the type must be the same. -# """ - -# # Find common keys and ensure types match -# common_keys = set(left.keys()).intersection(set(right.keys())) -# intersection = {} -# for key in common_keys: -# try: -# intersection[key] = get_compatible_type(left[key], right[key]) -# except TypeError: -# # If types are not compatible, raise an error -# raise TypeError( -# f"Type conflict for key '{key}': {left[key]} vs {right[key]}" -# ) - -# return intersection + return Schema(intersection) diff --git a/src/sample.py b/src/sample.py new file mode 100644 index 00000000..3c10555b --- /dev/null +++ b/src/sample.py @@ -0,0 +1,7 @@ +from collections.abc import Mapping + + +def test() -> Mapping[str, type] | int: ... + + +x = test() diff --git a/superpowers/plans/2026-03-14-node-authority-plan.md b/superpowers/plans/2026-03-14-node-authority-plan.md new file mode 100644 index 00000000..52c6afbd --- /dev/null +++ b/superpowers/plans/2026-03-14-node-authority-plan.md @@ -0,0 +1,372 @@ +# Node Authority Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development or superpowers:executing-plans. + +**Goal:** Make nodes self-validating and self-persisting. Remove `store_result`, add schema validation, add `process()` to FunctionNode and OperatorNode. + +**Architecture:** Nodes validate input schemas, compute, persist, and cache internally. The orchestrator calls `process_packet` or `process` — nodes handle everything else. Pod vs Node distinction: pods return lazy streams, nodes return materialized lists. + +**Spec:** `superpowers/specs/2026-03-14-node-authority-design.md` + +--- + +## File Map + +### Modified Files + +| File | Changes | +|------|---------| +| `src/orcapod/protocols/node_protocols.py` | Remove `store_result`, remove `operator` property, add `process()` | +| `src/orcapod/core/nodes/function_node.py` | Revert process_packet to bundled, add schema validation, add `process()`, remove `store_result` | +| `src/orcapod/core/nodes/operator_node.py` | Add `process()` with validation, remove `store_result`, remove `operator` property | +| `src/orcapod/core/nodes/source_node.py` | Remove `store_result` | +| `src/orcapod/pipeline/sync_orchestrator.py` | Remove `store_result` calls, use `node.process()` for operators | +| `tests/test_core/nodes/test_node_store_result.py` | Rewrite as test_node_process.py | +| `tests/test_pipeline/test_sync_orchestrator.py` | Update if needed | + +--- + +### Task 1: FunctionNode — revert process_packet, add schema validation, add process() + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` +- Rename+Rewrite: `tests/test_core/nodes/test_node_store_result.py` → `tests/test_core/nodes/test_node_process.py` + +- [ ] **Step 1: Write tests for reverted process_packet (bundled behavior)** + +Create `tests/test_core/nodes/test_node_process.py` with FunctionNode tests: + +```python +"""Tests for node process methods (schema validation, persistence, caching).""" +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase + + +def double_value(value: int) -> int: + return value * 2 + + +@pytest.fixture +def function_node_with_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + pod, src, + pipeline_database=pipeline_db, + result_database=result_db, + ) + return node, pipeline_db, result_db + + +@pytest.fixture +def function_node_no_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + return FunctionNode(pod, src) + + +class TestFunctionNodeProcessPacket: + def test_process_packet_returns_correct_result(self, function_node_no_db): + node = function_node_no_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + tag_out, result = node.process_packet(tag, packet) + assert result is not None + assert result.as_dict()["result"] == 2 + + def test_process_packet_writes_pipeline_record(self, function_node_with_db): + """process_packet should write pipeline provenance record.""" + node, pipeline_db, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + tag_out, result = node.process_packet(tag, packet) + assert result is not None + + records = pipeline_db.get_all_records(node.pipeline_path) + assert records is not None + assert records.num_rows == 1 + + def test_process_packet_writes_to_result_db(self, function_node_with_db): + """process_packet should memoize via CachedFunctionPod.""" + node, _, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + tag_out, result = node.process_packet(tag, packet) + assert result is not None + + cached = node._cached_function_pod.get_all_cached_outputs() + assert cached is not None + assert cached.num_rows == 1 + + def test_process_packet_caches_internally(self, function_node_with_db): + """process_packet should populate _cached_output_packets.""" + node, _, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + node.process_packet(tag, packet) + assert len(node._cached_output_packets) == 1 + + def test_process_packet_validates_schema(self, function_node_with_db): + """process_packet should reject packets with wrong schema.""" + node, _, _ = function_node_with_db + # Create a packet with wrong schema + wrong_table = pa.table({ + "wrong_key": pa.array(["x"], type=pa.large_string()), + "wrong_col": pa.array([99], type=pa.int64()), + }) + wrong_src = ArrowTableSource(wrong_table, tag_columns=["wrong_key"]) + wrong_packets = list(wrong_src.iter_packets()) + wrong_tag, wrong_pkt = wrong_packets[0] + + with pytest.raises(Exception): # InputValidationError or ValueError + node.process_packet(wrong_tag, wrong_pkt) + + +class TestFunctionNodeProcess: + def test_process_returns_materialized_results(self, function_node_with_db): + node, _, _ = function_node_with_db + results = node.process(node._input_stream) + assert isinstance(results, list) + assert len(results) == 2 + values = sorted([pkt.as_dict()["result"] for _, pkt in results]) + assert values == [2, 4] + + def test_process_writes_pipeline_records(self, function_node_with_db): + node, pipeline_db, _ = function_node_with_db + node.process(node._input_stream) + records = pipeline_db.get_all_records(node.pipeline_path) + assert records is not None + assert records.num_rows == 2 + + def test_process_caches_internally(self, function_node_with_db): + node, _, _ = function_node_with_db + node.process(node._input_stream) + assert len(node._cached_output_packets) == 2 + + def test_process_validates_stream_schema(self, function_node_with_db): + node, _, _ = function_node_with_db + wrong_table = pa.table({ + "wrong_key": pa.array(["x"], type=pa.large_string()), + "wrong_col": pa.array([99], type=pa.int64()), + }) + wrong_stream = ArrowTableSource(wrong_table, tag_columns=["wrong_key"]) + with pytest.raises(Exception): + node.process(wrong_stream) +``` + +- [ ] **Step 2: Run tests, verify failures** (process_packet no longer writes pipeline records due to earlier refactor; process() doesn't exist yet; schema validation doesn't exist) + +Run: `uv run pytest tests/test_core/nodes/test_node_process.py -v` + +- [ ] **Step 3: Implement changes on FunctionNode** + +In `src/orcapod/core/nodes/function_node.py`: + +a. Add `_validate_input_schema(tag, packet)` method that checks tag keys, packet keys, and system tag column names against `self._input_stream.output_schema(columns={"system_tags": True})` / `self._input_stream.keys(all_info=True)`. + +b. Add `_validate_stream_schema(input_stream)` method that validates the stream's output_schema against expected. + +c. Refactor: rename `_process_and_store_packet` to `_process_packet_internal` — this is the core compute+persist+cache method (no schema validation). + +d. Update `process_packet` to call `_validate_input_schema` then `_process_packet_internal`. + +e. Add `process(input_stream)` that calls `_validate_stream_schema` once, then iterates calling `_process_packet_internal` per packet. + +f. Remove `store_result()` method. + +g. Update `iter_packets()`, `_iter_packets_sequential()`, `_iter_packets_concurrent()` to call `_process_packet_internal` (they already call `_process_and_store_packet` — just rename). + +- [ ] **Step 4: Run tests** + +Run: `uv run pytest tests/test_core/nodes/test_node_process.py -v` +Run: `uv run pytest tests/test_core/function_pod/ -v` (regression) + +- [ ] **Step 5: Commit** + +``` +git commit -m "refactor(function-node): self-validating process_packet, add process(), remove store_result" +``` + +### Task 2: OperatorNode — add process(), remove store_result and operator property + +**Files:** +- Modify: `src/orcapod/core/nodes/operator_node.py` +- Modify: `tests/test_core/nodes/test_node_process.py` (append) + +- [ ] **Step 1: Write tests for OperatorNode.process()** + +Append to `tests/test_core/nodes/test_node_process.py`: + +```python +from orcapod.core.nodes import OperatorNode +from orcapod.core.operators import SelectPacketColumns +from orcapod.core.operators.join import Join +from orcapod.types import CacheMode + + +@pytest.fixture +def operator_with_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + op = SelectPacketColumns(columns=["value"]) + db = InMemoryArrowDatabase() + node = OperatorNode( + op, input_streams=[src], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + return node, db, src + + +@pytest.fixture +def operator_no_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + op = SelectPacketColumns(columns=["value"]) + node = OperatorNode(op, input_streams=[src]) + return node, src + + +class TestOperatorNodeProcess: + def test_process_returns_materialized_results(self, operator_no_db): + node, src = operator_no_db + results = node.process(src) + assert isinstance(results, list) + assert len(results) == 2 + + def test_process_writes_to_db_in_log_mode(self, operator_with_db): + node, db, src = operator_with_db + node.process(src) + records = node.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_process_caches_internally(self, operator_no_db): + node, src = operator_no_db + node.process(src) + cached = list(node.iter_packets()) + assert len(cached) == 2 + + def test_process_validates_stream_schema(self, operator_no_db): + node, _ = operator_no_db + wrong_table = pa.table({ + "wrong": pa.array(["x"], type=pa.large_string()), + "bad": pa.array([1], type=pa.int64()), + }) + wrong_stream = ArrowTableSource(wrong_table, tag_columns=["wrong"]) + with pytest.raises(Exception): + node.process(wrong_stream) + + def test_process_noop_db_in_off_mode(self, operator_no_db): + node, src = operator_no_db + results = node.process(src) + assert len(results) == 2 + # No DB, so no records + assert node.get_all_records() is None +``` + +- [ ] **Step 2: Run tests, verify failures** + +- [ ] **Step 3: Implement OperatorNode.process()** + +In `src/orcapod/core/nodes/operator_node.py`: + +a. Add `_validate_input_schemas(*input_streams)` — validates each stream's schema against the corresponding upstream's expected schema. + +b. Add `process(*input_streams) → list[tuple[Tag, Packet]]`: + 1. Validate schemas + 2. `result_stream = self._operator.process(*input_streams)` + 3. Materialize: `output = list(result_stream.iter_packets())` + 4. Cache: `self._cached_output_stream = result_stream` (or re-materialize) + 5. DB persist if LOG mode: `self._store_output_stream(result_stream)` + 6. Return output + +c. Remove `store_result()` method. +d. Remove `operator` property (orchestrator no longer needs it). + +- [ ] **Step 4: Run tests** + +Run: `uv run pytest tests/test_core/nodes/test_node_process.py -v` +Run: `uv run pytest tests/test_core/operators/ -v` (regression) + +- [ ] **Step 5: Commit** + +``` +git commit -m "refactor(operator-node): add self-validating process(), remove store_result and operator property" +``` + +### Task 3: SourceNode — remove store_result + +**Files:** +- Modify: `src/orcapod/core/nodes/source_node.py` + +- [ ] **Step 1: Remove `store_result()` from SourceNode** +- [ ] **Step 2: Run tests, verify pass** +- [ ] **Step 3: Commit** + +``` +git commit -m "refactor(source-node): remove store_result" +``` + +### Task 4: Update protocols and orchestrator + +**Files:** +- Modify: `src/orcapod/protocols/node_protocols.py` +- Modify: `src/orcapod/pipeline/sync_orchestrator.py` + +- [ ] **Step 1: Update protocols** + +Remove `store_result` from all protocols. Remove `operator` property from `OperatorNodeProtocol`. Add `process` to `FunctionNodeProtocol` and `OperatorNodeProtocol`. + +- [ ] **Step 2: Update SyncPipelineOrchestrator** + +- `_execute_source`: remove `node.store_result(output)` call +- `_execute_function`: remove `node.store_result(tag, packet, result)` call (process_packet handles it) +- `_execute_operator`: replace `node.operator.process(*input_streams)` + `node.store_result(output)` with `output = node.process(*input_streams)` +- Remove `_materialize_as_stream` helper (no longer needed — node.process accepts streams, and the orchestrator still needs to wrap buffers as streams... actually keep this, the orchestrator needs it to create streams from buffers before passing to node.process) + +Wait — the orchestrator still needs `_materialize_as_stream` because it holds buffers (lists of tag/packet) and needs to wrap them as `StreamProtocol` before calling `node.process(*input_streams)`. Keep the helper. + +- [ ] **Step 3: Delete old test file and clean up** + +- Delete `tests/test_core/nodes/test_node_store_result.py` +- Run full suite: `uv run pytest tests/ --tb=short` (timeout=300000) + +- [ ] **Step 4: Commit** + +``` +git commit -m "refactor(orchestrator): update protocols and orchestrator for node authority pattern" +``` + +### Task 5: Final verification + +- [ ] Run full test suite: `uv run pytest tests/ -v --tb=short` +- [ ] Verify all sync orchestrator tests pass +- [ ] Verify all parity tests pass +- [ ] Verify backward compat: `node.run()` and `iter_packets()` still work diff --git a/superpowers/plans/2026-03-14-remove-populate-cache-plan.md b/superpowers/plans/2026-03-14-remove-populate-cache-plan.md new file mode 100644 index 00000000..d912de73 --- /dev/null +++ b/superpowers/plans/2026-03-14-remove-populate-cache-plan.md @@ -0,0 +1,132 @@ +# Remove populate_cache — Self-Caching Nodes Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development or superpowers:executing-plans. + +**Goal:** Remove `populate_cache` from all nodes; make `store_result` and `get_cached_results` self-caching. + +**Architecture:** Nodes build in-memory caches as a side effect of orchestrator calls. No external cache population step needed. + +**Spec:** `superpowers/specs/2026-03-14-remove-populate-cache-design.md` + +--- + +### Task 1: Update SourceNode — self-caching store_result, remove populate_cache + +**Files:** +- Modify: `src/orcapod/core/nodes/source_node.py` +- Modify: `tests/test_core/nodes/test_node_store_result.py` +- Modify: `tests/test_core/nodes/test_node_populate_cache.py` + +- [ ] Update `TestSourceNodeStoreResult` — add test verifying `store_result` populates cache: + +```python +def test_store_result_populates_internal_cache(self, source_and_node): + _, node = source_and_node + packets = list(node.iter_packets()) + node.store_result(packets) + # iter_packets should now return from cache + cached = list(node.iter_packets()) + assert len(cached) == len(packets) +``` + +- [ ] Run test, verify it fails (store_result is currently a no-op) +- [ ] Update `SourceNode.store_result` to populate `self._cached_results` +- [ ] Remove `SourceNode.populate_cache` method +- [ ] Remove `TestSourceNodePopulateCache` from test_node_populate_cache.py +- [ ] Remove `populate_cache` from `SourceNodeProtocol` in node_protocols.py +- [ ] Run all tests, verify pass +- [ ] Commit + +### Task 2: Update OperatorNode — self-caching store_result, remove populate_cache + +**Files:** +- Modify: `src/orcapod/core/nodes/operator_node.py` +- Modify: `tests/test_core/nodes/test_node_store_result.py` +- Modify: `tests/test_core/nodes/test_node_populate_cache.py` + +- [ ] Update `TestOperatorNodeStoreResult` — add test verifying `store_result` populates cache: + +```python +def test_store_result_populates_internal_cache(self, operator_with_db): + node, _ = operator_with_db + stream = node._operator.process(*node._input_streams) + output = list(stream.iter_packets()) + node.store_result(output) + # iter_packets should work from cache + cached = list(node.iter_packets()) + assert len(cached) == 2 +``` + +- [ ] Run test, verify it fails +- [ ] Update `OperatorNode.store_result` to also set `_cached_output_stream` +- [ ] Remove `OperatorNode.populate_cache` method +- [ ] Remove `TestOperatorNodePopulateCache` from test_node_populate_cache.py +- [ ] Remove `populate_cache` from `OperatorNodeProtocol` in node_protocols.py +- [ ] Run all tests, verify pass +- [ ] Commit + +### Task 3: Update FunctionNode — self-caching store_result + get_cached_results, remove populate_cache + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` +- Modify: `tests/test_core/nodes/test_node_store_result.py` +- Modify: `tests/test_core/nodes/test_function_node_get_cached.py` +- Modify: `tests/test_core/nodes/test_node_populate_cache.py` + +- [ ] Add test verifying `store_result` populates `_cached_output_packets`: + +```python +def test_store_result_populates_internal_cache(self, function_node_with_db): + node, _, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + tag_out, result = node.process_packet(tag, packet) + node.store_result(tag, packet, result) + assert len(node._cached_output_packets) == 1 +``` + +- [ ] Add test verifying `get_cached_results` populates `_cached_output_packets`: + +```python +def test_get_cached_results_populates_internal_cache(self, function_node_with_db): + node = function_node_with_db + packets = list(node._input_stream.iter_packets()) + # Process and store all packets + entry_ids = [] + for tag, packet in packets: + tag_out, result = node.process_packet(tag, packet) + node.store_result(tag, packet, result) + entry_ids.append(node.compute_pipeline_entry_id(tag, packet)) + # Clear internal cache + node._cached_output_packets.clear() + # get_cached_results should repopulate it + node.get_cached_results(entry_ids) + assert len(node._cached_output_packets) == 2 +``` + +- [ ] Run tests, verify they fail +- [ ] Update `FunctionNode.store_result` to append to `_cached_output_packets` +- [ ] Update `FunctionNode.get_cached_results` to populate `_cached_output_packets` +- [ ] Remove `FunctionNode.populate_cache` method +- [ ] Remove `TestFunctionNodePopulateCache` from test_node_populate_cache.py +- [ ] Remove `populate_cache` from `FunctionNodeProtocol` in node_protocols.py +- [ ] Run all tests, verify pass +- [ ] Commit + +### Task 4: Remove Pipeline._apply_results and update Pipeline.run() + +**Files:** +- Modify: `src/orcapod/pipeline/graph.py` +- Modify: `tests/test_pipeline/test_sync_orchestrator.py` + +- [ ] Remove `_apply_results` method from Pipeline +- [ ] Update `Pipeline.run()` — remove `_apply_results` calls +- [ ] Verify `test_run_populates_node_caches` still passes (nodes self-cache now) +- [ ] Run all tests +- [ ] Commit + +### Task 5: Clean up test_node_populate_cache.py + +- [ ] If test_node_populate_cache.py is now empty, delete it +- [ ] Run all tests +- [ ] Commit diff --git a/superpowers/plans/2026-03-14-sync-orchestrator-plan.md b/superpowers/plans/2026-03-14-sync-orchestrator-plan.md new file mode 100644 index 00000000..57a889d1 --- /dev/null +++ b/superpowers/plans/2026-03-14-sync-orchestrator-plan.md @@ -0,0 +1,2084 @@ +# Sync Pipeline Orchestrator Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Implement a synchronous pipeline orchestrator with per-packet observability, uniform compute/store node protocols, and Pipeline.run() integration. + +**Architecture:** The orchestrator walks a compiled node graph topologically, maintaining materialized buffers between nodes. Each node type exposes a protocol (SourceNodeProtocol, FunctionNodeProtocol, OperatorNodeProtocol) with uniform `store_result` / `populate_cache` methods. TypeGuard functions provide runtime dispatch with static type narrowing. An ExecutionObserver protocol enables per-packet hooks via dependency injection. + +**Tech Stack:** Python 3.12+, PyArrow, NetworkX, Polars (for DB joins), pytest + pytest-asyncio + +**Spec:** `superpowers/specs/2026-03-14-sync-orchestrator-design.md` + +--- + +## File Map + +### New Files + +| File | Responsibility | +|------|---------------| +| `src/orcapod/protocols/node_protocols.py` | SourceNodeProtocol, FunctionNodeProtocol, OperatorNodeProtocol, TypeGuard dispatch functions | +| `src/orcapod/pipeline/observer.py` | ExecutionObserver protocol, NoOpObserver default | +| `src/orcapod/pipeline/result.py` | OrchestratorResult dataclass | +| `src/orcapod/pipeline/sync_orchestrator.py` | SyncPipelineOrchestrator | +| `tests/test_pipeline/test_sync_orchestrator.py` | Orchestrator integration tests | +| `tests/test_pipeline/test_observer.py` | Observer hook tests | +| `tests/test_core/nodes/test_node_store_result.py` | store_result tests for all node types | +| `tests/test_core/nodes/test_node_populate_cache.py` | populate_cache tests for all node types | +| `tests/test_core/nodes/test_function_node_get_cached.py` | get_cached_results tests | + +### Modified Files + +| File | Changes | +|------|---------| +| `src/orcapod/core/nodes/source_node.py` | Add `store_result()`, `populate_cache()`, modify `iter_packets()` for cache check | +| `src/orcapod/core/nodes/operator_node.py` | Add `get_cached_output()`, `store_result()`, `populate_cache()` | +| `src/orcapod/core/nodes/function_node.py` | Extract pipeline record from `process_packet()` into `store_result()`, add `get_cached_results()`, `populate_cache()` | +| `src/orcapod/core/nodes/__init__.py` | Re-export GraphNode (no change needed, already correct) | +| `src/orcapod/pipeline/graph.py` | Update `Pipeline.run()`, add `_apply_results()` | +| `src/orcapod/pipeline/__init__.py` | Export SyncPipelineOrchestrator, update AsyncPipelineOrchestrator import | +| `src/orcapod/pipeline/orchestrator.py` | Rename to `async_orchestrator.py` | + +--- + +## Chunk 1: Protocols, Observer, and Result Types + +### Task 1: Node Protocols and TypeGuard Dispatch + +**Files:** +- Create: `src/orcapod/protocols/node_protocols.py` +- Create: `tests/test_protocols/test_node_protocols.py` + +- [ ] **Step 1: Write tests for TypeGuard dispatch** + +```python +# tests/test_protocols/test_node_protocols.py +"""Tests for node protocol TypeGuard dispatch functions.""" +from __future__ import annotations + +import pytest + +from orcapod.core.nodes import FunctionNode, OperatorNode, SourceNode +from orcapod.protocols.node_protocols import ( + is_function_node, + is_operator_node, + is_source_node, +) + + +class TestTypeGuardDispatch: + """TypeGuard functions correctly narrow node types.""" + + def test_is_source_node_true(self, source_node): + assert is_source_node(source_node) is True + + def test_is_source_node_false_for_function(self, function_node): + assert is_source_node(function_node) is False + + def test_is_function_node_true(self, function_node): + assert is_function_node(function_node) is True + + def test_is_function_node_false_for_operator(self, operator_node): + assert is_function_node(operator_node) is False + + def test_is_operator_node_true(self, operator_node): + assert is_operator_node(operator_node) is True + + def test_is_operator_node_false_for_source(self, source_node): + assert is_operator_node(source_node) is False + + +# --- Fixtures --- + +@pytest.fixture +def _sample_source(): + import pyarrow as pa + from orcapod.core.sources import ArrowTableSource + + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + return ArrowTableSource(table, tag_columns=["key"]) + + +@pytest.fixture +def source_node(_sample_source): + return SourceNode(_sample_source) + + +@pytest.fixture +def function_node(_sample_source): + from orcapod.core.function_pod import FunctionPod + from orcapod.core.packet_function import PythonPacketFunction + + pf = PythonPacketFunction(lambda value: value * 2, output_keys="result") + pod = FunctionPod(pf) + return FunctionNode(pod, _sample_source) + + +@pytest.fixture +def operator_node(_sample_source): + from orcapod.core.operators import SelectPacketColumns + + op = SelectPacketColumns(columns=["value"]) + return OperatorNode(op, input_streams=[_sample_source]) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/test_protocols/test_node_protocols.py -v` +Expected: FAIL — `ModuleNotFoundError: No module named 'orcapod.protocols.node_protocols'` + +- [ ] **Step 3: Implement node protocols and TypeGuard functions** + +```python +# src/orcapod/protocols/node_protocols.py +"""Node protocols for orchestrator interaction. + +Defines the three node protocols (Source, Function, Operator) that +formalize the interface between orchestrators and graph nodes, plus +TypeGuard dispatch functions for runtime type narrowing. +""" +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Protocol, TypeGuard, runtime_checkable + +if TYPE_CHECKING: + from orcapod.core.nodes import GraphNode + from orcapod.protocols.core_protocols import ( + PacketProtocol, + StreamProtocol, + TagProtocol, + ) + from orcapod.protocols.core_protocols.operator_pod import OperatorPodProtocol + + +@runtime_checkable +class SourceNodeProtocol(Protocol): + """Protocol for source nodes in orchestrated execution.""" + + node_type: str + + def iter_packets(self) -> Iterator[tuple["TagProtocol", "PacketProtocol"]]: ... + def store_result( + self, results: list[tuple["TagProtocol", "PacketProtocol"]] + ) -> None: ... + def populate_cache( + self, results: list[tuple["TagProtocol", "PacketProtocol"]] + ) -> None: ... + + +@runtime_checkable +class FunctionNodeProtocol(Protocol): + """Protocol for function nodes in orchestrated execution.""" + + node_type: str + + def get_cached_results( + self, entry_ids: list[str] + ) -> dict[str, tuple["TagProtocol", "PacketProtocol"]]: ... + + def compute_pipeline_entry_id( + self, tag: "TagProtocol", packet: "PacketProtocol" + ) -> str: ... + + def process_packet( + self, tag: "TagProtocol", packet: "PacketProtocol" + ) -> tuple["TagProtocol", "PacketProtocol | None"]: ... + + def store_result( + self, + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + ) -> None: ... + + def populate_cache( + self, results: list[tuple["TagProtocol", "PacketProtocol"]] + ) -> None: ... + + +@runtime_checkable +class OperatorNodeProtocol(Protocol): + """Protocol for operator nodes in orchestrated execution.""" + + node_type: str + + @property + def operator(self) -> "OperatorPodProtocol": ... + + def get_cached_output(self) -> "StreamProtocol | None": ... + def store_result( + self, results: list[tuple["TagProtocol", "PacketProtocol"]] + ) -> None: ... + def populate_cache( + self, results: list[tuple["TagProtocol", "PacketProtocol"]] + ) -> None: ... + + +def is_source_node(node: "GraphNode") -> TypeGuard[SourceNodeProtocol]: + """Check if a node is a source node.""" + return node.node_type == "source" + + +def is_function_node(node: "GraphNode") -> TypeGuard[FunctionNodeProtocol]: + """Check if a node is a function node.""" + return node.node_type == "function" + + +def is_operator_node(node: "GraphNode") -> TypeGuard[OperatorNodeProtocol]: + """Check if a node is an operator node.""" + return node.node_type == "operator" +``` + +Note: `OperatorNode` currently stores the operator as `self._operator` (private). Task 5 +adds a public `operator` property. The protocol declares the public property so the +orchestrator uses `node.operator.process(...)` consistently. + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `uv run pytest tests/test_protocols/test_node_protocols.py -v` +Expected: PASS (all 6 tests) + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/protocols/node_protocols.py tests/test_protocols/test_node_protocols.py +git commit -m "feat(protocols): add node protocols and TypeGuard dispatch for orchestrator" +``` + +### Task 2: ExecutionObserver Protocol and NoOpObserver + +**Files:** +- Create: `src/orcapod/pipeline/observer.py` +- Create: `tests/test_pipeline/test_observer.py` + +- [ ] **Step 1: Write tests for NoOpObserver** + +```python +# tests/test_pipeline/test_observer.py +"""Tests for ExecutionObserver protocol and NoOpObserver.""" +from __future__ import annotations + +from orcapod.pipeline.observer import ExecutionObserver, NoOpObserver + + +class TestNoOpObserver: + """NoOpObserver satisfies the protocol and does nothing.""" + + def test_satisfies_protocol(self): + observer = NoOpObserver() + assert isinstance(observer, ExecutionObserver) + + def test_on_node_start_noop(self): + observer = NoOpObserver() + observer.on_node_start(None) # type: ignore[arg-type] + + def test_on_node_end_noop(self): + observer = NoOpObserver() + observer.on_node_end(None) # type: ignore[arg-type] + + def test_on_packet_start_noop(self): + observer = NoOpObserver() + observer.on_packet_start(None, None, None) # type: ignore[arg-type] + + def test_on_packet_end_noop(self): + observer = NoOpObserver() + observer.on_packet_end(None, None, None, None, cached=False) # type: ignore[arg-type] +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/test_pipeline/test_observer.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement observer** + +```python +# src/orcapod/pipeline/observer.py +"""Execution observer protocol for pipeline orchestration. + +Provides hooks for monitoring node and packet-level execution events +during orchestrated pipeline runs. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from orcapod.core.nodes import GraphNode + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +@runtime_checkable +class ExecutionObserver(Protocol): + """Observer protocol for pipeline execution events. + + ``on_packet_start`` / ``on_packet_end`` are only invoked for function + nodes. ``on_node_start`` / ``on_node_end`` are invoked for all node + types. + """ + + def on_node_start(self, node: "GraphNode") -> None: ... + def on_node_end(self, node: "GraphNode") -> None: ... + def on_packet_start( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + ) -> None: ... + def on_packet_end( + self, + node: "GraphNode", + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + cached: bool, + ) -> None: ... + + +class NoOpObserver: + """Default observer that does nothing.""" + + def on_node_start(self, node: "GraphNode") -> None: + pass + + def on_node_end(self, node: "GraphNode") -> None: + pass + + def on_packet_start( + self, + node: "GraphNode", + tag: "TagProtocol", + packet: "PacketProtocol", + ) -> None: + pass + + def on_packet_end( + self, + node: "GraphNode", + tag: "TagProtocol", + input_packet: "PacketProtocol", + output_packet: "PacketProtocol | None", + cached: bool, + ) -> None: + pass +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `uv run pytest tests/test_pipeline/test_observer.py -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/pipeline/observer.py tests/test_pipeline/test_observer.py +git commit -m "feat(pipeline): add ExecutionObserver protocol and NoOpObserver" +``` + +### Task 3: OrchestratorResult Dataclass + +**Files:** +- Create: `src/orcapod/pipeline/result.py` + +- [ ] **Step 1: Create the result dataclass** + +```python +# src/orcapod/pipeline/result.py +"""Result type returned by pipeline orchestrators.""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + + +@dataclass +class OrchestratorResult: + """Result of an orchestrator run. + + Attributes: + node_outputs: Mapping from graph node to its computed (tag, packet) + pairs. Empty when ``materialize_results=False``. + """ + + node_outputs: dict[Any, list[tuple["TagProtocol", "PacketProtocol"]]] = field( + default_factory=dict + ) +``` + +- [ ] **Step 2: Commit** + +```bash +git add src/orcapod/pipeline/result.py +git commit -m "feat(pipeline): add OrchestratorResult dataclass" +``` + +--- + +## Chunk 2: Node Refactoring — SourceNode and OperatorNode + +### Task 4: SourceNode — store_result and populate_cache + +**Files:** +- Modify: `src/orcapod/core/nodes/source_node.py` +- Create: `tests/test_core/nodes/test_node_populate_cache.py` +- Create: `tests/test_core/nodes/test_node_store_result.py` + +- [ ] **Step 1: Write tests for SourceNode.populate_cache** + +```python +# tests/test_core/nodes/test_node_populate_cache.py +"""Tests for populate_cache on all node types.""" +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.nodes import SourceNode +from orcapod.core.sources import ArrowTableSource + + +@pytest.fixture +def source_and_node(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + node = SourceNode(src) + return src, node + + +class TestSourceNodePopulateCache: + def test_iter_packets_uses_cache_when_populated(self, source_and_node): + src, node = source_and_node + original = list(node.iter_packets()) + assert len(original) == 2 + + # Populate cache with only the first packet + node.populate_cache([original[0]]) + + cached = list(node.iter_packets()) + assert len(cached) == 1 + + def test_iter_packets_delegates_to_stream_when_no_cache(self, source_and_node): + _, node = source_and_node + result = list(node.iter_packets()) + assert len(result) == 2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/test_core/nodes/test_node_populate_cache.py::TestSourceNodePopulateCache -v` +Expected: FAIL — `AttributeError: 'SourceNode' object has no attribute 'populate_cache'` + +- [ ] **Step 3: Implement SourceNode.populate_cache and modify iter_packets** + +In `src/orcapod/core/nodes/source_node.py`, add `_cached_results` field in `__init__`, +modify `iter_packets()` to check cache, and add `populate_cache()`: + +In `__init__` after `self.stream = stream`: +```python +self._cached_results: list[tuple[cp.TagProtocol, cp.PacketProtocol]] | None = None +``` + +Replace `iter_packets`: +```python +def iter_packets(self) -> Iterator[tuple[cp.TagProtocol, cp.PacketProtocol]]: + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + if self._cached_results is not None: + return iter(self._cached_results) + return self.stream.iter_packets() +``` + +Add new method: +```python +def populate_cache( + self, results: list[tuple[cp.TagProtocol, cp.PacketProtocol]] +) -> None: + """Populate the in-memory cache with externally-provided results. + + After calling this, ``iter_packets()`` returns from the cache + instead of delegating to the wrapped stream. + """ + self._cached_results = list(results) +``` + +Also add `_cached_results = None` in the `from_descriptor` read-only path. + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `uv run pytest tests/test_core/nodes/test_node_populate_cache.py::TestSourceNodePopulateCache -v` +Expected: PASS + +- [ ] **Step 5: Write tests for SourceNode.store_result** + +```python +# tests/test_core/nodes/test_node_store_result.py +"""Tests for store_result on all node types.""" +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.nodes import SourceNode +from orcapod.core.sources import ArrowTableSource + + +@pytest.fixture +def source_and_node(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + node = SourceNode(src) + return src, node + + +class TestSourceNodeStoreResult: + def test_store_result_noop_without_db(self, source_and_node): + """store_result should be a no-op when no DB is configured.""" + _, node = source_and_node + packets = list(node.iter_packets()) + # Should not raise + node.store_result(packets) +``` + +- [ ] **Step 6: Run test to verify it fails** + +Run: `uv run pytest tests/test_core/nodes/test_node_store_result.py::TestSourceNodeStoreResult -v` +Expected: FAIL — `AttributeError: 'SourceNode' object has no attribute 'store_result'` + +- [ ] **Step 7: Implement SourceNode.store_result** + +In `src/orcapod/core/nodes/source_node.py`: + +```python +def store_result( + self, results: list[tuple[cp.TagProtocol, cp.PacketProtocol]] +) -> None: + """Persist source data snapshot to the pipeline DB if configured. + + Currently a no-op. Future implementations may store a snapshot of + what the pipeline consumed from this source. + """ + pass +``` + +- [ ] **Step 8: Run tests to verify they pass** + +Run: `uv run pytest tests/test_core/nodes/test_node_store_result.py::TestSourceNodeStoreResult tests/test_core/nodes/test_node_populate_cache.py::TestSourceNodePopulateCache -v` +Expected: PASS + +- [ ] **Step 9: Run existing SourceNode tests to verify no regressions** + +Run: `uv run pytest tests/test_core/sources/ tests/test_pipeline/ -v` +Expected: All existing tests PASS + +- [ ] **Step 10: Commit** + +```bash +git add src/orcapod/core/nodes/source_node.py tests/test_core/nodes/test_node_populate_cache.py tests/test_core/nodes/test_node_store_result.py +git commit -m "feat(source-node): add store_result and populate_cache for orchestrator support" +``` + +### Task 5: OperatorNode — get_cached_output, store_result, populate_cache + +**Files:** +- Modify: `src/orcapod/core/nodes/operator_node.py` +- Modify: `tests/test_core/nodes/test_node_populate_cache.py` +- Modify: `tests/test_core/nodes/test_node_store_result.py` + +- [ ] **Step 1: Write tests for OperatorNode.populate_cache** + +Append to `tests/test_core/nodes/test_node_populate_cache.py`: + +```python +from orcapod.core.nodes import OperatorNode +from orcapod.core.operators import SelectPacketColumns + + +@pytest.fixture +def operator_node_with_data(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + op = SelectPacketColumns(columns=["value"]) + node = OperatorNode(op, input_streams=[src]) + return node + + +class TestOperatorNodePopulateCache: + def test_iter_packets_uses_cache_when_populated(self, operator_node_with_data): + node = operator_node_with_data + # Run normally first to get output + node.run() + original = list(node.iter_packets()) + assert len(original) == 2 + + # Clear cache, then populate with subset + node.clear_cache() + node.populate_cache([original[0]]) + cached = list(node.iter_packets()) + assert len(cached) == 1 +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_core/nodes/test_node_populate_cache.py::TestOperatorNodePopulateCache -v` +Expected: FAIL — `AttributeError: 'OperatorNode' object has no attribute 'populate_cache'` + +- [ ] **Step 3: Write tests for OperatorNode.store_result and get_cached_output** + +Append to `tests/test_core/nodes/test_node_store_result.py`: + +```python +from orcapod.core.nodes import OperatorNode +from orcapod.core.operators import SelectPacketColumns +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import CacheMode + + +@pytest.fixture +def operator_with_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + op = SelectPacketColumns(columns=["value"]) + db = InMemoryArrowDatabase() + node = OperatorNode( + op, input_streams=[src], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + return node, db + + +@pytest.fixture +def operator_no_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + op = SelectPacketColumns(columns=["value"]) + return OperatorNode(op, input_streams=[src]) + + +class TestOperatorNodeStoreResult: + def test_store_result_writes_to_db_in_log_mode(self, operator_with_db): + node, db = operator_with_db + # Compute via operator directly (not through node.run) + stream = node._operator.process(*node._input_streams) + output = list(stream.iter_packets()) + node.store_result(output) + + records = node.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_store_result_noop_in_off_mode(self, operator_no_db): + node = operator_no_db + stream = node._operator.process(*node._input_streams) + output = list(stream.iter_packets()) + # Should not raise + node.store_result(output) + + def test_get_cached_output_returns_none_in_off_mode(self, operator_no_db): + assert operator_no_db.get_cached_output() is None + + def test_get_cached_output_returns_none_in_log_mode(self, operator_with_db): + node, _ = operator_with_db + assert node.get_cached_output() is None + + def test_get_cached_output_returns_stream_in_replay_mode(self, operator_with_db): + node, db = operator_with_db + # First store some results + stream = node._operator.process(*node._input_streams) + output = list(stream.iter_packets()) + node.store_result(output) + + # Switch to REPLAY mode + node._cache_mode = CacheMode.REPLAY + cached = node.get_cached_output() + assert cached is not None + cached_packets = list(cached.iter_packets()) + assert len(cached_packets) == 2 +``` + +- [ ] **Step 4: Run tests to verify they fail** + +Run: `uv run pytest tests/test_core/nodes/test_node_store_result.py::TestOperatorNodeStoreResult -v` +Expected: FAIL — `AttributeError` + +- [ ] **Step 5: Implement OperatorNode methods** + +In `src/orcapod/core/nodes/operator_node.py`, add three new methods. + +Add a public `operator` property (if not already present): +```python +@property +def operator(self) -> OperatorPodProtocol: + """Return the wrapped operator pod.""" + return self._operator +``` + +Add `get_cached_output`: +```python +def get_cached_output(self) -> "StreamProtocol | None": + """Return cached output stream in REPLAY mode, else None. + + Returns: + The cached stream if REPLAY mode and DB records exist, + otherwise None. + """ + if self._pipeline_database is None: + return None + if self._cache_mode != CacheMode.REPLAY: + return None + self._replay_from_cache() + return self._cached_output_stream +``` + +Add `store_result`: +```python +def store_result( + self, + results: "list[tuple[TagProtocol, PacketProtocol]]", +) -> None: + """Persist computed results to the pipeline DB. + + Wraps the materialized results as an ArrowTableStream and stores + via the existing ``_store_output_stream`` logic. No-op if no DB + is attached or cache mode is OFF. + + Args: + results: Materialized (tag, packet) pairs from computation. + """ + if self._pipeline_database is None: + return + if self._cache_mode == CacheMode.OFF: + return + + from orcapod.core.operators.static_output_pod import StaticOutputOperatorPod + + stream = StaticOutputOperatorPod._materialize_to_stream(results) + self._store_output_stream(stream) +``` + +Add `populate_cache`: +```python +def populate_cache( + self, + results: "list[tuple[TagProtocol, PacketProtocol]]", +) -> None: + """Populate in-memory cache from externally-provided results. + + After calling this, ``iter_packets()`` / ``as_table()`` return + from the cache without recomputation. Empty lists clear the cache + and set an empty stream. + + Args: + results: Materialized (tag, packet) pairs. + """ + if not results: + self._cached_output_stream = None + self._cached_output_table = None + self._update_modified_time() + return + + from orcapod.core.operators.static_output_pod import StaticOutputOperatorPod + + self._cached_output_stream = StaticOutputOperatorPod._materialize_to_stream( + results + ) + self._update_modified_time() +``` + +Note: `StaticOutputOperatorPod._materialize_to_stream` raises `ValueError` on empty lists, +so `populate_cache` handles the empty case explicitly. The method exists as a static helper +on the base operator class (verified at `static_output_pod.py:195`). + +- [ ] **Step 6: Run tests to verify they pass** + +Run: `uv run pytest tests/test_core/nodes/test_node_store_result.py::TestOperatorNodeStoreResult tests/test_core/nodes/test_node_populate_cache.py::TestOperatorNodePopulateCache -v` +Expected: PASS + +- [ ] **Step 7: Run existing OperatorNode tests to verify no regressions** + +Run: `uv run pytest tests/test_core/operators/ tests/test_pipeline/ -v` +Expected: All existing tests PASS + +- [ ] **Step 8: Commit** + +```bash +git add src/orcapod/core/nodes/operator_node.py tests/test_core/nodes/test_node_populate_cache.py tests/test_core/nodes/test_node_store_result.py +git commit -m "feat(operator-node): add get_cached_output, store_result, populate_cache" +``` + +--- + +## Chunk 3: FunctionNode Refactoring + +### Task 6: FunctionNode — Pure process_packet, store_result, get_cached_results, populate_cache + +The current `FunctionNode.process_packet()` bundles computation (via CachedFunctionPod) +with pipeline record writing (via `add_pipeline_record`). We need to: + +1. Extract pipeline record writing from `process_packet` into `store_result` +2. Keep `process_packet` handling computation + function-level memoization (CachedFunctionPod) +3. Add `get_cached_results` factored out of `iter_packets()` Phase 1 +4. Add `populate_cache` +5. Keep existing `iter_packets()` and `run()` working + +**Key insight:** `process_packet` is NOT pure — it writes to the result DB via +CachedFunctionPod. But that's function-level memoization (the function pod's own concern). +What moves to `store_result` is ONLY the pipeline provenance record +(`add_pipeline_record`). This keeps a clean boundary: function-level concerns in +`process_packet`, pipeline-level concerns in `store_result`. + +**Backward compatibility:** The existing `iter_packets()` calls `process_packet` then +`add_pipeline_record` inline. After this refactoring, we rename the original bundled +method to `_process_and_store_packet` and update all call sites in `iter_packets()`, +`_iter_packets_sequential()`, and `_iter_packets_concurrent()`. + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` +- Create: `tests/test_core/nodes/test_function_node_get_cached.py` +- Modify: `tests/test_core/nodes/test_node_populate_cache.py` +- Modify: `tests/test_core/nodes/test_node_store_result.py` + +- [ ] **Step 1: Write tests for FunctionNode.process_packet (no pipeline record) and store_result** + +```python +# Append to tests/test_core/nodes/test_node_store_result.py + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction + + +def double_value(value: int) -> int: + return value * 2 + + +@pytest.fixture +def function_node_with_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + pod, src, + pipeline_database=pipeline_db, + result_database=result_db, + ) + return node, pipeline_db, result_db + + +@pytest.fixture +def function_node_no_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + return FunctionNode(pod, src) + + +class TestFunctionNodeProcessPacket: + def test_process_packet_does_not_write_pipeline_record(self, function_node_with_db): + """process_packet handles computation but NOT pipeline provenance.""" + node, pipeline_db, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + tag_out, result = node.process_packet(tag, packet) + assert result is not None + + # No pipeline record should exist (only store_result writes those) + records = pipeline_db.get_all_records(node.pipeline_path) + assert records is None + + def test_process_packet_writes_to_result_db(self, function_node_with_db): + """process_packet handles function-level memoization via CachedFunctionPod.""" + node, _, result_db = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + tag_out, result = node.process_packet(tag, packet) + assert result is not None + + # Result DB should have the cached computation + cached_results = node._cached_function_pod.get_all_cached_outputs() + assert cached_results is not None + assert cached_results.num_rows == 1 + + def test_process_packet_returns_correct_result(self, function_node_no_db): + node = function_node_no_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + tag_out, result = node.process_packet(tag, packet) + assert result is not None + assert result.as_dict()["result"] == 2 # double of 1 + + +class TestFunctionNodeStoreResult: + def test_store_result_writes_pipeline_record(self, function_node_with_db): + node, pipeline_db, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + tag_out, result = node.process_packet(tag, packet) + node.store_result(tag, packet, result) + + records = pipeline_db.get_all_records(node.pipeline_path) + assert records is not None + assert records.num_rows == 1 + + def test_store_result_does_not_write_to_result_db(self, function_node_with_db): + """store_result only writes pipeline records, not result cache.""" + node, _, result_db = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + # process_packet first (writes to result DB) + tag_out, result = node.process_packet(tag, packet) + cached_before = node._cached_function_pod.get_all_cached_outputs() + count_before = cached_before.num_rows if cached_before is not None else 0 + + # store_result should NOT add more to result DB + node.store_result(tag, packet, result) + cached_after = node._cached_function_pod.get_all_cached_outputs() + count_after = cached_after.num_rows if cached_after is not None else 0 + assert count_after == count_before + + def test_store_result_noop_without_db(self, function_node_no_db): + node = function_node_no_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + tag_out, result = node.process_packet(tag, packet) + # Should not raise + node.store_result(tag, packet, result) + + def test_store_result_handles_none_output(self, function_node_with_db): + node, pipeline_db, _ = function_node_with_db + packets = list(node._input_stream.iter_packets()) + tag, packet = packets[0] + + # Should not raise; should be a no-op for None output + node.store_result(tag, packet, None) + records = pipeline_db.get_all_records(node.pipeline_path) + assert records is None +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/test_core/nodes/test_node_store_result.py::TestFunctionNodeProcessPacket tests/test_core/nodes/test_node_store_result.py::TestFunctionNodeStoreResult -v` +Expected: FAIL — first test fails because `process_packet` still writes pipeline records + +- [ ] **Step 3: Refactor FunctionNode.process_packet and add store_result** + +In `src/orcapod/core/nodes/function_node.py`: + +1. Rename current `process_packet` to `_process_and_store_packet` (used by existing + `iter_packets()` for backward compatibility — keeps bundled compute+pipeline record). + +2. Create new `process_packet` that handles computation + function-level memoization + but NOT pipeline records: + +```python +def process_packet( + self, + tag: TagProtocol, + packet: PacketProtocol, +) -> tuple[TagProtocol, PacketProtocol | None]: + """Process a single packet with function-level memoization. + + Delegates to ``CachedFunctionPod`` (when DB is attached) for + computation and result-level caching, or to the raw ``FunctionPod`` + otherwise. Does NOT write pipeline provenance records — use + ``store_result`` for that. + + Args: + tag: The tag associated with the packet. + packet: The input packet to process. + + Returns: + A ``(tag, output_packet)`` tuple; output_packet is ``None`` if + the function filters the packet out. + """ + if self._cached_function_pod is not None: + return self._cached_function_pod.process_packet(tag, packet) + return self._function_pod.process_packet(tag, packet) +``` + +3. Add `store_result` (pipeline provenance only): + +```python +def store_result( + self, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, +) -> None: + """Record pipeline provenance for a processed packet. + + Writes a pipeline record associating this (tag + system_tags + + input_packet) with the output packet record ID. Does NOT write + to the result DB — that is handled by ``process_packet`` via + ``CachedFunctionPod``. + + No-op if no pipeline DB is attached or output is None. + + Args: + tag: The tag associated with the packet. + input_packet: The original input packet. + output_packet: The computation result, or None if filtered. + """ + if output_packet is None: + return + if self._pipeline_database is None: + return + + result_computed = True + if self._cached_function_pod is not None: + result_computed = bool( + output_packet.get_meta_value( + self._cached_function_pod.RESULT_COMPUTED_FLAG, True + ) + ) + + self.add_pipeline_record( + tag, + input_packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) +``` + +4. Update `iter_packets()` to call `_process_and_store_packet` instead of `process_packet` + in Phase 2 (around line 787). Also update `_iter_packets_sequential` (line 826) **and + `_iter_packets_concurrent` (line 854)** — all three call sites must use the bundled + version for backward compatibility. + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `uv run pytest tests/test_core/nodes/test_node_store_result.py::TestFunctionNodeProcessPacket tests/test_core/nodes/test_node_store_result.py::TestFunctionNodeStoreResult -v` +Expected: PASS + +- [ ] **Step 5: Run ALL existing FunctionNode tests to verify no regressions** + +Run: `uv run pytest tests/test_core/function_pod/ tests/test_pipeline/ -v` +Expected: All existing tests PASS. The `iter_packets()` path uses `_process_and_store_packet` +which preserves the original bundled behavior. + +- [ ] **Step 6: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py tests/test_core/nodes/test_node_store_result.py +git commit -m "refactor(function-node): separate pure process_packet from store_result" +``` + +- [ ] **Step 7: Write tests for get_cached_results** + +```python +# tests/test_core/nodes/test_function_node_get_cached.py +"""Tests for FunctionNode.get_cached_results.""" +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase + + +def double_value(value: int) -> int: + return value * 2 + + +@pytest.fixture +def function_node_with_db(): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + pod, src, + pipeline_database=pipeline_db, + result_database=result_db, + ) + return node + + +class TestGetCachedResults: + def test_returns_empty_dict_when_no_db(self): + table = pa.table({ + "key": pa.array(["a"], type=pa.large_string()), + "value": pa.array([1], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + node = FunctionNode(pod, src) + + result = node.get_cached_results([]) + assert result == {} + + def test_returns_empty_dict_when_db_empty(self, function_node_with_db): + result = function_node_with_db.get_cached_results(["nonexistent"]) + assert result == {} + + def test_returns_cached_results_for_matching_entry_ids(self, function_node_with_db): + node = function_node_with_db + packets = list(node._input_stream.iter_packets()) + + # Process and store two packets + entry_ids = [] + for tag, packet in packets: + tag_out, result = node.process_packet(tag, packet) + node.store_result(tag, packet, result) + entry_ids.append(node.compute_pipeline_entry_id(tag, packet)) + + # Retrieve cached results for both + cached = node.get_cached_results(entry_ids) + assert len(cached) == 2 + assert all(eid in cached for eid in entry_ids) + + def test_filters_to_requested_entry_ids_only(self, function_node_with_db): + node = function_node_with_db + packets = list(node._input_stream.iter_packets()) + + entry_ids = [] + for tag, packet in packets: + tag_out, result = node.process_packet(tag, packet) + node.store_result(tag, packet, result) + entry_ids.append(node.compute_pipeline_entry_id(tag, packet)) + + # Request only the first entry ID + cached = node.get_cached_results([entry_ids[0]]) + assert len(cached) == 1 + assert entry_ids[0] in cached + assert entry_ids[1] not in cached +``` + +- [ ] **Step 8: Run tests to verify they fail** + +Run: `uv run pytest tests/test_core/nodes/test_function_node_get_cached.py -v` +Expected: FAIL — `AttributeError: 'FunctionNode' object has no attribute 'get_cached_results'` + +- [ ] **Step 9: Implement get_cached_results** + +In `src/orcapod/core/nodes/function_node.py`: + +```python +def get_cached_results( + self, entry_ids: list[str] +) -> dict[str, tuple[TagProtocol, PacketProtocol]]: + """Retrieve cached results for specific pipeline entry IDs. + + Looks up the pipeline DB and result DB, joins them, and filters + to the requested entry IDs. Returns a mapping from entry ID to + (tag, output_packet). + + Args: + entry_ids: Pipeline entry IDs to look up. + + Returns: + Mapping from entry_id to (tag, output_packet) for found entries. + Empty dict if no DB is attached or no matches found. + """ + if self._cached_function_pod is None or not entry_ids: + return {} + + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + entry_id_set = set(entry_ids) + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is None or results is None: + return {} + + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() + ) + + if joined.num_rows == 0: + return {} + + # Filter to requested entry IDs + all_entry_ids = joined.column(PIPELINE_ENTRY_ID_COL).to_pylist() + mask = [eid in entry_id_set for eid in all_entry_ids] + filtered = joined.filter(pa.array(mask)) + + if filtered.num_rows == 0: + return {} + + tag_keys = self._input_stream.keys()[0] + drop_cols = [ + c for c in filtered.column_names + if c.startswith(constants.META_PREFIX) or c == PIPELINE_ENTRY_ID_COL + ] + data_table = filtered.drop( + [c for c in drop_cols if c in filtered.column_names] + ) + + from orcapod.core.streams.arrow_table_stream import ArrowTableStream + + stream = ArrowTableStream(data_table, tag_columns=tag_keys) + filtered_entry_ids = [eid for eid, m in zip(all_entry_ids, mask) if m] + + result_dict: dict[str, tuple[TagProtocol, PacketProtocol]] = {} + for entry_id, (tag, packet) in zip( + filtered_entry_ids, stream.iter_packets() + ): + result_dict[entry_id] = (tag, packet) + + return result_dict +``` + +- [ ] **Step 10: Run tests to verify they pass** + +Run: `uv run pytest tests/test_core/nodes/test_function_node_get_cached.py -v` +Expected: PASS + +- [ ] **Step 11: Write tests for FunctionNode.populate_cache** + +Append to `tests/test_core/nodes/test_node_populate_cache.py`: + +```python +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction + + +class TestFunctionNodePopulateCache: + def test_iter_packets_uses_cache_when_populated(self): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(lambda value: value * 2, output_keys="result") + pod = FunctionPod(pf) + node = FunctionNode(pod, src) + + # Run to get results + original = list(node.iter_packets()) + assert len(original) == 2 + + # Clear and populate with subset + node.clear_cache() + node.populate_cache([original[0]]) + cached = list(node.iter_packets()) + assert len(cached) == 1 +``` + +- [ ] **Step 12: Run test to verify it fails** + +Run: `uv run pytest tests/test_core/nodes/test_node_populate_cache.py::TestFunctionNodePopulateCache -v` +Expected: FAIL — `AttributeError: 'FunctionNode' object has no attribute 'populate_cache'` + +- [ ] **Step 13: Implement FunctionNode.populate_cache** + +In `src/orcapod/core/nodes/function_node.py`: + +```python +def populate_cache( + self, results: list[tuple[TagProtocol, PacketProtocol]] +) -> None: + """Populate in-memory cache from externally-provided results. + + After calling this, ``iter_packets()`` returns from the cache + without upstream iteration or computation. + + Args: + results: Materialized (tag, packet) pairs. + """ + self._cached_output_packets.clear() + for i, (tag, packet) in enumerate(results): + self._cached_output_packets[i] = (tag, packet) + self._cached_input_iterator = None + self._needs_iterator = False + self._update_modified_time() +``` + +- [ ] **Step 14: Run ALL tests** + +Run: `uv run pytest tests/test_core/nodes/ tests/test_core/function_pod/ tests/test_pipeline/ -v` +Expected: ALL PASS + +- [ ] **Step 15: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py tests/test_core/nodes/test_function_node_get_cached.py tests/test_core/nodes/test_node_populate_cache.py +git commit -m "feat(function-node): add get_cached_results, populate_cache for orchestrator" +``` + +--- + +## Chunk 4: SyncPipelineOrchestrator and Pipeline Integration + +### Task 7: SyncPipelineOrchestrator + +**Files:** +- Create: `src/orcapod/pipeline/sync_orchestrator.py` +- Create: `tests/test_pipeline/test_sync_orchestrator.py` + +- [ ] **Step 1: Write test for basic linear pipeline (Source → FunctionPod)** + +```python +# tests/test_pipeline/test_sync_orchestrator.py +"""Tests for the synchronous pipeline orchestrator.""" +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode, OperatorNode, SourceNode +from orcapod.core.operators import SelectPacketColumns +from orcapod.core.operators.join import Join +from orcapod.core.operators.mappers import MapPackets +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import Pipeline +from orcapod.pipeline.observer import ExecutionObserver +from orcapod.pipeline.sync_orchestrator import SyncPipelineOrchestrator + + +def _make_source(tag_col, packet_col, data): + table = pa.table({ + tag_col: pa.array(data[tag_col], type=pa.large_string()), + packet_col: pa.array(data[packet_col], type=pa.int64()), + }) + return ArrowTableSource(table, tag_columns=[tag_col]) + + +def double_value(value: int) -> int: + return value * 2 + + +def add_values(value: int, score: int) -> int: + return value + score + + +class TestSyncOrchestratorLinear: + """Source -> FunctionPod.""" + + def test_linear_pipeline(self): + src = _make_source("key", "value", {"key": ["a", "b", "c"], "value": [1, 2, 3]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="linear", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + orch = SyncPipelineOrchestrator() + result = orch.run(pipeline._node_graph) + + # Verify results exist for all nodes + assert len(result.node_outputs) > 0 + + # Find the function node output + fn_outputs = [ + v for k, v in result.node_outputs.items() + if k.node_type == "function" + ] + assert len(fn_outputs) == 1 + assert len(fn_outputs[0]) == 3 + values = sorted([pkt.as_dict()["result"] for _, pkt in fn_outputs[0]]) + assert values == [2, 4, 6] + + +class TestSyncOrchestratorDiamond: + """Two sources -> Join -> FunctionPod.""" + + def test_diamond_dag(self): + src_a = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + src_b = _make_source("key", "score", {"key": ["a", "b"], "score": [100, 200]}) + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="diamond", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + joined = Join()(src_a, src_b, label="join") + pod(joined, label="adder") + + orch = SyncPipelineOrchestrator() + result = orch.run(pipeline._node_graph) + + fn_outputs = [ + v for k, v in result.node_outputs.items() + if k.node_type == "function" + ] + assert len(fn_outputs) == 1 + values = sorted([pkt.as_dict()["total"] for _, pkt in fn_outputs[0]]) + assert values == [110, 220] + + +class TestSyncOrchestratorObserver: + """Observer hooks fire in correct order.""" + + def test_observer_hooks_fire(self): + src = _make_source("key", "value", {"key": ["a"], "value": [1]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="obs", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + events = [] + + class RecordingObserver: + def on_node_start(self, node): + events.append(("node_start", node.node_type)) + + def on_node_end(self, node): + events.append(("node_end", node.node_type)) + + def on_packet_start(self, node, tag, packet): + events.append(("packet_start",)) + + def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): + events.append(("packet_end", cached)) + + orch = SyncPipelineOrchestrator(observer=RecordingObserver()) + orch.run(pipeline._node_graph) + + # Source: node_start, node_end + # Function: node_start, packet_start, packet_end, node_end + assert events[0] == ("node_start", "source") + assert events[1] == ("node_end", "source") + assert events[2] == ("node_start", "function") + assert events[3] == ("packet_start",) + assert events[4] == ("packet_end", False) + assert events[5] == ("node_end", "function") + + +class TestSyncOrchestratorUnknownNodeType: + """Unknown node types raise TypeError.""" + + def test_raises_on_unknown_node_type(self): + import networkx as nx + + class FakeNode: + node_type = "unknown" + + G = nx.DiGraph() + G.add_node(FakeNode()) + + orch = SyncPipelineOrchestrator() + with pytest.raises(TypeError, match="Unknown node type"): + orch.run(G) +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py -v` +Expected: FAIL — `ModuleNotFoundError` + +- [ ] **Step 3: Implement SyncPipelineOrchestrator** + +```python +# src/orcapod/pipeline/sync_orchestrator.py +"""Synchronous pipeline orchestrator. + +Walks a compiled pipeline's node graph topologically, executing each node +with materialized buffers and per-packet observer hooks for function nodes. +""" +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from orcapod.pipeline.observer import NoOpObserver +from orcapod.pipeline.result import OrchestratorResult +from orcapod.protocols.node_protocols import ( + is_function_node, + is_operator_node, + is_source_node, +) + +if TYPE_CHECKING: + import networkx as nx + + from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + +logger = logging.getLogger(__name__) + + +class SyncPipelineOrchestrator: + """Execute a compiled pipeline synchronously with observer hooks. + + Walks the node graph in topological order. For each node: + - SourceNode: materializes iter_packets() into a buffer + - FunctionNode: per-packet execution with cache lookup + observer hooks + - OperatorNode: bulk execution via operator.process() + + All nodes have store_result called after computation. The orchestrator + returns an OrchestratorResult with all node outputs. + + Args: + observer: Optional execution observer for hooks. Defaults to + NoOpObserver. + """ + + def __init__(self, observer: "ExecutionObserver | None" = None) -> None: + self._observer = observer or NoOpObserver() + + def run( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, + ) -> OrchestratorResult: + """Execute the node graph synchronously. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, keep all node outputs in memory + and return them. If False, discard buffers after downstream + consumption (only DB-persisted results survive). + + Returns: + OrchestratorResult with node outputs. + """ + import networkx as nx + + topo_order = list(nx.topological_sort(graph)) + buffers: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} + processed: set[Any] = set() + + for node in topo_order: + if is_source_node(node): + buffers[node] = self._execute_source(node) + elif is_function_node(node): + upstream_buffer = self._gather_upstream(node, graph, buffers) + buffers[node] = self._execute_function(node, upstream_buffer) + elif is_operator_node(node): + upstream_buffers = self._gather_upstream_multi( + node, graph, buffers + ) + buffers[node] = self._execute_operator(node, upstream_buffers) + else: + raise TypeError( + f"Unknown node type: {getattr(node, 'node_type', None)!r}" + ) + + processed.add(node) + + if not materialize_results: + self._gc_buffers(node, graph, buffers, processed) + + return OrchestratorResult(node_outputs=buffers) + + def _execute_source(self, node): + """Execute a source node: materialize its packets.""" + self._observer.on_node_start(node) + output = list(node.iter_packets()) + node.store_result(output) + self._observer.on_node_end(node) + return output + + def _execute_function(self, node, upstream_buffer): + """Execute a function node with per-packet hooks.""" + self._observer.on_node_start(node) + + upstream_entries = [ + (tag, packet, node.compute_pipeline_entry_id(tag, packet)) + for tag, packet in upstream_buffer + ] + entry_ids = [eid for _, _, eid in upstream_entries] + + cached = node.get_cached_results(entry_ids=entry_ids) + + output = [] + for tag, packet, entry_id in upstream_entries: + self._observer.on_packet_start(node, tag, packet) + if entry_id in cached: + tag_out, result = cached[entry_id] + self._observer.on_packet_end( + node, tag, packet, result, cached=True + ) + output.append((tag_out, result)) + else: + tag_out, result = node.process_packet(tag, packet) + node.store_result(tag, packet, result) + self._observer.on_packet_end( + node, tag, packet, result, cached=False + ) + if result is not None: + output.append((tag_out, result)) + + self._observer.on_node_end(node) + return output + + def _execute_operator(self, node, upstream_buffers): + """Execute an operator node: bulk stream processing.""" + self._observer.on_node_start(node) + + cached = node.get_cached_output() + if cached is not None: + output = list(cached.iter_packets()) + else: + input_streams = [ + self._materialize_as_stream(buf, upstream_node) + for buf, upstream_node in upstream_buffers + ] + result_stream = node.operator.process(*input_streams) + output = list(result_stream.iter_packets()) + node.store_result(output) + + self._observer.on_node_end(node) + return output + + def _gather_upstream(self, node, graph, buffers): + """Gather a single upstream buffer (for function nodes).""" + predecessors = list(graph.predecessors(node)) + if len(predecessors) != 1: + raise ValueError( + f"FunctionNode expects exactly 1 upstream, got {len(predecessors)}" + ) + return buffers[predecessors[0]] + + def _gather_upstream_multi(self, node, graph, buffers): + """Gather multiple upstream buffers with their nodes (for operator nodes). + + Returns list of (buffer, upstream_node) tuples preserving the order + that matches the operator's input_streams order. + """ + predecessors = list(graph.predecessors(node)) + # Match predecessor order to the node's upstreams order + upstream_order = { + id(upstream): i for i, upstream in enumerate(node.upstreams) + } + sorted_preds = sorted( + predecessors, + key=lambda p: upstream_order.get(id(p), 0), + ) + return [(buffers[p], p) for p in sorted_preds] + + @staticmethod + def _materialize_as_stream(buf, upstream_node): + """Wrap a (tag, packet) buffer as an ArrowTableStream. + + Args: + buf: List of (tag, packet) tuples. + upstream_node: The node that produced this buffer (used to + determine tag column names). + + Returns: + An ArrowTableStream. + """ + from orcapod.core.streams.arrow_table_stream import ArrowTableStream + from orcapod.utils import arrow_utils + + if not buf: + raise ValueError("Cannot materialize empty buffer as stream") + + # Use selective columns matching the proven pattern in + # StaticOutputOperatorPod._materialize_to_stream: + # system_tags for tags, source info for packets. + tag_tables = [ + tag.as_table(columns={"system_tags": True}) for tag, _ in buf + ] + packet_tables = [ + pkt.as_table(columns={"source": True}) for _, pkt in buf + ] + + import pyarrow as pa + + combined_tags = pa.concat_tables(tag_tables) + combined_packets = pa.concat_tables(packet_tables) + + user_tag_keys = tuple(buf[0][0].keys()) + source_info = buf[0][1].source_info() + + full_table = arrow_utils.hstack_tables(combined_tags, combined_packets) + + return ArrowTableStream( + full_table, + tag_columns=user_tag_keys, + source_info=source_info, + ) + + @staticmethod + def _gc_buffers(current_node, graph, buffers, processed): + """Discard buffers no longer needed by any unprocessed downstream.""" + for pred in graph.predecessors(current_node): + if pred not in buffers: + continue + all_successors_done = all( + succ in processed for succ in graph.successors(pred) + ) + if all_successors_done: + del buffers[pred] +``` + +- [ ] **Step 4: Run tests to verify they pass** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/pipeline/sync_orchestrator.py tests/test_pipeline/test_sync_orchestrator.py +git commit -m "feat(pipeline): implement SyncPipelineOrchestrator" +``` + +### Task 8: Pipeline.run() Integration and File Reorganization + +**Files:** +- Modify: `src/orcapod/pipeline/graph.py` +- Rename: `src/orcapod/pipeline/orchestrator.py` → `src/orcapod/pipeline/async_orchestrator.py` +- Modify: `src/orcapod/pipeline/__init__.py` + +- [ ] **Step 1: Write test for Pipeline.run() with default orchestrator** + +Append to `tests/test_pipeline/test_sync_orchestrator.py`: + +```python +class TestPipelineRunIntegration: + """Pipeline.run() with orchestrator parameter.""" + + def test_default_run_uses_sync_orchestrator(self): + """Pipeline.run() without args should use SyncPipelineOrchestrator.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="default", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + pipeline.run() + + records = pipeline.doubler.get_all_records() + assert records is not None + assert records.num_rows == 2 + values = sorted(records.column("result").to_pylist()) + assert values == [2, 4] + + def test_run_with_explicit_orchestrator(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="explicit", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + events = [] + + class RecordingObserver: + def on_node_start(self, node): + events.append(("node_start", node.node_type)) + def on_node_end(self, node): + events.append(("node_end", node.node_type)) + def on_packet_start(self, node, tag, packet): + events.append(("packet_start",)) + def on_packet_end(self, node, tag, input_pkt, output_pkt, cached): + events.append(("packet_end",)) + + orch = SyncPipelineOrchestrator(observer=RecordingObserver()) + pipeline.run(orchestrator=orch) + + # Observer events should have fired + assert len(events) > 0 + + # Results should be accessible via node + records = pipeline.doubler.get_all_records() + assert records is not None + + def test_run_populates_node_caches(self): + """After run(), iter_packets()/as_table() should work on nodes.""" + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="cache", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + pipeline.run() + + # as_table should work after orchestrated execution + table = pipeline.doubler.as_table() + assert table.num_rows == 2 +``` + +- [ ] **Step 2: Run tests to verify they fail** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py::TestPipelineRunIntegration -v` +Expected: FAIL (Pipeline.run signature doesn't accept orchestrator yet) + +- [ ] **Step 3: Rename orchestrator.py to async_orchestrator.py** + +```bash +git mv src/orcapod/pipeline/orchestrator.py src/orcapod/pipeline/async_orchestrator.py +``` + +- [ ] **Step 4: Update Pipeline.run() in graph.py** + +In `src/orcapod/pipeline/graph.py`, update the `run` method to accept an orchestrator +parameter and the `_run_async` method to import from the new location: + +```python +def run( + self, + orchestrator=None, + config: PipelineConfig | None = None, + execution_engine: cp.PacketFunctionExecutorProtocol | None = None, + execution_engine_opts: "dict[str, Any] | None" = None, +) -> None: + """Execute all compiled nodes. + + Args: + orchestrator: Optional orchestrator instance. When provided, + the orchestrator drives execution. When omitted and no + async mode is requested, defaults to + SyncPipelineOrchestrator. + config: Pipeline configuration (legacy parameter). + execution_engine: Optional packet-function executor (legacy). + execution_engine_opts: Engine options dict (legacy). + """ + from orcapod.pipeline.sync_orchestrator import SyncPipelineOrchestrator + from orcapod.pipeline.result import OrchestratorResult + from orcapod.types import ExecutorType, PipelineConfig + + explicit_config = config is not None + config = config or PipelineConfig() + + effective_engine = ( + execution_engine + if execution_engine is not None + else config.execution_engine + ) + effective_opts = ( + execution_engine_opts + if execution_engine_opts is not None + else config.execution_engine_opts + ) + + if not self._compiled: + self.compile() + + if effective_engine is not None: + self._apply_execution_engine(effective_engine, effective_opts) + + if orchestrator is not None: + result = orchestrator.run(self._node_graph) + self._apply_results(result) + else: + use_async = config.executor == ExecutorType.ASYNC_CHANNELS or ( + effective_engine is not None and not explicit_config + ) + if use_async: + self._run_async(config) + else: + orch = SyncPipelineOrchestrator() + result = orch.run(self._node_graph) + self._apply_results(result) + + self.flush() +``` + +Add the `_apply_results` method: + +```python +def _apply_results(self, result: "OrchestratorResult") -> None: + """Populate node caches from orchestrator results.""" + for node, outputs in result.node_outputs.items(): + if hasattr(node, "populate_cache"): + node.populate_cache(outputs) +``` + +Update `_run_async` import path: + +```python +def _run_async(self, config: PipelineConfig) -> None: + """Run the pipeline asynchronously using the orchestrator.""" + from orcapod.pipeline.async_orchestrator import AsyncPipelineOrchestrator + + orchestrator = AsyncPipelineOrchestrator() + orchestrator.run(self, config) +``` + +- [ ] **Step 5: Update pipeline __init__.py** + +```python +# src/orcapod/pipeline/__init__.py +from .async_orchestrator import AsyncPipelineOrchestrator +from .graph import Pipeline +from .serialization import LoadStatus, PIPELINE_FORMAT_VERSION +from .sync_orchestrator import SyncPipelineOrchestrator + +__all__ = [ + "AsyncPipelineOrchestrator", + "LoadStatus", + "PIPELINE_FORMAT_VERSION", + "Pipeline", + "SyncPipelineOrchestrator", +] +``` + +- [ ] **Step 6: Run tests to verify they pass** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py -v` +Expected: PASS + +- [ ] **Step 7: Run ALL tests to verify no regressions** + +Run: `uv run pytest tests/ -v` +Expected: ALL PASS. The async orchestrator tests should still pass since the import +path is updated in `__init__.py`. + +- [ ] **Step 8: Commit** + +```bash +git add -A +git commit -m "feat(pipeline): integrate SyncPipelineOrchestrator into Pipeline.run()" +``` + +### Task 9: Sync vs Async Parity Tests + +**Files:** +- Modify: `tests/test_pipeline/test_sync_orchestrator.py` + +- [ ] **Step 1: Write parity tests** + +Append to `tests/test_pipeline/test_sync_orchestrator.py`: + +```python +class TestSyncAsyncParity: + """Sync orchestrator should produce same DB results as async.""" + + def test_linear_pipeline_parity(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + # Sync via orchestrator + sync_pipeline = Pipeline( + name="sync", pipeline_database=InMemoryArrowDatabase() + ) + with sync_pipeline: + pod(src, label="doubler") + sync_pipeline.run() + sync_records = sync_pipeline.doubler.get_all_records() + sync_values = sorted(sync_records.column("result").to_pylist()) + + # Async + from orcapod.pipeline import AsyncPipelineOrchestrator + + async_pipeline = Pipeline( + name="async", pipeline_database=InMemoryArrowDatabase() + ) + with async_pipeline: + pod(src, label="doubler") + AsyncPipelineOrchestrator().run(async_pipeline) + async_records = async_pipeline.doubler.get_all_records() + async_values = sorted(async_records.column("result").to_pylist()) + + assert sync_values == async_values + + def test_diamond_pipeline_parity(self): + src_a = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + src_b = _make_source("key", "score", {"key": ["a", "b"], "score": [100, 200]}) + pf = PythonPacketFunction(add_values, output_keys="total") + pod = FunctionPod(pf) + + sync_pipeline = Pipeline( + name="sync_d", pipeline_database=InMemoryArrowDatabase() + ) + with sync_pipeline: + joined = Join()(src_a, src_b, label="join") + pod(joined, label="adder") + sync_pipeline.run() + sync_values = sorted( + sync_pipeline.adder.get_all_records().column("total").to_pylist() + ) + + from orcapod.pipeline import AsyncPipelineOrchestrator + + async_pipeline = Pipeline( + name="async_d", pipeline_database=InMemoryArrowDatabase() + ) + with async_pipeline: + joined = Join()(src_a, src_b, label="join") + pod(joined, label="adder") + AsyncPipelineOrchestrator().run(async_pipeline) + async_values = sorted( + async_pipeline.adder.get_all_records().column("total").to_pylist() + ) + + assert sync_values == async_values +``` + +- [ ] **Step 2: Run parity tests** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py::TestSyncAsyncParity -v` +Expected: PASS + +- [ ] **Step 3: Final full test suite run** + +Run: `uv run pytest tests/ -v` +Expected: ALL PASS + +- [ ] **Step 4: Commit** + +```bash +git add tests/test_pipeline/test_sync_orchestrator.py +git commit -m "test(pipeline): add sync vs async parity tests" +``` diff --git a/superpowers/plans/2026-03-15-async-orchestrator-refactor-plan.md b/superpowers/plans/2026-03-15-async-orchestrator-refactor-plan.md new file mode 100644 index 00000000..2414ca10 --- /dev/null +++ b/superpowers/plans/2026-03-15-async-orchestrator-refactor-plan.md @@ -0,0 +1,1717 @@ +# Async Orchestrator Refactor Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Refactor both orchestrators to use slim node protocols where nodes own their execution and orchestrators are pure topology schedulers. + +**Architecture:** Node protocols slim to `execute()` + `async_execute()` with observer injection. Orchestrators call these methods and collect results. Per-packet logic (cache lookup, observer hooks) moves inside nodes. AsyncPipelineOrchestrator adopts the same `run(graph) -> OrchestratorResult` interface as the sync orchestrator. + +**Tech Stack:** Python, asyncio, networkx, pyarrow, pytest, pytest-asyncio + +**Spec:** `superpowers/specs/2026-03-15-async-orchestrator-refactor-design.md` + +--- + +## Chunk 1: Protocol Changes and SourceNode + +### Task 1: Slim down node protocols + +**Files:** +- Modify: `src/orcapod/protocols/node_protocols.py` + +- [ ] **Step 1: Write failing test — protocols have new shape** + +Add a test that imports the new protocol shapes and verifies TypeGuard dispatch still works. + +```python +# tests/test_pipeline/test_node_protocols.py (new file) +"""Tests for revised node protocols.""" + +from __future__ import annotations + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from orcapod.protocols.node_protocols import ( + SourceNodeProtocol, + FunctionNodeProtocol, + OperatorNodeProtocol, + is_source_node, + is_function_node, + is_operator_node, +) + + +class TestSourceNodeProtocol: + def test_requires_execute(self): + """SourceNodeProtocol requires execute method.""" + + class GoodSource: + node_type = "source" + + def execute(self, *, observer=None): + return [] + + async def async_execute(self, output, *, observer=None): + pass + + assert isinstance(GoodSource(), SourceNodeProtocol) + + def test_rejects_old_iter_packets_only(self): + """SourceNodeProtocol no longer accepts iter_packets alone.""" + + class OldSource: + node_type = "source" + + def iter_packets(self): + return iter([]) + + assert not isinstance(OldSource(), SourceNodeProtocol) + + +class TestFunctionNodeProtocol: + def test_requires_execute_and_async_execute(self): + class GoodFunction: + node_type = "function" + + def execute(self, input_stream, *, observer=None): + return [] + + async def async_execute(self, input_channel, output, *, observer=None): + pass + + assert isinstance(GoodFunction(), FunctionNodeProtocol) + + def test_rejects_old_protocol(self): + """Old protocol with get_cached_results etc. is not sufficient.""" + + class OldFunction: + node_type = "function" + + def get_cached_results(self, entry_ids): + return {} + + def compute_pipeline_entry_id(self, tag, packet): + return "" + + def execute_packet(self, tag, packet): + return (tag, None) + + def execute(self, input_stream): + return [] + + # Missing async_execute → not a valid FunctionNodeProtocol + assert not isinstance(OldFunction(), FunctionNodeProtocol) + + +class TestOperatorNodeProtocol: + def test_requires_execute_and_async_execute(self): + class GoodOperator: + node_type = "operator" + + def execute(self, *input_streams, observer=None): + return [] + + async def async_execute(self, inputs, output, *, observer=None): + pass + + assert isinstance(GoodOperator(), OperatorNodeProtocol) + + def test_rejects_old_protocol(self): + """Old protocol with get_cached_output is not sufficient.""" + + class OldOperator: + node_type = "operator" + + def execute(self, *input_streams): + return [] + + def get_cached_output(self): + return None + + # Missing async_execute → not valid + assert not isinstance(OldOperator(), OperatorNodeProtocol) + + +class TestTypeGuardDispatch: + def test_dispatch_source(self): + node = MagicMock() + node.node_type = "source" + assert is_source_node(node) + assert not is_function_node(node) + assert not is_operator_node(node) + + def test_dispatch_function(self): + node = MagicMock() + node.node_type = "function" + assert is_function_node(node) + + def test_dispatch_operator(self): + node = MagicMock() + node.node_type = "operator" + assert is_operator_node(node) +``` + +Create `tests/test_pipeline/test_node_protocols.py` with the above content. + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py -v` +Expected: FAIL — protocol shapes don't match yet. + +- [ ] **Step 3: Update node protocols** + +Replace the contents of `src/orcapod/protocols/node_protocols.py`: + +```python +"""Node protocols for orchestrator interaction. + +Defines the three node protocols (Source, Function, Operator) that +formalize the interface between orchestrators and graph nodes, plus +TypeGuard dispatch functions for runtime type narrowing. + +Each protocol exposes ``execute`` (sync) and ``async_execute`` (async). +Nodes own their execution — caching, per-packet logic, and persistence +are internal. Orchestrators are topology schedulers. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Protocol, TypeGuard, runtime_checkable + +if TYPE_CHECKING: + from orcapod.channels import ReadableChannel, WritableChannel + from orcapod.core.nodes import GraphNode + from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.core_protocols import ( + PacketProtocol, + StreamProtocol, + TagProtocol, + ) + + +@runtime_checkable +class SourceNodeProtocol(Protocol): + """Protocol for source nodes in orchestrated execution.""" + + node_type: str + + def execute( + self, + *, + observer: "ExecutionObserver | None" = None, + ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... + + async def async_execute( + self, + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: "ExecutionObserver | None" = None, + ) -> None: ... + + +@runtime_checkable +class FunctionNodeProtocol(Protocol): + """Protocol for function nodes in orchestrated execution.""" + + node_type: str + + def execute( + self, + input_stream: "StreamProtocol", + *, + observer: "ExecutionObserver | None" = None, + ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... + + async def async_execute( + self, + input_channel: "ReadableChannel[tuple[TagProtocol, PacketProtocol]]", + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: "ExecutionObserver | None" = None, + ) -> None: ... + + +@runtime_checkable +class OperatorNodeProtocol(Protocol): + """Protocol for operator nodes in orchestrated execution.""" + + node_type: str + + def execute( + self, + *input_streams: "StreamProtocol", + observer: "ExecutionObserver | None" = None, + ) -> list[tuple["TagProtocol", "PacketProtocol"]]: ... + + async def async_execute( + self, + inputs: "Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]]", + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: "ExecutionObserver | None" = None, + ) -> None: ... + + +def is_source_node(node: "GraphNode") -> TypeGuard[SourceNodeProtocol]: + """Check if a node is a source node.""" + return node.node_type == "source" + + +def is_function_node(node: "GraphNode") -> TypeGuard[FunctionNodeProtocol]: + """Check if a node is a function node.""" + return node.node_type == "function" + + +def is_operator_node(node: "GraphNode") -> TypeGuard[OperatorNodeProtocol]: + """Check if a node is an operator node.""" + return node.node_type == "operator" +``` + +- [ ] **Step 4: Run protocol tests to verify they pass** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add tests/test_pipeline/test_node_protocols.py src/orcapod/protocols/node_protocols.py +git commit -m "refactor(protocols): slim node protocols to execute + async_execute with observer (PLT-922)" +``` + +### Task 2: Delete AsyncExecutableProtocol + +**Files:** +- Delete: `src/orcapod/protocols/core_protocols/async_executable.py` +- Modify: `src/orcapod/protocols/core_protocols/__init__.py` + +- [ ] **Step 1: Remove AsyncExecutableProtocol import and re-export** + +In `src/orcapod/protocols/core_protocols/__init__.py`, remove line 4 +(`from .async_executable import AsyncExecutableProtocol`) and remove +`"AsyncExecutableProtocol"` from `__all__`. + +- [ ] **Step 2: Delete the file** + +```bash +rm src/orcapod/protocols/core_protocols/async_executable.py +``` + +- [ ] **Step 3: Check for other imports of AsyncExecutableProtocol** + +Run: `uv run grep -r "AsyncExecutableProtocol" src/ tests/` + +If any imports remain, remove them. This protocol was defined but not used +by any caller. + +- [ ] **Step 4: Run full test suite to verify nothing breaks** + +Run: `uv run pytest tests/ -x -q` +Expected: All tests pass. + +- [ ] **Step 5: Commit** + +```bash +git add -u +git commit -m "refactor(protocols): remove AsyncExecutableProtocol (PLT-922)" +``` + +### Task 3: Add SourceNode.execute() with observer injection + +**Files:** +- Modify: `src/orcapod/core/nodes/source_node.py:228-255` +- Test: `tests/test_pipeline/test_node_protocols.py` (extend) + +- [ ] **Step 1: Write failing test for SourceNode.execute()** + +Append to `tests/test_pipeline/test_node_protocols.py`: + +```python +import pyarrow as pa +from orcapod.core.sources import ArrowTableSource +from orcapod.core.nodes import SourceNode + + +class TestSourceNodeExecute: + def _make_source_node(self): + table = pa.table({ + "key": pa.array(["a", "b", "c"], type=pa.large_string()), + "value": pa.array([1, 2, 3], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + return SourceNode(src) + + def test_execute_returns_list(self): + node = self._make_source_node() + result = node.execute() + assert isinstance(result, list) + assert len(result) == 3 + + def test_execute_populates_cached_results(self): + node = self._make_source_node() + node.execute() + assert node._cached_results is not None + assert len(node._cached_results) == 3 + + def test_execute_with_observer(self): + node = self._make_source_node() + events = [] + + class Obs: + def on_node_start(self, n): + events.append(("start", n.node_type)) + def on_node_end(self, n): + events.append(("end", n.node_type)) + def on_packet_start(self, n, t, p): + pass + def on_packet_end(self, n, t, ip, op, cached): + pass + + node.execute(observer=Obs()) + assert events == [("start", "source"), ("end", "source")] + + def test_execute_without_observer(self): + """execute() works fine with no observer.""" + node = self._make_source_node() + result = node.execute() + assert len(result) == 3 +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestSourceNodeExecute -v` +Expected: FAIL — `execute()` method doesn't exist yet. + +- [ ] **Step 3: Implement SourceNode.execute()** + +Add to `src/orcapod/core/nodes/source_node.py`, before the `run()` method +(around line 237): + +```python +def execute( + self, + *, + observer: Any = None, +) -> list[tuple[cp.TagProtocol, cp.PacketProtocol]]: + """Execute this source: materialize packets and return. + + Args: + observer: Optional execution observer for hooks. + + Returns: + List of (tag, packet) tuples. + """ + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + if observer is not None: + observer.on_node_start(self) + result = list(self.stream.iter_packets()) + self._cached_results = result + if observer is not None: + observer.on_node_end(self) + return result +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestSourceNodeExecute -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/source_node.py tests/test_pipeline/test_node_protocols.py +git commit -m "feat(source-node): add execute() with observer injection (PLT-922)" +``` + +### Task 4: Tighten SourceNode.async_execute() signature + observer + +**Files:** +- Modify: `src/orcapod/core/nodes/source_node.py:240-254` +- Test: `tests/test_pipeline/test_node_protocols.py` (extend) + +- [ ] **Step 1: Write failing test for tightened async_execute** + +Append to `tests/test_pipeline/test_node_protocols.py`: + +```python +import pytest +from orcapod.channels import Channel + + +class TestSourceNodeAsyncExecuteProtocol: + @pytest.mark.asyncio + async def test_tightened_signature(self): + """async_execute takes output only, no inputs.""" + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + node = SourceNode(src) + + output_ch = Channel(buffer_size=16) + # New signature: just output + observer + await node.async_execute(output_ch.writer, observer=None) + rows = await output_ch.reader.collect() + assert len(rows) == 2 + + @pytest.mark.asyncio + async def test_async_execute_with_observer(self): + table = pa.table({ + "key": pa.array(["a"], type=pa.large_string()), + "value": pa.array([1], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + node = SourceNode(src) + events = [] + + class Obs: + def on_node_start(self, n): + events.append("start") + def on_node_end(self, n): + events.append("end") + def on_packet_start(self, n, t, p): + pass + def on_packet_end(self, n, t, ip, op, cached): + pass + + output_ch = Channel(buffer_size=16) + await node.async_execute(output_ch.writer, observer=Obs()) + assert events == ["start", "end"] +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestSourceNodeAsyncExecuteProtocol -v` +Expected: FAIL — old signature takes `inputs, output`. + +- [ ] **Step 3: Update SourceNode.async_execute()** + +Replace the `async_execute` method in `src/orcapod/core/nodes/source_node.py` +(around line 240): + +```python +async def async_execute( + self, + output: WritableChannel[tuple[cp.TagProtocol, cp.PacketProtocol]], + *, + observer: Any = None, +) -> None: + """Push all (tag, packet) pairs from the wrapped stream to the output channel. + + Args: + output: Channel to write results to. + observer: Optional execution observer for hooks. + """ + if self.stream is None: + raise RuntimeError( + "SourceNode in read-only mode has no stream data available" + ) + try: + if observer is not None: + observer.on_node_start(self) + for tag, packet in self.stream.iter_packets(): + await output.send((tag, packet)) + if observer is not None: + observer.on_node_end(self) + finally: + await output.close() +``` + +Also remove `Sequence` from the imports since it's no longer needed for the +signature (keep `Iterator`). + +- [ ] **Step 4: Run tests** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py -v` +Expected: All PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/source_node.py tests/test_pipeline/test_node_protocols.py +git commit -m "refactor(source-node): tighten async_execute signature + observer (PLT-922)" +``` + +## Chunk 2: FunctionNode and OperatorNode Changes + +### Task 5: Add observer parameter to FunctionNode.execute() + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py:488-512` +- Test: `tests/test_pipeline/test_node_protocols.py` (extend) + +- [ ] **Step 1: Write failing test** + +Append to `tests/test_pipeline/test_node_protocols.py`: + +```python +from orcapod.core.function_pod import FunctionPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.nodes import FunctionNode + + +def double_value(value: int) -> int: + return value * 2 + + +class TestFunctionNodeExecute: + def _make_function_node(self): + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + return FunctionNode(pod, src) + + def test_execute_with_observer(self): + node = self._make_function_node() + events = [] + + class Obs: + def on_node_start(self, n): + events.append(("node_start", n.node_type)) + def on_node_end(self, n): + events.append(("node_end", n.node_type)) + def on_packet_start(self, n, t, p): + events.append(("packet_start",)) + def on_packet_end(self, n, t, ip, op, cached): + events.append(("packet_end", cached)) + + input_stream = node._input_stream + result = node.execute(input_stream, observer=Obs()) + + assert len(result) == 2 + assert events[0] == ("node_start", "function") + assert events[-1] == ("node_end", "function") + # Should have packet_start/packet_end for each packet + packet_events = [e for e in events if e[0].startswith("packet")] + assert len(packet_events) == 4 # 2 start + 2 end + + def test_execute_without_observer(self): + node = self._make_function_node() + input_stream = node._input_stream + result = node.execute(input_stream) + assert len(result) == 2 + values = sorted([pkt.as_dict()["result"] for _, pkt in result]) + assert values == [2, 4] +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestFunctionNodeExecute -v` +Expected: FAIL — `execute()` doesn't accept `observer` keyword. + +- [ ] **Step 3: Update FunctionNode.execute()** + +In `src/orcapod/core/nodes/function_node.py`, modify the `execute` method +(line 488) to add observer parameter and internal hooks. The method should: + +1. Accept `*, observer=None` keyword parameter +2. Call `observer.on_node_start(self)` at the start +3. For each packet: compute entry ID, check cache, call + `observer.on_packet_start` / `on_packet_end(cached=...)` around execution +4. Call `observer.on_node_end(self)` at the end + +```python +def execute( + self, + input_stream: StreamProtocol, + *, + observer: Any = None, +) -> list[tuple[TagProtocol, PacketProtocol]]: + """Execute all packets from a stream: compute, persist, and cache. + + Args: + input_stream: The input stream to process. + observer: Optional execution observer for hooks. + + Returns: + Materialized list of (tag, output_packet) pairs, excluding + None outputs. + """ + if observer is not None: + observer.on_node_start(self) + + # Gather entry IDs and check cache + upstream_entries = [ + (tag, packet, self.compute_pipeline_entry_id(tag, packet)) + for tag, packet in input_stream.iter_packets() + ] + entry_ids = [eid for _, _, eid in upstream_entries] + cached = self.get_cached_results(entry_ids=entry_ids) + + output: list[tuple[TagProtocol, PacketProtocol]] = [] + for tag, packet, entry_id in upstream_entries: + if observer is not None: + observer.on_packet_start(self, tag, packet) + + if entry_id in cached: + tag_out, result = cached[entry_id] + if observer is not None: + observer.on_packet_end(self, tag, packet, result, cached=True) + output.append((tag_out, result)) + else: + tag_out, result = self._process_packet_internal(tag, packet) + if observer is not None: + observer.on_packet_end(self, tag, packet, result, cached=False) + if result is not None: + output.append((tag_out, result)) + + if observer is not None: + observer.on_node_end(self) + return output +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestFunctionNodeExecute -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py tests/test_pipeline/test_node_protocols.py +git commit -m "feat(function-node): add observer injection to execute() (PLT-922)" +``` + +### Task 6: Tighten FunctionNode.async_execute() signature + observer + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py:1142-1263` +- Test: `tests/test_pipeline/test_node_protocols.py` (extend) + +- [ ] **Step 1: Write failing test** + +Append to `tests/test_pipeline/test_node_protocols.py`: + +```python +class TestFunctionNodeAsyncExecute: + @pytest.mark.asyncio + async def test_tightened_signature(self): + """async_execute takes single input_channel, not Sequence.""" + table = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([1, 2], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + node = FunctionNode(pod, src) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + for tag, packet in src.iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + # New signature: single input_channel, not list + await node.async_execute(input_ch.reader, output_ch.writer) + rows = await output_ch.reader.collect() + assert len(rows) == 2 + values = sorted([pkt.as_dict()["result"] for _, pkt in rows]) + assert values == [2, 4] + + @pytest.mark.asyncio + async def test_async_execute_with_observer(self): + table = pa.table({ + "key": pa.array(["a"], type=pa.large_string()), + "value": pa.array([1], type=pa.int64()), + }) + src = ArrowTableSource(table, tag_columns=["key"]) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + node = FunctionNode(pod, src) + + events = [] + class Obs: + def on_node_start(self, n): events.append("node_start") + def on_node_end(self, n): events.append("node_end") + def on_packet_start(self, n, t, p): events.append("pkt_start") + def on_packet_end(self, n, t, ip, op, cached): events.append("pkt_end") + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + for tag, packet in src.iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + await node.async_execute(input_ch.reader, output_ch.writer, observer=Obs()) + assert "node_start" in events + assert "node_end" in events +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestFunctionNodeAsyncExecute -v` +Expected: FAIL — old signature takes `inputs` (Sequence) and `pipeline_config`. + +- [ ] **Step 3: Update FunctionNode.async_execute()** + +Replace the `async_execute` method in `src/orcapod/core/nodes/function_node.py` +(line 1142). Key changes: +- First positional arg: `input_channel: ReadableChannel[...]` (not `inputs: Sequence[...]`) +- Remove `pipeline_config` parameter +- Add `*, observer=None` keyword +- Replace all `inputs[0]` references with `input_channel` +- Use hardcoded default concurrency (defer to PLT-930) +- Add observer hooks: `on_node_start`/`on_node_end` around the whole method, + `on_packet_start`/`on_packet_end(cached=...)` around each packet in Phase 2 + +```python +async def async_execute( + self, + input_channel: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: Any = None, +) -> None: + """Streaming async execution for FunctionNode. + + When a database is attached, uses two-phase execution: replay cached + results first, then compute missing packets concurrently. Otherwise, + routes each packet through ``_async_process_packet_internal`` directly. + + Args: + input_channel: Single input channel to read from. + output: Output channel to write results to. + observer: Optional execution observer for hooks. + """ + try: + if observer is not None: + observer.on_node_start(self) + + if self._cached_function_pod is not None: + # Two-phase async execution with DB backing + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + existing_entry_ids: set[str] = set() + + taginfo = self._pipeline_database.get_all_records( + self.pipeline_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is not None and results is not None: + joined = ( + pl.DataFrame(taginfo) + .join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + .to_arrow() + ) + if joined.num_rows > 0: + tag_keys = self._input_stream.keys()[0] + existing_entry_ids = set( + cast( + list[str], + joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), + ) + ) + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + ] + data_table = joined.drop( + [c for c in drop_cols if c in joined.column_names] + ) + existing_stream = ArrowTableStream( + data_table, tag_columns=tag_keys + ) + for tag, packet in existing_stream.iter_packets(): + await output.send((tag, packet)) + + # Phase 2: process new packets concurrently + async def process_one_db( + tag: TagProtocol, packet: PacketProtocol + ) -> None: + try: + if observer is not None: + observer.on_packet_start(self, tag, packet) + ( + tag_out, + result_packet, + ) = await self._async_process_packet_internal(tag, packet) + if observer is not None: + observer.on_packet_end( + self, tag, packet, result_packet, cached=False + ) + if result_packet is not None: + await output.send((tag_out, result_packet)) + finally: + pass + + async with asyncio.TaskGroup() as tg: + async for tag, packet in input_channel: + entry_id = self.compute_pipeline_entry_id(tag, packet) + if entry_id in existing_entry_ids: + if observer is not None: + observer.on_packet_start(self, tag, packet) + observer.on_packet_end( + self, tag, packet, None, cached=True + ) + continue + tg.create_task(process_one_db(tag, packet)) + else: + # Simple async execution without DB + async def process_one( + tag: TagProtocol, packet: PacketProtocol + ) -> None: + if observer is not None: + observer.on_packet_start(self, tag, packet) + ( + tag_out, + result_packet, + ) = await self._async_process_packet_internal(tag, packet) + if observer is not None: + observer.on_packet_end( + self, tag, packet, result_packet, cached=False + ) + if result_packet is not None: + await output.send((tag_out, result_packet)) + + async with asyncio.TaskGroup() as tg: + async for tag, packet in input_channel: + tg.create_task(process_one(tag, packet)) + + if observer is not None: + observer.on_node_end(self) + finally: + await output.close() +``` + +Note: Concurrency limiting (semaphore) is removed for now. PLT-930 will +re-add it as node-level config. + +- [ ] **Step 4: Run new tests** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestFunctionNodeAsyncExecute -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py tests/test_pipeline/test_node_protocols.py +git commit -m "refactor(function-node): tighten async_execute signature + observer (PLT-922)" +``` + +### Task 7: Add observer to OperatorNode.execute() + cache check + +**Files:** +- Modify: `src/orcapod/core/nodes/operator_node.py:432-473` +- Test: `tests/test_pipeline/test_node_protocols.py` (extend) + +- [ ] **Step 1: Write failing test** + +Append to `tests/test_pipeline/test_node_protocols.py`: + +```python +from orcapod.core.nodes import OperatorNode +from orcapod.core.operators.join import Join + + +class TestOperatorNodeExecute: + def _make_join_node(self): + table_a = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + table_b = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "score": pa.array([100, 200], type=pa.int64()), + }) + src_a = ArrowTableSource(table_a, tag_columns=["key"]) + src_b = ArrowTableSource(table_b, tag_columns=["key"]) + return OperatorNode(Join(), input_streams=[src_a, src_b]) + + def test_execute_with_observer(self): + node = self._make_join_node() + events = [] + + class Obs: + def on_node_start(self, n): + events.append(("node_start", n.node_type)) + def on_node_end(self, n): + events.append(("node_end", n.node_type)) + def on_packet_start(self, n, t, p): + pass + def on_packet_end(self, n, t, ip, op, cached): + pass + + result = node.execute( + *node._input_streams, observer=Obs() + ) + assert len(result) == 2 + assert events == [("node_start", "operator"), ("node_end", "operator")] + + def test_execute_without_observer(self): + node = self._make_join_node() + result = node.execute(*node._input_streams) + assert len(result) == 2 +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestOperatorNodeExecute -v` +Expected: FAIL — `execute()` doesn't accept `observer`. + +- [ ] **Step 3: Update OperatorNode.execute()** + +Modify `src/orcapod/core/nodes/operator_node.py` `execute` method (line 432): + +```python +def execute( + self, + *input_streams: StreamProtocol, + observer: Any = None, +) -> list[tuple[TagProtocol, PacketProtocol]]: + """Execute input streams: compute, persist, and cache. + + Args: + *input_streams: Input streams to execute. + observer: Optional execution observer for hooks. + + Returns: + Materialized list of (tag, packet) pairs. + """ + if observer is not None: + observer.on_node_start(self) + + # Check REPLAY cache first + cached_output = self.get_cached_output() + if cached_output is not None: + output = list(cached_output.iter_packets()) + if observer is not None: + observer.on_node_end(self) + return output + + # Compute + result_stream = self._operator.process(*input_streams) + + # Materialize + output = list(result_stream.iter_packets()) + + # Cache + if output: + self._cached_output_stream = StaticOutputOperatorPod._materialize_to_stream( + output + ) + else: + self._cached_output_stream = result_stream + + self._update_modified_time() + + # Persist to DB only in LOG mode + if ( + self._pipeline_database is not None + and self._cache_mode == CacheMode.LOG + and self._cached_output_stream is not None + ): + self._store_output_stream(self._cached_output_stream) + + if observer is not None: + observer.on_node_end(self) + return output +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestOperatorNodeExecute -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/operator_node.py tests/test_pipeline/test_node_protocols.py +git commit -m "feat(operator-node): add observer + cache check to execute() (PLT-922)" +``` + +### Task 8: Add observer to OperatorNode.async_execute() + +**Files:** +- Modify: `src/orcapod/core/nodes/operator_node.py:627-688` +- Test: `tests/test_pipeline/test_node_protocols.py` (extend) + +- [ ] **Step 1: Write failing test** + +Append to `tests/test_pipeline/test_node_protocols.py`: + +```python +class TestOperatorNodeAsyncExecute: + @pytest.mark.asyncio + async def test_async_execute_with_observer(self): + table_a = pa.table({ + "key": pa.array(["a", "b"], type=pa.large_string()), + "value": pa.array([10, 20], type=pa.int64()), + }) + src_a = ArrowTableSource(table_a, tag_columns=["key"]) + from orcapod.core.operators import SelectPacketColumns + op = SelectPacketColumns(columns=["value"]) + op_node = OperatorNode(op, input_streams=[src_a]) + + events = [] + class Obs: + def on_node_start(self, n): events.append("start") + def on_node_end(self, n): events.append("end") + def on_packet_start(self, n, t, p): pass + def on_packet_end(self, n, t, ip, op, cached): pass + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + for tag, packet in src_a.iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + await op_node.async_execute([input_ch.reader], output_ch.writer, observer=Obs()) + assert "start" in events + assert "end" in events +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestOperatorNodeAsyncExecute -v` +Expected: FAIL — `async_execute()` doesn't accept `observer`. + +- [ ] **Step 3: Update OperatorNode.async_execute()** + +Modify `src/orcapod/core/nodes/operator_node.py` `async_execute` (line 627) +to add `*, observer=None` keyword and call hooks: + +```python +async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: Any = None, +) -> None: +``` + +Add `if observer: observer.on_node_start(self)` near the top of the try block, +and `if observer: observer.on_node_end(self)` before the `finally`. + +- [ ] **Step 4: Run tests** + +Run: `uv run pytest tests/test_pipeline/test_node_protocols.py::TestOperatorNodeAsyncExecute -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/operator_node.py tests/test_pipeline/test_node_protocols.py +git commit -m "feat(operator-node): add observer to async_execute() (PLT-922)" +``` + +## Chunk 3: Orchestrator Refactoring + +### Task 9: Simplify SyncPipelineOrchestrator + +**Files:** +- Modify: `src/orcapod/pipeline/sync_orchestrator.py` +- Modify: `tests/test_pipeline/test_sync_orchestrator.py` + +- [ ] **Step 1: Update SyncPipelineOrchestrator to use node.execute()** + +Rewrite `src/orcapod/pipeline/sync_orchestrator.py`. The `run()` method +calls `node.execute(...)` directly. Remove `_execute_source`, +`_execute_function`, `_execute_operator`. Keep `_materialize_as_stream`, +`_gather_upstream`, `_gather_upstream_multi`, `_gc_buffers`. + +```python +def run( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, +) -> OrchestratorResult: + """Execute the node graph synchronously. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, keep all node outputs in memory. + If False, discard buffers after downstream consumption. + + Returns: + OrchestratorResult with node outputs. + """ + import networkx as nx + + topo_order = list(nx.topological_sort(graph)) + buffers: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} + processed: set[Any] = set() + + for node in topo_order: + if is_source_node(node): + buffers[node] = node.execute(observer=self._observer) + elif is_function_node(node): + upstream_buf = self._gather_upstream(node, graph, buffers) + upstream_node = list(graph.predecessors(node))[0] + input_stream = self._materialize_as_stream(upstream_buf, upstream_node) + buffers[node] = node.execute(input_stream, observer=self._observer) + elif is_operator_node(node): + upstream_buffers = self._gather_upstream_multi(node, graph, buffers) + input_streams = [ + self._materialize_as_stream(buf, upstream_node) + for buf, upstream_node in upstream_buffers + ] + buffers[node] = node.execute(*input_streams, observer=self._observer) + else: + raise TypeError( + f"Unknown node type: {getattr(node, 'node_type', None)!r}" + ) + + processed.add(node) + + if not materialize_results: + self._gc_buffers(node, graph, buffers, processed) + + return OrchestratorResult(node_outputs=buffers) +``` + +- [ ] **Step 2: Run existing sync orchestrator tests** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py -v` +Expected: All PASS — the simplified orchestrator should produce the same results. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/pipeline/sync_orchestrator.py +git commit -m "refactor(sync-orchestrator): delegate to node.execute(), remove per-packet logic (PLT-922)" +``` + +### Task 10: Refactor AsyncPipelineOrchestrator + +**Files:** +- Modify: `src/orcapod/pipeline/async_orchestrator.py` +- Modify: `tests/test_pipeline/test_orchestrator.py` + +- [ ] **Step 1: Rewrite AsyncPipelineOrchestrator** + +Replace the contents of `src/orcapod/pipeline/async_orchestrator.py`: + +```python +"""Async pipeline orchestrator for push-based channel execution. + +Walks a compiled pipeline's node graph and launches all nodes concurrently +via ``asyncio.TaskGroup``, wiring them together with bounded channels. +Uses TypeGuard dispatch with tightened per-type async_execute signatures. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections import defaultdict +from typing import TYPE_CHECKING, Any + +from orcapod.channels import BroadcastChannel, Channel +from orcapod.pipeline.observer import NoOpObserver +from orcapod.pipeline.result import OrchestratorResult +from orcapod.protocols.node_protocols import ( + is_function_node, + is_operator_node, + is_source_node, +) + +if TYPE_CHECKING: + import networkx as nx + + from orcapod.pipeline.observer import ExecutionObserver + from orcapod.protocols.core_protocols import PacketProtocol, TagProtocol + +logger = logging.getLogger(__name__) + + +class AsyncPipelineOrchestrator: + """Execute a compiled pipeline asynchronously using channels. + + After compilation, the orchestrator: + + 1. Walks the node graph in topological order. + 2. Creates bounded channels (or broadcast channels for fan-out) + between connected nodes. + 3. Launches every node's ``async_execute`` concurrently via + ``asyncio.TaskGroup``, using TypeGuard dispatch for per-type + signatures. + + Args: + observer: Optional execution observer for hooks. + buffer_size: Channel buffer size. Defaults to 64. + """ + + def __init__( + self, + observer: "ExecutionObserver | None" = None, + buffer_size: int = 64, + ) -> None: + self._observer = observer + self._buffer_size = buffer_size + + def run( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, + ) -> OrchestratorResult: + """Synchronous entry point — runs the async pipeline to completion. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, collect all node outputs into + the result. If False, return empty node_outputs. + + Returns: + OrchestratorResult with node outputs. + """ + return asyncio.run(self._run_async(graph, materialize_results)) + + async def run_async( + self, + graph: "nx.DiGraph", + materialize_results: bool = True, + ) -> OrchestratorResult: + """Async entry point for callers already inside an event loop. + + Args: + graph: A NetworkX DiGraph with GraphNode objects as vertices. + materialize_results: If True, collect all node outputs. + + Returns: + OrchestratorResult with node outputs. + """ + return await self._run_async(graph, materialize_results) + + async def _run_async( + self, + graph: "nx.DiGraph", + materialize_results: bool, + ) -> OrchestratorResult: + """Core async logic: wire channels, launch tasks, collect results.""" + import networkx as nx + + topo_order = list(nx.topological_sort(graph)) + buf = self._buffer_size + + # Build edge maps + out_edges: dict[Any, list[Any]] = defaultdict(list) + in_edges: dict[Any, list[Any]] = defaultdict(list) + for upstream_node, downstream_node in graph.edges(): + out_edges[upstream_node].append(downstream_node) + in_edges[downstream_node].append(upstream_node) + + # Create channels for each edge + node_output_channels: dict[Any, Channel | BroadcastChannel] = {} + edge_readers: dict[tuple[Any, Any], Any] = {} + + for node, downstreams in out_edges.items(): + if len(downstreams) == 1: + ch = Channel(buffer_size=buf) + node_output_channels[node] = ch + edge_readers[(node, downstreams[0])] = ch.reader + else: + bch = BroadcastChannel(buffer_size=buf) + node_output_channels[node] = bch + for ds in downstreams: + edge_readers[(node, ds)] = bch.add_reader() + + # Terminal nodes need sink channels + terminal_channels: list[Channel] = [] + for node in topo_order: + if node not in node_output_channels: + ch = Channel(buffer_size=buf) + node_output_channels[node] = ch + terminal_channels.append(ch) + + # Result collection: tap each node's output + collectors: dict[Any, list[tuple[TagProtocol, PacketProtocol]]] = {} + if materialize_results: + for node in topo_order: + collectors[node] = [] + + # Launch all nodes concurrently + async with asyncio.TaskGroup() as tg: + for node in topo_order: + writer = node_output_channels[node].writer + + if materialize_results: + # Wrap writer to collect items + collector = collectors[node] + writer = _CollectingWriter(writer, collector) + + if is_source_node(node): + tg.create_task( + node.async_execute(writer, observer=self._observer) + ) + elif is_function_node(node): + input_reader = edge_readers[ + (list(in_edges[node])[0], node) + ] + tg.create_task( + node.async_execute( + input_reader, writer, observer=self._observer + ) + ) + elif is_operator_node(node): + input_readers = [ + edge_readers[(upstream, node)] + for upstream in in_edges.get(node, []) + ] + tg.create_task( + node.async_execute( + input_readers, writer, observer=self._observer + ) + ) + else: + raise TypeError( + f"Unknown node type: {getattr(node, 'node_type', None)!r}" + ) + + # Drain terminal channels + for ch in terminal_channels: + await ch.reader.collect() + + return OrchestratorResult( + node_outputs=collectors if materialize_results else {} + ) + + +class _CollectingWriter: + """Wrapper that collects items while forwarding to real writer.""" + + def __init__(self, writer: Any, collector: list) -> None: + self._writer = writer + self._collector = collector + + async def send(self, item: Any) -> None: + self._collector.append(item) + await self._writer.send(item) + + async def close(self) -> None: + await self._writer.close() +``` + +- [ ] **Step 2: Update async orchestrator tests** + +Update `tests/test_pipeline/test_orchestrator.py` with these mechanical changes +throughout the file: + +**Signature changes (find and replace):** +- `orchestrator.run(pipeline)` → `pipeline.compile(); orchestrator.run(pipeline._node_graph); pipeline.flush()` +- `orchestrator.run(pipeline, config=config)` → `pipeline.compile(); AsyncPipelineOrchestrator(buffer_size=config.channel_buffer_size).run(pipeline._node_graph); pipeline.flush()` +- `await orchestrator.run_async(pipeline)` → `pipeline.compile(); await orchestrator.run_async(pipeline._node_graph); pipeline.flush()` +- `await node.async_execute([], output_ch.writer)` → `await node.async_execute(output_ch.writer)` +- `await node.async_execute([input_ch.reader], output_ch.writer)` → `await node.async_execute(input_ch.reader, output_ch.writer)` + +**Affected test classes and specific changes:** + +`TestSourceNodeAsyncExecute`: Change `await node.async_execute([], output_ch.writer)` to +`await node.async_execute(output_ch.writer)` in both test methods. + +`TestFunctionNodeAsyncExecute`: Change `await node.async_execute([input_ch.reader], output_ch.writer)` +to `await node.async_execute(input_ch.reader, output_ch.writer)`. + +`TestOrchestratorLinearPipeline`: Both tests — add `pipeline.compile()` before +and `pipeline.flush()` after `orchestrator.run(pipeline._node_graph)`. + +`TestOrchestratorOperatorPipeline`: Same compile/run/flush pattern. + +`TestOrchestratorDiamondDag`: Both tests — same pattern. + +`TestOrchestratorRunAsync`: Change `await orchestrator.run_async(pipeline)` to +`pipeline.compile(); await orchestrator.run_async(pipeline._node_graph); pipeline.flush()`. + +`TestPipelineConfigIntegration`: Replace: +```python +config = PipelineConfig(executor=ExecutorType.ASYNC_CHANNELS, channel_buffer_size=4) +orchestrator = AsyncPipelineOrchestrator() +orchestrator.run(pipeline, config=config) +``` +with: +```python +pipeline.compile() +orchestrator = AsyncPipelineOrchestrator(buffer_size=4) +orchestrator.run(pipeline._node_graph) +pipeline.flush() +``` + +- [ ] **Step 3: Run updated tests** + +Run: `uv run pytest tests/test_pipeline/test_orchestrator.py -v` +Expected: All PASS. + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/pipeline/async_orchestrator.py tests/test_pipeline/test_orchestrator.py +git commit -m "refactor(async-orchestrator): use node protocols, graph interface, OrchestratorResult (PLT-922)" +``` + +### Task 11: Update Pipeline.run() + +**Files:** +- Modify: `src/orcapod/pipeline/graph.py:359-474` + +- [ ] **Step 1: Update Pipeline.run() and remove _run_async()** + +In `src/orcapod/pipeline/graph.py`: + +1. In the `run()` method (line 412-429), change the async path to: + ```python + if use_async: + from orcapod.pipeline.async_orchestrator import AsyncPipelineOrchestrator + AsyncPipelineOrchestrator( + buffer_size=config.channel_buffer_size, + ).run(self._node_graph) + ``` + And change the explicit orchestrator path (line 413) to also pass the graph: + ```python + if orchestrator is not None: + orchestrator.run(self._node_graph) + ``` + +2. Delete the `_run_async` method (lines 469-474). + +- [ ] **Step 2: Run all pipeline tests** + +Run: `uv run pytest tests/test_pipeline/ -v` +Expected: All PASS. + +- [ ] **Step 3: Run full test suite** + +Run: `uv run pytest tests/ -x -q` +Expected: All PASS. + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/pipeline/graph.py +git commit -m "refactor(pipeline): update run() to pass graph to orchestrators, remove _run_async (PLT-922)" +``` + +## Chunk 4: Parity Tests and Cleanup + +### Task 12: Update parity tests + +**Files:** +- Modify: `tests/test_pipeline/test_sync_orchestrator.py` + +- [ ] **Step 1: Update parity test signatures** + +In `tests/test_pipeline/test_sync_orchestrator.py`, class `TestSyncAsyncParity`: + +Update the async orchestrator calls to use the new interface: +```python +# Old: +AsyncPipelineOrchestrator().run(async_pipeline) + +# New: +async_pipeline.compile() +AsyncPipelineOrchestrator().run(async_pipeline._node_graph) +async_pipeline.flush() +``` + +Do this for both `test_linear_pipeline_parity` and `test_diamond_pipeline_parity`. + +- [ ] **Step 2: Run parity tests** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py::TestSyncAsyncParity -v` +Expected: PASS + +- [ ] **Step 3: Add materialize_results tests** + +Append to `tests/test_pipeline/test_sync_orchestrator.py`: + +```python +class TestMaterializeResults: + def test_sync_materialize_false_returns_empty(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="mat", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + orch = SyncPipelineOrchestrator() + result = orch.run(pipeline._node_graph, materialize_results=False) + assert result.node_outputs == {} + + def test_async_materialize_true_collects_all(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="mat_async", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + pipeline.compile() + orch = AsyncPipelineOrchestrator() + result = orch.run(pipeline._node_graph, materialize_results=True) + assert len(result.node_outputs) > 0 + + def test_async_materialize_false_returns_empty(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf = PythonPacketFunction(double_value, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="mat_async2", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="doubler") + + pipeline.compile() + orch = AsyncPipelineOrchestrator() + result = orch.run(pipeline._node_graph, materialize_results=False) + assert result.node_outputs == {} +``` + +- [ ] **Step 4: Run new tests** + +Run: `uv run pytest tests/test_pipeline/test_sync_orchestrator.py::TestMaterializeResults -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +```bash +git add tests/test_pipeline/test_sync_orchestrator.py +git commit -m "test(orchestrator): update parity tests + add materialize_results tests (PLT-922)" +``` + +### Task 13: Add async-specific tests (fan-out, terminal, error) + +**Files:** +- Modify: `tests/test_pipeline/test_orchestrator.py` (extend) + +- [ ] **Step 1: Add fan-out, terminal node, and error propagation tests** + +Append to `tests/test_pipeline/test_orchestrator.py`: + +```python +class TestAsyncOrchestratorFanOut: + """One source fans out to multiple downstream nodes.""" + + def test_fan_out_source_to_two_functions(self): + src = _make_source("key", "value", {"key": ["a", "b"], "value": [1, 2]}) + pf1 = PythonPacketFunction(double_value, output_keys="result") + pod1 = FunctionPod(pf1) + pf2 = PythonPacketFunction(double_value, output_keys="result") + pod2 = FunctionPod(pf2) + + pipeline = Pipeline(name="fanout", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod1(src, label="doubler1") + pod2(src, label="doubler2") + + pipeline.compile() + orch = AsyncPipelineOrchestrator() + result = orch.run(pipeline._node_graph, materialize_results=True) + pipeline.flush() + + fn_outputs = [ + v for k, v in result.node_outputs.items() if k.node_type == "function" + ] + assert len(fn_outputs) == 2 + for output in fn_outputs: + values = sorted([pkt.as_dict()["result"] for _, pkt in output]) + assert values == [2, 4] + + +class TestAsyncOrchestratorTerminalNode: + """Terminal nodes with no downstream should work correctly.""" + + def test_single_terminal_source(self): + """A pipeline with just a source (terminal) should work.""" + src = _make_source("key", "value", {"key": ["a"], "value": [1]}) + pipeline = Pipeline(name="terminal", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + # Just register the source, no downstream + pass + + # Manually build a minimal graph with just a source node + import networkx as nx + from orcapod.core.nodes import SourceNode + + node = SourceNode(src) + G = nx.DiGraph() + G.add_node(node) + + orch = AsyncPipelineOrchestrator() + result = orch.run(G, materialize_results=True) + assert len(result.node_outputs) == 1 + + +class TestAsyncOrchestratorErrorPropagation: + """Node failures should propagate correctly.""" + + def test_node_failure_propagates(self): + def failing_fn(value: int) -> int: + raise ValueError("intentional failure") + + src = _make_source("key", "value", {"key": ["a"], "value": [1]}) + pf = PythonPacketFunction(failing_fn, output_keys="result") + pod = FunctionPod(pf) + + pipeline = Pipeline(name="error", pipeline_database=InMemoryArrowDatabase()) + with pipeline: + pod(src, label="failer") + + pipeline.compile() + orch = AsyncPipelineOrchestrator() + + with pytest.raises(ExceptionGroup): + orch.run(pipeline._node_graph) +``` + +- [ ] **Step 2: Run the new tests** + +Run: `uv run pytest tests/test_pipeline/test_orchestrator.py::TestAsyncOrchestratorFanOut tests/test_pipeline/test_orchestrator.py::TestAsyncOrchestratorTerminalNode tests/test_pipeline/test_orchestrator.py::TestAsyncOrchestratorErrorPropagation -v` +Expected: PASS + +- [ ] **Step 3: Commit** + +```bash +git add tests/test_pipeline/test_orchestrator.py +git commit -m "test(async-orchestrator): add fan-out, terminal, and error propagation tests (PLT-922)" +``` + +### Task 14: Full test suite verification and cleanup + +- [ ] **Step 1: Verify no references to removed protocol methods** + +Run these searches: +```bash +uv run grep -r "AsyncExecutableProtocol" src/ tests/ +uv run grep -r "SourceNodeProtocol.*iter_packets" src/ +uv run grep -r "FunctionNodeProtocol.*get_cached_results" src/ +uv run grep -r "FunctionNodeProtocol.*execute_packet" src/ +uv run grep -r "FunctionNodeProtocol.*compute_pipeline_entry_id" src/ +uv run grep -r "OperatorNodeProtocol.*get_cached_output" src/ +``` + +Expected: No matches. + +- [ ] **Step 2: Run full test suite one final time** + +Run: `uv run pytest tests/ -q` +Expected: All pass. + +- [ ] **Step 3: Final commit if needed** + +```bash +git add -u +git commit -m "chore(cleanup): remove stale references to old protocol methods (PLT-922)" +``` diff --git a/superpowers/specs/2026-03-14-node-authority-design.md b/superpowers/specs/2026-03-14-node-authority-design.md new file mode 100644 index 00000000..c9b3f9a0 --- /dev/null +++ b/superpowers/specs/2026-03-14-node-authority-design.md @@ -0,0 +1,234 @@ +# Node Authority — Self-Validating, Self-Persisting Nodes + +## Overview + +Nodes become authoritative executors: they validate inputs, compute, persist, and cache +results internally. The orchestrator no longer calls `store_result` or accesses internal +pods directly. Instead, it calls `process_packet` or `process` on the node, and the node +handles everything. + +This refactoring also introduces a clear **pod vs node** distinction: +- **Pod**: computation definition. `process() → StreamProtocol` (lazy, deferred). +- **Node**: computation executor. `process() → list[tuple[Tag, Packet]]` (eager, + materialized, with schema validation + persistence + caching). + +## Motivation + +The previous design had the orchestrator reaching into `node.operator` to call the pod +directly, then telling the node to `store_result`. This broke encapsulation: the +orchestrator was acting as an intermediary between the node and its own pod. The node +should be the sole authority over its results — it should process input, decide if it's +valid, and handle all persistence internally. + +**Schema validation** is the trust mechanism. Each node knows its expected input schema +(including system tags, which encode pipeline topology). If incoming data matches, the +node can treat the results as its own and persist them. If not, the input doesn't belong +to this node. + +## Changes + +### Remove from all node types + +- `store_result()` — removed entirely. Persistence happens inside `process` / `process_packet`. + +### Remove from protocols + +- `store_result` from `SourceNodeProtocol`, `FunctionNodeProtocol`, `OperatorNodeProtocol` +- `operator` property from `OperatorNodeProtocol` (orchestrator no longer accesses it) + +### FunctionNode + +**`process_packet(tag, packet) → tuple[Tag, Packet | None]`** + +Reverts to the original bundled behavior (pre-split), plus schema validation: + +1. Validate tag + packet schema (including system tags) against expected input schema + from `self._input_stream.output_schema()` / `self._input_stream.keys()`. +2. Compute via `CachedFunctionPod` (function-level memoization in result DB) or raw + `FunctionPod` (no DB). +3. Write pipeline provenance record (via `add_pipeline_record`). +4. Cache result in `_cached_output_packets`. +5. Return `(tag_out, output_packet)`. + +Validation checks: tag column names + types, packet column names + types, system tag +column names. System tags are critical because they encode pipeline topology — mismatched +system tags mean the data came from a different pipeline path. + +**`process(input_stream) → list[tuple[Tag, Packet]]`** (NEW) + +Bulk entry point that validates schema once against the stream, then iterates packets +using an internal unchecked path: + +1. Validate stream schema (tag + packet + system tags) against expected input. +2. For each (tag, packet) in stream: call `_process_packet_internal(tag, packet)` — + same as `process_packet` but skips per-packet schema validation. +3. Return materialized list of `(tag, output_packet)` pairs (excluding None outputs). + +This is more efficient when per-packet observer hooks aren't needed. + +**Internal structure:** + +``` +process_packet(tag, pkt) + → _validate_input_schema(tag, pkt) + → _process_packet_internal(tag, pkt) + +process(input_stream) + → _validate_stream_schema(input_stream) + → for tag, pkt in input_stream.iter_packets(): + _process_packet_internal(tag, pkt) + → return materialized results + +_process_packet_internal(tag, pkt) + → compute (CachedFunctionPod or FunctionPod) + → write pipeline record (if DB attached) + → cache in _cached_output_packets + → return (tag_out, output_packet) +``` + +### OperatorNode + +**`process(*input_streams) → list[tuple[Tag, Packet]]`** (NEW) + +Replaces the orchestrator's direct access to `node.operator`: + +1. Validate each input stream's schema (tag + system tags + packet) against expected + upstream schemas from `self._input_streams`. +2. Compute via `self._operator.process(*input_streams)`. +3. Materialize results. +4. Persist to pipeline DB (if LOG mode, via existing `_store_output_stream`). +5. Cache in `_cached_output_stream`. +6. Return materialized list of `(tag, output_packet)` pairs. + +**Remove**: public `operator` property (orchestrator no longer needs it). + +### SourceNode + +- Remove `store_result()`. The source is authoritative by definition — it produces its + own data via `iter_packets()`. +- The orchestrator materializes the buffer from `iter_packets()`. If caching is needed + for transient sources, `iter_packets()` can cache internally on first call (existing + `_cached_results` field supports this — set during first iteration). + +### Node Protocols (updated) + +```python +class SourceNodeProtocol(Protocol): + node_type: str + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: ... + +class FunctionNodeProtocol(Protocol): + node_type: str + def get_cached_results(self, entry_ids: list[str]) -> dict[str, tuple[TagProtocol, PacketProtocol]]: ... + def compute_pipeline_entry_id(self, tag: TagProtocol, packet: PacketProtocol) -> str: ... + def process_packet(self, tag: TagProtocol, packet: PacketProtocol) -> tuple[TagProtocol, PacketProtocol | None]: ... + def process(self, input_stream: StreamProtocol) -> list[tuple[TagProtocol, PacketProtocol]]: ... + +class OperatorNodeProtocol(Protocol): + node_type: str + def get_cached_output(self) -> StreamProtocol | None: ... + def process(self, *input_streams: StreamProtocol) -> list[tuple[TagProtocol, PacketProtocol]]: ... +``` + +### SyncPipelineOrchestrator (updated) + +The orchestrator no longer calls `store_result` or accesses `node.operator`. Updated +execution paths: + +**Source execution:** +```python +def _execute_source(self, node): + self._observer.on_node_start(node) + output = list(node.iter_packets()) + self._observer.on_node_end(node) + return output +``` + +**Function execution (with observer hooks):** +```python +def _execute_function(self, node, upstream_buffer): + self._observer.on_node_start(node) + # Phase 1: pipeline-level cache lookup + upstream_entries = [ + (tag, pkt, node.compute_pipeline_entry_id(tag, pkt)) + for tag, pkt in upstream_buffer + ] + cached = node.get_cached_results([eid for _, _, eid in upstream_entries]) + + output = [] + for tag, pkt, entry_id in upstream_entries: + self._observer.on_packet_start(node, tag, pkt) + if entry_id in cached: + tag_out, result = cached[entry_id] + self._observer.on_packet_end(node, tag, pkt, result, cached=True) + output.append((tag_out, result)) + else: + tag_out, result = node.process_packet(tag, pkt) + self._observer.on_packet_end(node, tag, pkt, result, cached=False) + if result is not None: + output.append((tag_out, result)) + + self._observer.on_node_end(node) + return output +``` + +**Operator execution:** +```python +def _execute_operator(self, node, upstream_buffers): + self._observer.on_node_start(node) + cached = node.get_cached_output() + if cached is not None: + output = list(cached.iter_packets()) + else: + input_streams = [ + self._materialize_as_stream(buf, upstream_node) + for buf, upstream_node in upstream_buffers + ] + output = node.process(*input_streams) # node handles everything + self._observer.on_node_end(node) + return output +``` + +### Pipeline.run() + +No changes needed beyond what's already there (previous refactoring already removed +`_apply_results`). + +### Backward Compatibility + +- `FunctionNode.run()` and `iter_packets()` continue to work for the non-orchestrated + pull-based path. They use `_process_packet_internal` (same as the orchestrator path). +- `OperatorNode.run()` continues to work for the non-orchestrated path. +- The existing `_process_and_store_packet` on FunctionNode can be replaced by + `_process_packet_internal` (same semantics). + +## Schema Validation Details + +Validation checks the following against the node's expected input: + +1. **Tag column names** — must match exactly (user-defined tag columns). +2. **Tag column types** — must match the expected Arrow types. +3. **System tag column names** — must match (these encode pipeline topology via + `_tag::source:...` naming with pipeline hash extensions). +4. **Packet column names** — must match. +5. **Packet column types** — must match. + +Validation raises `InputValidationError` (or similar) with a clear message indicating +which columns/types don't match. + +For `process_packet(tag, packet)`: validate from the Tag and Packet datagram objects +directly (they expose `keys()`, `arrow_schema()`, etc.). + +For `process(input_stream)` / `process(*input_streams)`: validate from the stream's +`output_schema()` method, which is cheaper (once per stream, not per packet). + +## Testing + +- Update `process_packet` tests to verify it writes pipeline records again (revert the + "purity" tests — `process_packet` is no longer pure). +- Add schema validation tests (valid schema passes, mismatched schema raises). +- Add `OperatorNode.process()` tests (basic computation, DB persistence, caching). +- Add `FunctionNode.process()` tests (bulk processing, single schema validation). +- Remove all `store_result` tests. +- Verify orchestrator integration tests still pass. +- Verify backward compat: `node.run()` and `iter_packets()` still work. diff --git a/superpowers/specs/2026-03-14-remove-populate-cache-design.md b/superpowers/specs/2026-03-14-remove-populate-cache-design.md new file mode 100644 index 00000000..986aa3da --- /dev/null +++ b/superpowers/specs/2026-03-14-remove-populate-cache-design.md @@ -0,0 +1,82 @@ +# Remove populate_cache — Self-Caching Nodes Design + +## Overview + +Remove the `populate_cache` method from all node types. Instead, nodes build their +in-memory cache as a natural side effect of the orchestrator's existing calls to +`get_cached_results`, `process_packet`, and `store_result`. This simplifies the +node protocol and eliminates the need for `Pipeline._apply_results()`. + +## Motivation + +The orchestrator already calls `get_cached_results`, `process_packet`, and `store_result` +on nodes during execution. Each of these is an opportunity for the node to build its +in-memory cache internally, rather than relying on an external `populate_cache` call. + +The current flow has unnecessary indirection: +``` +orchestrator.run() → returns OrchestratorResult +Pipeline._apply_results() → calls node.populate_cache() for each node +``` + +The simpler flow: +``` +orchestrator.run() → nodes self-cache during execution → done +``` + +## Changes + +### Node Protocol Changes + +Remove `populate_cache` from all three protocols: +- `SourceNodeProtocol`: remove `populate_cache` +- `FunctionNodeProtocol`: remove `populate_cache` +- `OperatorNodeProtocol`: remove `populate_cache` + +### SourceNode + +- Remove `populate_cache()` method +- `store_result(results)`: now sets `self._cached_results = list(results)` in addition + to any future DB persistence. This ensures `iter_packets()` returns from cache after + orchestrated execution. + +### FunctionNode + +- Remove `populate_cache()` method +- `get_cached_results(entry_ids)`: after retrieving cached results, populates + `_cached_output_packets` with the returned entries. +- `store_result(tag, input_packet, output_packet)`: after writing pipeline record, + adds the result to `_cached_output_packets`. Also sets `_needs_iterator = False` + and `_cached_input_iterator = None` to indicate iteration state is managed externally. + +### OperatorNode + +- Remove `populate_cache()` method +- `get_cached_output()`: already sets `_cached_output_stream` via `_replay_from_cache()`. + No change needed. +- `store_result(results)`: now also sets `_cached_output_stream` from the materialized + results (same as what `populate_cache` did). This ensures `iter_packets()` / `as_table()` + work after orchestrated execution. + +### Pipeline.run() + +- Remove `_apply_results()` method (no longer needed) +- `run()` no longer calls `_apply_results()` — nodes are self-cached + +### OrchestratorResult + +- Keep as-is. The orchestrator still returns results for programmatic inspection. + The caller may want to examine what was produced without going through node accessors. + +### SyncPipelineOrchestrator + +- No changes needed. It already calls `store_result` and `get_cached_results` on nodes. + The orchestrator is unaware of caching — that's the node's concern. + +## Testing + +- Remove all `populate_cache` tests +- Update `store_result` tests to verify internal cache is populated +- Update `get_cached_results` tests to verify internal cache is populated +- Verify `iter_packets()` / `as_table()` work after orchestrated execution (existing + integration tests already cover this via `Pipeline.run()`) diff --git a/superpowers/specs/2026-03-14-sync-orchestrator-design.md b/superpowers/specs/2026-03-14-sync-orchestrator-design.md new file mode 100644 index 00000000..2d6af36d --- /dev/null +++ b/superpowers/specs/2026-03-14-sync-orchestrator-design.md @@ -0,0 +1,615 @@ +# Sync Pipeline Orchestrator Design + +## Overview + +Design a synchronous pipeline orchestrator that provides fine-grained control over pipeline +execution, including per-packet observability hooks for function nodes. The orchestrator +operates on a graph of nodes, drives execution externally (rather than relying on pull-based +iteration), and returns computed results for all nodes. + +This design also introduces node protocols that formalize the interface between orchestrator +and nodes, refactors caching/DB logic out of monolithic `run()`/`iter_packets()` methods into +reusable building blocks, and simplifies `Pipeline.run()` to delegate to an orchestrator. + +### Goals + +- Sync orchestrator with per-packet hooks for logging, metrics, and debugging. +- Clean separation: orchestrator controls scheduling; nodes handle computation and persistence. +- Uniform compute/store pattern across all node types. +- Node protocols so the orchestrator is decoupled from concrete node classes. +- Orchestrator operates on a graph of nodes, not on the `Pipeline` object directly. +- Orchestrator returns results; pipeline decides how to apply them to node caches. +- Memory-saving mode that skips result accumulation and only persists to databases. + +### Out of Scope + +- Refactoring `AsyncPipelineOrchestrator` to use the new protocols (deferred until both + implementations exist and can inform a shared protocol). +- Shared orchestrator base/protocol between sync and async. +- Selective cache population policies in `Pipeline._apply_results()`. + +## Design + +### Conceptual Model + +Two orthogonal axes govern pipeline execution: + +| | Pull (consumer drives) | Push (producer drives) | +|--------------|----------------------------|--------------------------------| +| **Sync** | `iter_packets()` iterators | Orchestrator with buffers | +| **Async** | `async for` via `__aiter__`| `async_execute()` with channels| + +The existing codebase has sync-pull (`iter_packets()`) and async-push (`async_execute()`). +This design adds **sync-push** via `SyncPipelineOrchestrator`. + +The pipeline graph's stream wiring declares **topology** (what connects to what). The +orchestrator creates its own **execution transport** (buffers) to control data flow at +runtime. The stream wiring remains valuable for standalone/ad-hoc use without an orchestrator. + +### Uniform Compute/Store Pattern + +All node types follow the same two-step pattern during orchestrated execution: + +``` +1. Compute → produce output data +2. Store → persist pipeline-level records to DB (node decides what to do; no-op if no DB) +``` + +The orchestrator always calls both steps for every node. It does not check whether a node +has a database or any storage capability — that is the node's concern. This keeps the +orchestrator simple and the node protocol uniform. + +| Node Type | Compute | Compute DB effects | Store (pipeline-level) | +|-------------|-------------------------------------------|------------------------------------------|------------------------------------------| +| SourceNode | `iter_packets()` | None | `store_result(results)` — snapshot (future) | +| FunctionNode| `process_packet(tag, packet)` per-packet | Function-level result memoization (result DB) | `store_result(tag, input_packet, output)` — pipeline provenance record | +| OperatorNode| `operator.process(*input_streams)` | None | `store_result(results)` — pipeline cache (LOG mode) | + +**Two levels of persistence for FunctionNode:** + +1. **Function-level result memoization** (result DB, via `CachedFunctionPod`): "This + function, given this input packet, produces this output." This is a property of the + function itself, not of any specific pipeline. It happens inside `process_packet` and + is the function pod's own concern. + +2. **Pipeline-level provenance** (pipeline DB, via `add_pipeline_record`): "In this pipeline + run, this (tag + system_tags + packet) was processed and points to result record Y." This + is pipeline bookkeeping, handled by `store_result`. + +This distinction matters: function-level memoization transcends any single pipeline — if the +same function runs in a different pipeline with the same input, the cached result is reused. +Pipeline provenance is specific to a pipeline execution context. + +### Post-Execution Contract + +- **Orchestrator's job**: execute the graph, invoke `store_result` on every node after + computation, and return computed results for all nodes. +- **Pipeline's job**: receive the orchestrator's results and decide per-node whether to + populate in-memory caches (making `iter_packets()` / `as_table()` work after execution). +- **DB-backed nodes**: results persisted to DB during orchestrator execution AND optionally + cached in memory after (via pipeline's `_apply_results`). +- **Non-DB nodes**: `store_result` is a no-op; results available only if the pipeline + populates their cache from the orchestrator's returned results. +- **`materialize_results=False` mode**: orchestrator discards buffers after all downstream + consumers have read them. `OrchestratorResult.node_outputs` is empty. Only DB-persisted + results survive. This is an explicit trade-off: non-DB nodes lose all data after execution. + +### Component Flow + +``` +Pipeline.run(orchestrator=None) + │ + │ 1. compile() if needed + │ 2. default orchestrator = SyncPipelineOrchestrator() + │ + ▼ +orchestrator.run(node_graph) + │ + │ For each node in topological order: + │ + ├── SourceNode: + │ observer.on_node_start(node) + │ output = materialize node.iter_packets() ← compute + │ node.store_result(output) ← store (no-op if no DB) + │ observer.on_node_end(node) + │ + ├── FunctionNode: + │ observer.on_node_start(node) + │ compute entry_ids from upstream buffer + │ cached = node.get_cached_results(entry_ids) ← pipeline DB read + │ for each (tag, packet) in upstream buffer: + │ observer.on_packet_start(node, tag, packet) + │ if entry_id in cached: + │ use cached result ← pipeline-level cache hit + │ observer.on_packet_end(..., cached=True) + │ else: + │ tag_out, result = node.process_packet(...) ← compute (+ function-level memoization in result DB) + │ node.store_result(tag, packet, result) ← pipeline provenance record only + │ observer.on_packet_end(..., cached=False) + │ observer.on_node_end(node) + │ + └── OperatorNode: + observer.on_node_start(node) + cached = node.get_cached_output() ← DB read (REPLAY mode) + if cached: + output = materialize cached stream + else: + output = node.operator.process(*input_streams) ← compute (pure) + node.store_result(output) ← store (LOG → write; OFF → no-op) + observer.on_node_end(node) + │ + ▼ +returns OrchestratorResult(node_outputs=buffers) + │ + ▼ +Pipeline._apply_results(result) + │ for each node, calls node.populate_cache(outputs) + │ + ▼ +Pipeline.flush() ← flush all databases +``` + +### Node Protocols + +Three protocols formalize the orchestrator-node interface. Each matches a fundamentally +different execution model. The orchestrator dispatches via `TypeGuard` functions that check +the existing `node_type` string attribute — cheap at runtime, with full type narrowing for +static analysis. + +All protocols share `store_result` (persistence) and `populate_cache` (in-memory caching) +as the uniform post-computation interface. The orchestrator always calls `store_result` +after computation — nodes decide internally what to do (write to DB or no-op). + +#### `SourceNodeProtocol` + +Provides data with no computation. The orchestrator materializes its output into a buffer. + +```python +class SourceNodeProtocol(Protocol): + node_type: str # "source" + + def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: ... + def store_result( + self, results: list[tuple[TagProtocol, PacketProtocol]] + ) -> None: ... + def populate_cache( + self, results: list[tuple[TagProtocol, PacketProtocol]] + ) -> None: ... +``` + +`store_result`: persists a snapshot of the source data to the pipeline DB if configured. +No-op if no DB is attached. Useful for transient/updatable sources where you want a record +of what the pipeline actually consumed. + +#### `FunctionNodeProtocol` + +Per-packet computation with optional DB caching. The orchestrator drives iteration externally, +calling `process_packet()` for each input with observer hooks. + +`process_packet()` handles computation and function-level result memoization (via +`CachedFunctionPod` when DB is attached). This memoization is the function pod's own +concern — it transcends any single pipeline. `store_result()` handles pipeline-level +provenance recording only. + +```python +class FunctionNodeProtocol(Protocol): + node_type: str # "function" + + def get_cached_results( + self, entry_ids: list[str] + ) -> dict[str, tuple[TagProtocol, PacketProtocol]]: ... + + def compute_pipeline_entry_id( + self, tag: TagProtocol, packet: PacketProtocol + ) -> str: ... + + def process_packet( + self, tag: TagProtocol, packet: PacketProtocol + ) -> tuple[TagProtocol, PacketProtocol | None]: ... + + def store_result( + self, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, + ) -> None: ... + + def populate_cache( + self, results: list[tuple[TagProtocol, PacketProtocol]] + ) -> None: ... +``` + +`store_result`: adds a pipeline provenance record to the pipeline DB (via +`add_pipeline_record`). Does NOT write to the result DB — that is handled by +`process_packet` via `CachedFunctionPod`. No-op if no DB is attached or output is None. + +#### `OperatorNodeProtocol` + +Whole-stream computation with cache modes. The orchestrator extracts the operator pod and +invokes it with orchestrator-prepared streams. + +```python +class OperatorNodeProtocol(Protocol): + node_type: str # "operator" + operator: OperatorPodProtocol + + def get_cached_output(self) -> StreamProtocol | None: ... + def store_result( + self, results: list[tuple[TagProtocol, PacketProtocol]] + ) -> None: ... + def populate_cache( + self, results: list[tuple[TagProtocol, PacketProtocol]] + ) -> None: ... +``` + +`store_result`: accepts a materialized list of (tag, packet) pairs. If cache mode is LOG, +wraps the list as an `ArrowTableStream` and writes to DB. No-op for OFF mode. This avoids +the double-consumption problem where materializing a stream for the buffer would exhaust it +before storage could read it. + +#### TypeGuard Dispatch + +```python +from typing import TypeGuard + +def is_source_node(node: GraphNode) -> TypeGuard[SourceNodeProtocol]: + return node.node_type == "source" + +def is_function_node(node: GraphNode) -> TypeGuard[FunctionNodeProtocol]: + return node.node_type == "function" + +def is_operator_node(node: GraphNode) -> TypeGuard[OperatorNodeProtocol]: + return node.node_type == "operator" +``` + +The dispatch chain must include an `else` branch that raises `TypeError` for unknown node +types to prevent silent skipping if new node types are added. + +### ExecutionObserver Protocol + +```python +class ExecutionObserver(Protocol): + def on_node_start(self, node: GraphNode) -> None: ... + def on_node_end(self, node: GraphNode) -> None: ... + def on_packet_start( + self, node: GraphNode, tag: TagProtocol, packet: PacketProtocol + ) -> None: ... + def on_packet_end( + self, + node: GraphNode, + tag: TagProtocol, + input_packet: PacketProtocol, + output_packet: PacketProtocol | None, + cached: bool, + ) -> None: ... +``` + +`on_packet_start` / `on_packet_end` are only invoked for function nodes (the only node type +with per-packet granularity). `on_node_start` / `on_node_end` are invoked for all node types. + +The `tag` parameter in `on_packet_start` is the **input** tag. In `on_packet_end`, `tag` is +also the **input** tag (the output tag is available via `output_packet` if needed, but since +function nodes pass tags through unchanged, they are the same). + +Default implementation: `NoOpObserver` with empty method bodies. + +### OrchestratorResult + +```python +@dataclass +class OrchestratorResult: + node_outputs: dict[GraphNode, list[tuple[TagProtocol, PacketProtocol]]] +``` + +When `materialize_results=False`, `node_outputs` is empty. + +### SyncPipelineOrchestrator + +```python +class SyncPipelineOrchestrator: + def __init__(self, observer: ExecutionObserver | None = None) -> None: + self._observer = observer or NoOpObserver() + + def run( + self, + graph: nx.DiGraph, + materialize_results: bool = True, + ) -> OrchestratorResult: + ... +``` + +#### Execution Logic + +```python +def run(self, graph, materialize_results=True): + topo_order = list(nx.topological_sort(graph)) + buffers: dict[GraphNode, list[tuple[Tag, Packet]]] = {} + + for node in topo_order: + if is_source_node(node): + buffers[node] = self._execute_source(node) + elif is_function_node(node): + upstream_buffer = self._gather_upstream(node, graph, buffers) + buffers[node] = self._execute_function(node, upstream_buffer) + elif is_operator_node(node): + upstream_buffers = self._gather_upstream_multi(node, graph, buffers) + buffers[node] = self._execute_operator(node, upstream_buffers) + else: + raise TypeError(f"Unknown node type: {node.node_type!r}") + + # Memory-saving: discard buffers no longer needed by any downstream + if not materialize_results: + self._gc_buffers(node, graph, buffers) + + return OrchestratorResult(node_outputs=buffers) +``` + +#### Error Handling + +Exceptions from node execution (source iteration, `process_packet`, operator `process`) +propagate immediately. The orchestrator does not attempt partial completion or cleanup. +Buffers for already-completed nodes remain in memory; DB-persisted results from completed +nodes survive. This is consistent with the existing sync execution path where an exception +in `node.run()` halts the pipeline. + +#### Source Execution + +```python +def _execute_source(self, node): + self._observer.on_node_start(node) + output = list(node.iter_packets()) + node.store_result(output) # no-op if no DB + self._observer.on_node_end(node) + return output +``` + +#### Function Execution + +```python +def _execute_function(self, node, upstream_buffer): + self._observer.on_node_start(node) + + # Compute entry IDs for current upstream + upstream_entries = [ + (tag, packet, node.compute_pipeline_entry_id(tag, packet)) + for tag, packet in upstream_buffer + ] + entry_ids = [eid for _, _, eid in upstream_entries] + + # Phase 1: targeted cache lookup (DB read) + cached = node.get_cached_results(entry_ids=entry_ids) + + output = [] + for tag, packet, entry_id in upstream_entries: + self._observer.on_packet_start(node, tag, packet) + if entry_id in cached: + tag_out, result = cached[entry_id] + self._observer.on_packet_end(node, tag, packet, result, cached=True) + output.append((tag_out, result)) + else: + # process_packet is pure computation — no DB side effects + tag_out, result = node.process_packet(tag, packet) + # store_result handles all DB writes (result cache + pipeline record) + node.store_result(tag, packet, result) + self._observer.on_packet_end(node, tag, packet, result, cached=False) + if result is not None: + output.append((tag_out, result)) + + self._observer.on_node_end(node) + return output +``` + +#### Operator Execution + +```python +def _execute_operator(self, node, upstream_buffers): + self._observer.on_node_start(node) + + cached = node.get_cached_output() # DB read (REPLAY mode only) + if cached is not None: + output = list(cached.iter_packets()) + else: + input_streams = [ + self._materialize_as_stream(buf, node) for buf in upstream_buffers + ] + result_stream = node.operator.process(*input_streams) + output = list(result_stream.iter_packets()) + node.store_result(output) # LOG → write; OFF → no-op + + self._observer.on_node_end(node) + return output +``` + +Note: operator input validation is not performed by the orchestrator. Validation occurs at +compile time (`OperatorNode.__init__` calls `self._operator.validate_inputs()`). The +orchestrator-prepared streams have the same schema as the original inputs, so revalidation +is unnecessary. + +#### `_materialize_as_stream` + +Wraps a `list[tuple[Tag, Packet]]` buffer as an `ArrowTableStream`. Implementation: + +1. Extract Arrow tables from each Tag and Packet datagram. +2. Horizontal-stack tag columns + packet columns (including source info and system tags). +3. Tag column names come from the **upstream** node's output schema (available via + `upstream.keys()` or from the tag objects themselves). +4. Construct `ArrowTableStream(combined_table, tag_columns=tag_keys)`. + +For operators with multiple inputs, each upstream buffer is materialized separately using +that upstream's tag column names. + +#### Buffer GC (materialize_results=False) + +After processing each node, iterate over its predecessors in the graph. For each predecessor, +check if all of its successors have been processed (i.e., appear earlier in topological +order or are the current node). If so, delete the predecessor's buffer. This is O(edges) +per node in the worst case. + +### Node Refactoring + +#### FunctionNode Changes + +**`process_packet(tag, packet)` — computation + function-level memoization** + +Refactored to remove pipeline record writing. Now delegates to `CachedFunctionPod` +(when DB is attached) or raw `FunctionPod` (when not). `CachedFunctionPod` handles +function-level result memoization internally: it checks if this input packet's content +hash has been computed before, and if not, computes and stores the result in the result +DB. This memoization is the function pod's own concern — it transcends any single pipeline. + +Pipeline provenance recording (which was previously bundled into `process_packet`) is +extracted into `store_result`. + +**`store_result(tag, input_packet, output_packet) -> None`** + +Handles pipeline-level provenance recording only: + +1. Adds a pipeline provenance record to the pipeline DB (tag + system tags + input packet + hash → output packet record ID) via `add_pipeline_record`. +2. Does NOT write to the result DB — that is `CachedFunctionPod`'s concern, handled + during `process_packet`. +3. No-op if no pipeline DB is attached or output is None. + +**`get_cached_results(entry_ids: list[str]) -> dict[str, tuple[Tag, Packet]]`** + +Factored out of `iter_packets()` Phase 1 logic. Implementation: + +1. Fetch all pipeline records from `pipeline_database.get_all_records(pipeline_path)`. +2. Fetch all result records from `result_database.get_all_records(record_path)`. +3. Join on `PACKET_RECORD_ID` (same Polars join as current `iter_packets` Phase 1). +4. Filter to only rows whose pipeline entry ID is in the requested `entry_ids` set. +5. Reconstruct (tag, output_packet) pairs from the filtered rows. +6. Return as `dict[entry_id, (tag, packet)]`. + +When no DB is attached, returns `{}`. + +Note: this method fetches all records and filters in memory. The database protocol does not +currently support filtered lookups by entry ID. If the DB grows very large, a filtered +lookup method could be added to the database protocol in the future, but that is out of +scope for this design. + +**`populate_cache(results: list[tuple[Tag, Packet]]) -> None`** + +Populates `_cached_output_packets` dict from externally-provided results so that +`iter_packets()` / `as_table()` work after orchestrated execution. Sets +`_cached_input_iterator = None` and `_needs_iterator = False` to indicate iteration is +complete. + +#### OperatorNode Changes + +**`get_cached_output() -> StreamProtocol | None`** + +Returns the cached output stream when in REPLAY mode and DB records exist. Returns `None` +otherwise. Factored out of `run()`. Wraps the existing `_replay_from_cache()` logic. + +**`store_result(results: list[tuple[Tag, Packet]]) -> None`** + +Accepts a materialized list of (tag, packet) pairs. If cache mode is LOG, wraps the list +as an `ArrowTableStream` and calls the existing `_store_output_stream()` to write to DB. +No-op for OFF mode. This avoids the double-consumption problem: the orchestrator materializes +the stream into a list first, then passes the list to both the buffer and `store_result`. + +**`populate_cache(results: list[tuple[Tag, Packet]]) -> None`** + +Wraps results as an `ArrowTableStream` and sets `_cached_output_stream` so that +`iter_packets()` / `as_table()` work after orchestrated execution. + +#### SourceNode Changes + +**`store_result(results: list[tuple[Tag, Packet]]) -> None`** + +Persists a snapshot of the source data to the pipeline DB if configured. No-op if no DB +is attached. Useful for transient/updatable sources where you want a record of what the +pipeline actually consumed. + +**`populate_cache(results: list[tuple[Tag, Packet]]) -> None`** + +Adds a `_cached_results: list[tuple[Tag, Packet]] | None` field to `SourceNode`. When +populated, `iter_packets()` returns from this cache instead of delegating to +`self.stream.iter_packets()`. This requires modifying `SourceNode.iter_packets()` to check +the cache first. The existing delegation path remains the default when cache is `None`. + +### Pipeline.run() Changes + +The current `Pipeline.run()` signature is: + +```python +def run(self, config=None, execution_engine=None, execution_engine_opts=None) +``` + +For this iteration, `orchestrator` is added as a new parameter. The existing parameters +are preserved for backward compatibility. When `orchestrator` is provided, it takes +precedence: + +```python +def run(self, orchestrator=None, config=None, execution_engine=None, + execution_engine_opts=None): + if not self._compiled: + self.compile() + + if execution_engine is not None: + self._apply_execution_engine(execution_engine, execution_engine_opts) + + if orchestrator is not None: + result = orchestrator.run(self._node_graph) + self._apply_results(result) + elif use_async: # existing logic based on config/engine + self._run_async(config) + else: + orchestrator = SyncPipelineOrchestrator() + result = orchestrator.run(self._node_graph) + self._apply_results(result) + + self.flush() +``` + +The long-term goal is to simplify this to just `run(self, orchestrator=None)` once the +async orchestrator is also refactored. For now, the existing parameters are preserved to +avoid breaking changes. + +**`_apply_results(result: OrchestratorResult)`**: walks `result.node_outputs` and calls +`node.populate_cache(outputs)` for each node. Initially unconditional; selective policies +can be added later. + +### What Stays Unchanged + +- `iter_packets()` on all node types — standalone pull-based path, untouched (except + `SourceNode.iter_packets()` gains a cache check at the top). +- `async_execute()` on all node types — standalone push-based path, untouched. +- `AsyncPipelineOrchestrator` — untouched for now. +- `FunctionNode.run()` — still works for the non-orchestrated path (consumes `iter_packets()`). +- `OperatorNode.run()` — still works for the non-orchestrated path. + +### File Organization + +- `src/orcapod/pipeline/observer.py` — `ExecutionObserver` protocol, `NoOpObserver` +- `src/orcapod/pipeline/sync_orchestrator.py` — `SyncPipelineOrchestrator` +- `src/orcapod/pipeline/async_orchestrator.py` — existing `AsyncPipelineOrchestrator` + (renamed from `orchestrator.py`) +- `src/orcapod/pipeline/result.py` — `OrchestratorResult` dataclass +- `src/orcapod/protocols/node_protocols.py` — `SourceNodeProtocol`, `FunctionNodeProtocol`, + `OperatorNodeProtocol`, TypeGuard functions + +### Testing Strategy + +- **Unit tests for SyncPipelineOrchestrator**: linear, diamond, fan-out topologies — same + test shapes as existing `test_orchestrator.py`. +- **Observer tests**: verify hooks fire in correct order with correct arguments. +- **Cache population tests**: verify `populate_cache` makes `iter_packets()` / `as_table()` + work after orchestrated execution. +- **materialize_results=False tests**: verify DB persistence without memory accumulation. + Verify that non-DB node data is inaccessible (explicit trade-off). +- **FunctionNode.get_cached_results tests**: targeted lookup returns correct subset; + empty DB returns `{}`; entry IDs not in DB are absent from result. +- **FunctionNode.store_result tests**: verify pipeline provenance record write only; + verify no-op when no DB attached; verify no-op when output is None. +- **FunctionNode.process_packet tests**: verify process_packet does NOT write pipeline + records (only store_result does); verify function-level memoization still works + (CachedFunctionPod caches results in result DB during process_packet). +- **OperatorNode.get_cached_output / store_result tests**: cache mode behavior (OFF → no-op, + LOG → writes, REPLAY → reads). +- **SourceNode.store_result tests**: verify DB snapshot write when configured; no-op otherwise. +- **Sync vs async parity tests**: same pipeline produces same DB results regardless of + orchestrator type. +- **Error propagation tests**: exception in mid-pipeline node halts execution; earlier + nodes' DB results survive. diff --git a/superpowers/specs/2026-03-15-async-orchestrator-refactor-design.md b/superpowers/specs/2026-03-15-async-orchestrator-refactor-design.md new file mode 100644 index 00000000..6ab65a38 --- /dev/null +++ b/superpowers/specs/2026-03-15-async-orchestrator-refactor-design.md @@ -0,0 +1,396 @@ +# Async Orchestrator Refactor Design + +PLT-922: Refactor AsyncPipelineOrchestrator to use node protocols and orchestrator interface. + +## Context + +The `SyncPipelineOrchestrator` (PLT-921) introduced node protocols +(`SourceNodeProtocol`, `FunctionNodeProtocol`, `OperatorNodeProtocol`) with +TypeGuard dispatch. The sync orchestrator currently drives per-packet execution +for function nodes — calling `get_cached_results`, `execute_packet`, and firing +observer hooks from the orchestrator side. + +The `AsyncPipelineOrchestrator` uses a different pattern: it calls a uniform +`node.async_execute(inputs, output)` on all nodes, letting nodes handle +everything internally via channels. + +This refactor aligns both orchestrators on a common design where: + +- Nodes own their execution (caching, per-packet logic, persistence). +- Orchestrators are topology schedulers that call `execute` / `async_execute`. +- Observability is achieved via observer injection, not orchestrator-driven hooks. + +## Design Decisions + +### Slim node protocols + +The three node protocols expose only `execute` (sync) and `async_execute` +(async). All per-packet methods (`get_cached_results`, `execute_packet`, +`compute_pipeline_entry_id`) are removed from the protocol surface — they remain +as internal methods on the node classes. + +### Observer injection via parameter + +Both `execute` and `async_execute` accept an optional `observer` keyword +argument. Nodes call the observer hooks internally (`on_node_start`, +`on_node_end`, `on_packet_start`, `on_packet_end`). The `ExecutionObserver` +protocol itself is unchanged. + +### Orchestrators are topology schedulers + +Neither orchestrator inspects packet content, manages caches, or drives +per-packet loops. They: + +1. Walk the graph in topological order. +2. Call `execute` or `async_execute` on each node with the correct inputs. +3. Collect results into `OrchestratorResult`. + +### Tightened async signatures per node type + +Instead of a uniform `async_execute(inputs: Sequence[ReadableChannel], output)` +for all nodes, each protocol has a signature matching its arity: + +- Source: `async_execute(output)` — no inputs +- Function: `async_execute(input_channel, output)` — single input +- Operator: `async_execute(inputs: Sequence[ReadableChannel], output)` — N inputs + +### Deferred: prefer_async and concurrency config + +Two features are deferred to follow-up issues: + +- **PLT-929**: `prefer_async` flag on `FunctionNode` — allows sync `execute()` + to internally use the async execution path when the pod/executor supports it. +- **PLT-930**: Move async concurrency config to node-level construction — + currently `FunctionNode.async_execute` receives `pipeline_config` for + `max_concurrency`. In this refactor, `pipeline_config` is removed from the + `async_execute` signature entirely. Nodes use their existing default + concurrency until PLT-930 adds proper node-level config. + +## Revised Node Protocols + +```python +@runtime_checkable +class SourceNodeProtocol(Protocol): + node_type: str # == "source" + + def execute( + self, *, observer: ExecutionObserver | None = None + ) -> list[tuple[TagProtocol, PacketProtocol]]: ... + + async def async_execute( + self, + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: ExecutionObserver | None = None, + ) -> None: ... + + +@runtime_checkable +class FunctionNodeProtocol(Protocol): + node_type: str # == "function" + + def execute( + self, + input_stream: StreamProtocol, + *, + observer: ExecutionObserver | None = None, + ) -> list[tuple[TagProtocol, PacketProtocol]]: ... + + async def async_execute( + self, + input_channel: ReadableChannel[tuple[TagProtocol, PacketProtocol]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: ExecutionObserver | None = None, + ) -> None: ... + + +@runtime_checkable +class OperatorNodeProtocol(Protocol): + node_type: str # == "operator" + + def execute( + self, + *input_streams: StreamProtocol, + observer: ExecutionObserver | None = None, + ) -> list[tuple[TagProtocol, PacketProtocol]]: ... + + async def async_execute( + self, + inputs: Sequence[ReadableChannel[tuple[TagProtocol, PacketProtocol]]], + output: WritableChannel[tuple[TagProtocol, PacketProtocol]], + *, + observer: ExecutionObserver | None = None, + ) -> None: ... +``` + +TypeGuard dispatch functions (`is_source_node`, `is_function_node`, +`is_operator_node`) remain unchanged. + +## Sync Orchestrator + +The orchestrator simplifies to a pure topology scheduler: + +```python +class SyncPipelineOrchestrator: + def __init__(self, observer=None): + self._observer = observer + + def run(self, graph, materialize_results=True) -> OrchestratorResult: + for node in topological_sort(graph): + if is_source_node(node): + buffers[node] = node.execute(observer=self._observer) + elif is_function_node(node): + stream = self._materialize_as_stream(buffers[pred], pred) + buffers[node] = node.execute(stream, observer=self._observer) + elif is_operator_node(node): + streams = [self._materialize_as_stream(buffers[p], p) + for p in sorted_preds] + buffers[node] = node.execute(*streams, observer=self._observer) + return OrchestratorResult(node_outputs=buffers) +``` + +`_materialize_as_stream` is retained — operators need `StreamProtocol` inputs. +`_gather_upstream`, `_gather_upstream_multi`, `_gc_buffers` helpers are retained. +The per-packet methods (`_execute_source`, `_execute_function`, +`_execute_operator`) are removed — each becomes a single `node.execute()` call. + +## Async Orchestrator + +The async orchestrator preserves the channel-based concurrent execution model +but uses TypeGuard dispatch with tightened per-type signatures. + +```python +class AsyncPipelineOrchestrator: + def __init__(self, observer=None, buffer_size=64): + self._observer = observer + self._buffer_size = buffer_size + + def run(self, graph, materialize_results=True) -> OrchestratorResult: + return asyncio.run(self._run_async(graph, materialize_results)) + + async def run_async(self, graph, materialize_results=True) -> OrchestratorResult: + return await self._run_async(graph, materialize_results) + + async def _run_async(self, graph, materialize_results) -> OrchestratorResult: + # Wire channels between nodes (same logic as today) + # For materialize_results=True: tee each output channel to collect items + + async with asyncio.TaskGroup() as tg: + for node in topo_order: + if is_source_node(node): + tg.create_task( + node.async_execute(writer, observer=self._observer) + ) + elif is_function_node(node): + tg.create_task( + node.async_execute( + input_reader, writer, observer=self._observer + ) + ) + elif is_operator_node(node): + tg.create_task( + node.async_execute( + input_readers, writer, observer=self._observer + ) + ) + + return OrchestratorResult(node_outputs=collected if materialize else {}) +``` + +Key changes from current implementation: + +- Takes `graph: nx.DiGraph` instead of `Pipeline` + `PipelineConfig`. +- Returns `OrchestratorResult` instead of `None`. +- TypeGuard dispatch with per-type signatures instead of uniform call. +- Observer injection via constructor + parameter forwarding. +- `buffer_size` is a constructor parameter (not from `PipelineConfig`). +- `materialize_results` controls whether intermediate outputs are collected. + +### Result collection + +When `materialize_results=True`, each node's output channel is tapped to collect +items into a list as they flow through. This uses a lightweight wrapper that +appends each item to a per-node list before forwarding to downstream readers. +When `materialize_results=False`, no collection occurs and `OrchestratorResult` +has empty `node_outputs`. Terminal sink channels are still drained regardless of +this flag, since nodes write to them unconditionally. + +### Channel wiring + +Channel wiring logic is preserved from the current implementation: + +- Single downstream: plain `Channel(buffer_size=self._buffer_size)`. +- Fan-out (multiple downstreams): `BroadcastChannel` with a reader per + downstream. +- Terminal nodes (no outgoing edges): sink `Channel` so `async_execute` has + somewhere to write. Drained after execution. + +### Existing node `run()` methods + +The existing `run()` method on `SourceNode`, `FunctionNode`, and `OperatorNode` +(the non-orchestrator pull-based execution path) is left intact. It serves a +different purpose — standalone node execution outside of pipeline orchestration. + +### Error handling + +If a node raises during `async_execute`, `asyncio.TaskGroup` cancels all +sibling tasks and propagates the exception. Observer hooks (`on_node_end`) are +not guaranteed to fire on failure — this matches the sync orchestrator's +behavior where an exception in `node.execute()` also skips `on_node_end`. + +## Node Internal Changes + +### SourceNode + +**New method:** `execute(*, observer=None) -> list[(tag, packet)]` + +- Calls `observer.on_node_start(self)` if observer provided. +- Materializes `self.iter_packets()` into a list. +- Populates `_cached_results` so subsequent `iter_packets()` calls return the + cached version. +- Calls `observer.on_node_end(self)`. +- Returns the list. + +**Signature change:** `async_execute(output, *, observer=None) -> None` + +- Tightened from `async_execute(inputs, output)` — no `inputs` parameter + (source has no upstream). +- Adds observer `on_node_start` / `on_node_end` hooks internally. + +**Removed from protocol:** `iter_packets()` — replaced by `execute()`. The +method remains on the class for internal use and backward compatibility, but it +is no longer part of `SourceNodeProtocol`. + +### FunctionNode + +**Signature change:** `execute(input_stream, *, observer=None) -> list[(tag, packet)]` + +- The existing `execute` method already takes `input_stream: StreamProtocol` and + returns `list[(tag, packet)]`. +- Adds `observer` parameter. Internally calls `on_node_start` / `on_node_end` + and per-packet `on_packet_start` / `on_packet_end(cached=...)` hooks. +- Internally uses `get_cached_results`, `compute_pipeline_entry_id`, and + `execute_packet` — these are implementation details, not protocol surface. + +**Signature change:** `async_execute(input_channel, output, *, observer=None) -> None` + +- Tightened from `async_execute(inputs, output, pipeline_config)` — single + `input_channel` instead of `Sequence[ReadableChannel]`. +- `pipeline_config` parameter removed entirely. Node uses its existing default + concurrency. Proper node-level concurrency config deferred to PLT-930. +- Adds observer hooks internally. + +**Internal execution logic in `execute()`:** The current +`SyncPipelineOrchestrator._execute_function` drives a per-packet loop with +cache lookup (`get_cached_results`) and observer hooks. This logic moves inside +`FunctionNode.execute()`: iterate over the input stream's packets, call +`compute_pipeline_entry_id` to check the pipeline DB, call `execute_packet` for +misses, and fire `on_packet_start` / `on_packet_end(cached=...)` around each +packet. The node's internal `CachedFunctionPod` handles function-level +memoization as before. + +**Removed from protocol (kept as class methods):** +`get_cached_results`, `execute_packet`, `compute_pipeline_entry_id`. + +### OperatorNode + +**Signature change:** `execute(*input_streams, observer=None) -> list[(tag, packet)]` + +- The existing `execute` method already takes `*input_streams: StreamProtocol` + and returns `list[(tag, packet)]`. +- Adds `observer` parameter. Internally calls `on_node_start` / `on_node_end`. +- Internally calls `get_cached_output()` first — if it returns a stream + (REPLAY mode), materializes it and returns without computing. Otherwise + delegates to the operator's `process()` and handles persistence. + +**Signature change:** `async_execute(inputs, output, *, observer=None) -> None` + +- Signature already takes `Sequence[ReadableChannel]` — no arity change. +- Adds observer hooks internally. + +**Removed from protocol (kept as class method):** `get_cached_output`. + +## Pipeline.run() Changes + +```python +def run(self, orchestrator=None, config=None, ...): + if not self._compiled: + self.compile() + if effective_engine is not None: + self._apply_execution_engine(effective_engine, effective_opts) + + if orchestrator is not None: + orchestrator.run(self._node_graph) + else: + use_async = ... # same logic as today + if use_async: + AsyncPipelineOrchestrator( + buffer_size=config.channel_buffer_size, + ).run(self._node_graph) + else: + SyncPipelineOrchestrator().run(self._node_graph) + + self.flush() +``` + +The `_run_async()` helper method is removed. Both orchestrators receive +`self._node_graph` directly. The default async path instantiates +`AsyncPipelineOrchestrator` inline and calls `.run(self._node_graph)`, matching +the sync path pattern. + +## File-level Change Summary + +| File | Changes | +|------|---------| +| `protocols/node_protocols.py` | Remove `get_cached_results`, `execute_packet`, `compute_pipeline_entry_id` from `FunctionNodeProtocol`. Remove `get_cached_output` from `OperatorNodeProtocol`. Remove `iter_packets` from `SourceNodeProtocol`. Add `execute` and `async_execute` with observer param to all three protocols. | +| `protocols/core_protocols/async_executable.py` | Delete entire file. | +| `protocols/core_protocols/__init__.py` | Remove `AsyncExecutableProtocol` re-export. | +| `core/nodes/source_node.py` | Add `execute(observer=None)` method. Change `async_execute` signature (remove `inputs` param, add `observer`). Add observer hooks. | +| `core/nodes/function_node.py` | Add `observer` param to `execute`. Change `async_execute` signature (single `input_channel`, remove `pipeline_config`, add `observer`). Move per-packet cache lookup + observer hook calls inside `execute()`. | +| `core/nodes/operator_node.py` | Add `observer` param to `execute`. Add `observer` param to `async_execute`. Move observer hook calls inside both methods. | +| `pipeline/sync_orchestrator.py` | Remove `_execute_source`, `_execute_function`, `_execute_operator`. Simplify `run()` to call `node.execute(...)` directly. Pass observer to nodes. | +| `pipeline/async_orchestrator.py` | Change `run` / `run_async` to take `graph` + `materialize_results`. Add TypeGuard dispatch. Tighten per-node `async_execute` calls. Add `buffer_size` constructor param. Add observer support. Return `OrchestratorResult`. | +| `pipeline/graph.py` | Remove `_run_async()`. Update default async path to instantiate `AsyncPipelineOrchestrator` and call `.run(self._node_graph)`. | +| `tests/test_pipeline/test_sync_orchestrator.py` | Update tests for new `node.execute()` path. Observer tests verify hooks fire from inside nodes. | +| `tests/test_pipeline/test_orchestrator.py` | Update async tests for new signature (`graph` instead of `Pipeline`). Add `materialize_results` tests. Add fan-out and terminal node tests. | + +## Removals + +**From node protocols:** + +- `SourceNodeProtocol.iter_packets()` +- `FunctionNodeProtocol.get_cached_results()` +- `FunctionNodeProtocol.compute_pipeline_entry_id()` +- `FunctionNodeProtocol.execute_packet()` +- `OperatorNodeProtocol.get_cached_output()` + +**From protocols package:** + +- `AsyncExecutableProtocol` (entire file `async_executable.py`) + +**From Pipeline:** + +- `_run_async()` helper method + +**From SyncPipelineOrchestrator:** + +- `_execute_source()`, `_execute_function()`, `_execute_operator()` methods + +## Testing Strategy + +- **Sync orchestrator tests**: Update existing tests to verify `node.execute()` + is called (not the removed per-packet orchestrator logic). Observer tests + verify hooks fire from inside nodes with the same events and order. +- **Async orchestrator tests**: Update for new signature (`graph` instead of + `Pipeline`). Verify `OrchestratorResult` is returned. +- **Sync/async parity tests**: Both orchestrators should produce identical DB + results. Existing parity tests updated for new signatures. +- **`materialize_results` tests**: Verify `True` collects all node outputs, + `False` returns empty `node_outputs` (both orchestrators). +- **Fan-out tests**: Verify `BroadcastChannel` wiring when one node fans out to + multiple downstreams (async orchestrator). +- **Terminal node tests**: Verify sink channels are created and drained for nodes + with no outgoing edges (async orchestrator). +- **Error propagation tests**: Verify that a node failure in `TaskGroup` + propagates correctly and doesn't hang. diff --git a/tests/test_data/__init__.py b/test-objective/__init__.py similarity index 100% rename from tests/test_data/__init__.py rename to test-objective/__init__.py diff --git a/test-objective/conftest.py b/test-objective/conftest.py new file mode 100644 index 00000000..b55669b6 --- /dev/null +++ b/test-objective/conftest.py @@ -0,0 +1,265 @@ +"""Shared fixtures and helpers for specification-derived objective tests. + +These tests are derived from design documents, protocol definitions, and +interface contracts — NOT from reading implementation code. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.datagram import Datagram +from orcapod.core.datagrams.tag_packet import Packet, Tag +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.operators import ( + Batch, + DropPacketColumns, + DropTagColumns, + Join, + MapPackets, + MapTags, + MergeJoin, + PolarsFilter, + SelectPacketColumns, + SelectTagColumns, + SemiJoin, +) +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource, DictSource, ListSource +from orcapod.core.streams import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase, NoOpArrowDatabase +from orcapod.types import ColumnConfig, ContentHash, Schema + + +# --------------------------------------------------------------------------- +# Helper functions for packet functions +# --------------------------------------------------------------------------- + + +def double_value(x: int) -> int: + """Double an integer value.""" + return x * 2 + + +def add_values(x: int, y: int) -> int: + """Add two integer values.""" + return x + y + + +def to_uppercase(name: str) -> str: + """Convert a string to uppercase.""" + return name.upper() + + +def negate(x: int) -> int: + """Negate an integer.""" + return -x + + +def square(x: int) -> int: + """Square an integer.""" + return x * x + + +def concat_fields(first: str, last: str) -> str: + """Concatenate two strings with a space.""" + return f"{first} {last}" + + +def return_none(x: int) -> int | None: + """Always returns None (for testing None propagation).""" + return None + + +# --------------------------------------------------------------------------- +# Arrow table factories +# --------------------------------------------------------------------------- + + +def make_simple_table(n: int = 3) -> pa.Table: + """Table with tag=id (int), packet=value (int).""" + return pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "value": pa.array([i * 10 for i in range(n)], type=pa.int64()), + } + ) + + +def make_two_packet_col_table(n: int = 3) -> pa.Table: + """Table with tag=id, packet={x, y}.""" + return pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + "y": pa.array([i * 10 for i in range(n)], type=pa.int64()), + } + ) + + +def make_string_table(n: int = 3) -> pa.Table: + """Table with tag=id, packet=name (str).""" + names = ["alice", "bob", "charlie"][:n] + return pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "name": pa.array(names, type=pa.large_string()), + } + ) + + +def make_joinable_tables() -> tuple[pa.Table, pa.Table]: + """Two tables with shared tag=id, non-overlapping packet columns.""" + left = pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "age": pa.array([25, 30, 35], type=pa.int64()), + } + ) + right = pa.table( + { + "id": pa.array([2, 3, 4], type=pa.int64()), + "score": pa.array([85, 90, 95], type=pa.int64()), + } + ) + return left, right + + +def make_overlapping_packet_tables() -> tuple[pa.Table, pa.Table]: + """Two tables with shared tag=id AND overlapping packet column 'value'.""" + left = pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "value": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right = pa.table( + { + "id": pa.array([2, 3, 4], type=pa.int64()), + "value": pa.array([200, 300, 400], type=pa.int64()), + } + ) + return left, right + + +# --------------------------------------------------------------------------- +# Fixtures: Arrow tables +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_table() -> pa.Table: + return make_simple_table() + + +@pytest.fixture +def two_col_table() -> pa.Table: + return make_two_packet_col_table() + + +@pytest.fixture +def string_table() -> pa.Table: + return make_string_table() + + +# --------------------------------------------------------------------------- +# Fixtures: Streams +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_stream() -> ArrowTableStream: + """Stream with tag=id, packet=value.""" + return ArrowTableStream(make_simple_table(), tag_columns=["id"]) + + +@pytest.fixture +def two_col_stream() -> ArrowTableStream: + """Stream with tag=id, packet={x, y}.""" + return ArrowTableStream(make_two_packet_col_table(), tag_columns=["id"]) + + +@pytest.fixture +def string_stream() -> ArrowTableStream: + """Stream with tag=id, packet=name.""" + return ArrowTableStream(make_string_table(), tag_columns=["id"]) + + +@pytest.fixture +def joinable_streams() -> tuple[ArrowTableStream, ArrowTableStream]: + """Two streams with shared tag=id, non-overlapping packet columns.""" + left, right = make_joinable_tables() + return ( + ArrowTableStream(left, tag_columns=["id"]), + ArrowTableStream(right, tag_columns=["id"]), + ) + + +# --------------------------------------------------------------------------- +# Fixtures: Sources +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_source() -> ArrowTableSource: + return ArrowTableSource(make_simple_table(), tag_columns=["id"]) + + +@pytest.fixture +def dict_source() -> DictSource: + return DictSource( + {"id": [1, 2, 3], "value": [10, 20, 30]}, + tag_columns=["id"], + ) + + +# --------------------------------------------------------------------------- +# Fixtures: Packet functions +# --------------------------------------------------------------------------- + + +@pytest.fixture +def double_pf() -> PythonPacketFunction: + return PythonPacketFunction(double_value, output_keys="result") + + +@pytest.fixture +def add_pf() -> PythonPacketFunction: + return PythonPacketFunction(add_values, output_keys="result") + + +@pytest.fixture +def uppercase_pf() -> PythonPacketFunction: + return PythonPacketFunction(to_uppercase, output_keys="result") + + +# --------------------------------------------------------------------------- +# Fixtures: Pods +# --------------------------------------------------------------------------- + + +@pytest.fixture +def double_pod(double_pf) -> FunctionPod: + return FunctionPod(packet_function=double_pf) + + +@pytest.fixture +def add_pod(add_pf) -> FunctionPod: + return FunctionPod(packet_function=add_pf) + + +# --------------------------------------------------------------------------- +# Fixtures: Databases +# --------------------------------------------------------------------------- + + +@pytest.fixture +def inmemory_db() -> InMemoryArrowDatabase: + return InMemoryArrowDatabase() + + +@pytest.fixture +def noop_db() -> NoOpArrowDatabase: + return NoOpArrowDatabase() diff --git a/test-objective/integration/__init__.py b/test-objective/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test-objective/integration/test_caching_flows.py b/test-objective/integration/test_caching_flows.py new file mode 100644 index 00000000..8418dd74 --- /dev/null +++ b/test-objective/integration/test_caching_flows.py @@ -0,0 +1,237 @@ +"""Specification-derived integration tests for DB-backed caching flows. + +Tests FunctionNode and OperatorNode caching behavior +as documented in the design specification. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import ( + FunctionNode, + OperatorNode, +) +from orcapod.core.operators import Join +from orcapod.core.packet_function import CachedPacketFunction, PythonPacketFunction +from orcapod.core.sources import ArrowTableSource, DerivedSource +from orcapod.core.streams import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import CacheMode + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _double(x: int) -> int: + return x * 2 + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableSource(table, tag_columns=["id"]) + + +# =================================================================== +# FunctionNode caching +# =================================================================== + + +class TestFunctionNodeCaching: + """Per design: first run computes and stores; second run replays cached.""" + + def test_first_run_computes_all(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + source = _make_source(3) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=pod, + input_stream=source, + pipeline_database=pipeline_db, + result_database=result_db, + ) + node.run() + records = node.get_all_records() + assert records is not None + assert records.num_rows == 3 + + def test_second_run_uses_cache(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + source = _make_source(3) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + + # First run + node1 = FunctionNode( + function_pod=pod, + input_stream=source, + pipeline_database=pipeline_db, + result_database=result_db, + ) + node1.run() + + # Second run with same inputs — should use cached results + node2 = FunctionNode( + function_pod=pod, + input_stream=source, + pipeline_database=pipeline_db, + result_database=result_db, + ) + packets = list(node2.iter_packets()) + assert len(packets) == 3 + + +class TestDerivedSourceReingestion: + """Per design: FunctionNode → DerivedSource → further pipeline.""" + + def test_derived_source_as_pipeline_input(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + source = _make_source(3) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + + node = FunctionNode( + function_pod=pod, + input_stream=source, + pipeline_database=pipeline_db, + result_database=result_db, + ) + node.run() + + # Create DerivedSource from the node's results + derived = node.as_source() + assert isinstance(derived, DerivedSource) + + # Should be able to iterate packets from derived source + packets = list(derived.iter_packets()) + assert len(packets) == 3 + + +# =================================================================== +# OperatorNode caching +# =================================================================== + + +class TestOperatorNodeCaching: + """Per design: CacheMode.LOG stores results, REPLAY loads from DB.""" + + def test_log_mode_stores_results(self): + source_a = ArrowTableSource( + pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "age": pa.array([25, 30, 35], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + source_b = ArrowTableSource( + pa.table( + { + "id": pa.array([2, 3, 4], type=pa.int64()), + "score": pa.array([85, 90, 95], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + join = Join() + db = InMemoryArrowDatabase() + node = OperatorNode( + operator=join, + input_streams=[source_a, source_b], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + node.run() + records = node.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_replay_mode_loads_from_db(self): + source_a = ArrowTableSource( + pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "age": pa.array([25, 30, 35], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + source_b = ArrowTableSource( + pa.table( + { + "id": pa.array([2, 3, 4], type=pa.int64()), + "score": pa.array([85, 90, 95], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + join = Join() + db = InMemoryArrowDatabase() + + # First: LOG + node1 = OperatorNode( + operator=join, + input_streams=[source_a, source_b], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + node1.run() + + # Second: REPLAY + node2 = OperatorNode( + operator=join, + input_streams=[source_a, source_b], + pipeline_database=db, + cache_mode=CacheMode.REPLAY, + ) + node2.run() + table = node2.as_table() + assert table.num_rows == 2 + + +# =================================================================== +# CachedPacketFunction end-to-end +# =================================================================== + + +class TestCachedPacketFunctionEndToEnd: + """End-to-end test of CachedPacketFunction with InMemoryArrowDatabase.""" + + def test_full_caching_flow(self): + db = InMemoryArrowDatabase() + inner_pf = PythonPacketFunction(_double, output_keys="result") + cached_pf = CachedPacketFunction(inner_pf, result_database=db) + cached_pf.set_auto_flush(True) + + from orcapod.core.datagrams.tag_packet import Packet + + # Process multiple packets + for x in [1, 2, 3]: + result = cached_pf.call(Packet({"x": x})) + assert result is not None + assert result["result"] == x * 2 + + # All should be cached + all_outputs = cached_pf.get_all_cached_outputs() + assert all_outputs is not None + assert all_outputs.num_rows == 3 + + # Re-calling should use cache + for x in [1, 2, 3]: + result = cached_pf.call(Packet({"x": x})) + assert result is not None + assert result["result"] == x * 2 diff --git a/test-objective/integration/test_column_config_filtering.py b/test-objective/integration/test_column_config_filtering.py new file mode 100644 index 00000000..722d4165 --- /dev/null +++ b/test-objective/integration/test_column_config_filtering.py @@ -0,0 +1,198 @@ +"""Specification-derived integration tests for ColumnConfig filtering across components. + +Tests that ColumnConfig consistently controls column visibility across +Datagram, Tag, Packet, Stream, and Source components. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.datagram import Datagram +from orcapod.core.datagrams.tag_packet import Packet, Tag +from orcapod.core.sources import ArrowTableSource +from orcapod.core.streams import ArrowTableStream +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig + +# Use the actual system tag prefix from constants +_SYS_TAG_KEY = f"{constants.SYSTEM_TAG_PREFIX}source:abc" + + +# =================================================================== +# Datagram ColumnConfig +# =================================================================== + + +class TestDatagramColumnConfig: + """Per design, ColumnConfig controls which column groups are visible.""" + + def test_data_only_excludes_meta(self): + d = Datagram( + {"name": "alice", "age": 30}, + meta_info={"pipeline": "test"}, + ) + keys = d.keys() + assert "name" in keys + assert "age" in keys + # Meta columns should not be visible + for k in keys: + assert not k.startswith(constants.META_PREFIX) + + def test_meta_true_includes_meta(self): + d = Datagram( + {"name": "alice"}, + meta_info={"pipeline": "test"}, + ) + keys_default = d.keys() + keys_with_meta = d.keys(columns=ColumnConfig(meta=True)) + # With meta=True, there should be more keys than default + assert len(keys_with_meta) > len(keys_default) + assert "pipeline" in keys_with_meta + + def test_all_info_includes_everything(self): + d = Datagram( + {"name": "alice"}, + meta_info={"pipeline": "test"}, + ) + keys_all = d.keys(all_info=True) + keys_default = d.keys() + assert len(keys_all) >= len(keys_default) + + +# =================================================================== +# Tag ColumnConfig +# =================================================================== + + +class TestTagColumnConfig: + """Per design, system_tags=True includes _tag_ columns in Tag.""" + + def test_system_tags_excluded_by_default(self): + t = Tag( + {"id": 1}, + system_tags={_SYS_TAG_KEY: "rec1"}, + ) + keys = t.keys() + assert _SYS_TAG_KEY not in keys + + def test_system_tags_included_with_config(self): + t = Tag( + {"id": 1}, + system_tags={_SYS_TAG_KEY: "rec1"}, + ) + keys_default = t.keys() + keys_with_tags = t.keys(columns=ColumnConfig(system_tags=True)) + assert len(keys_with_tags) > len(keys_default) + assert _SYS_TAG_KEY in keys_with_tags + + def test_all_info_includes_system_tags(self): + t = Tag( + {"id": 1}, + system_tags={_SYS_TAG_KEY: "rec1"}, + ) + keys = t.keys(all_info=True) + assert _SYS_TAG_KEY in keys + + +# =================================================================== +# Packet ColumnConfig +# =================================================================== + + +class TestPacketColumnConfig: + """Per design, source=True includes _source_ columns in Packet.""" + + def test_source_excluded_by_default(self): + p = Packet( + {"value": 42}, + source_info={"value": "src1:rec1"}, + ) + keys = p.keys() + for k in keys: + assert not k.startswith(constants.SOURCE_PREFIX) + + def test_source_included_with_config(self): + p = Packet( + {"value": 42}, + source_info={"value": "src1:rec1"}, + ) + keys = p.keys(columns=ColumnConfig(source=True)) + source_keys = [k for k in keys if k.startswith(constants.SOURCE_PREFIX)] + assert len(source_keys) > 0 + + def test_all_info_includes_source(self): + p = Packet( + {"value": 42}, + source_info={"value": "src1:rec1"}, + ) + keys = p.keys(all_info=True) + source_keys = [k for k in keys if k.startswith(constants.SOURCE_PREFIX)] + assert len(source_keys) > 0 + + +# =================================================================== +# Stream ColumnConfig consistency +# =================================================================== + + +class TestStreamColumnConfigConsistency: + """Per design, keys(), output_schema(), and as_table() should all + respect the same ColumnConfig consistently.""" + + def test_keys_schema_table_consistency_default(self): + source = ArrowTableSource( + pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "value": pa.array([10, 20], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + tag_keys, packet_keys = source.keys() + tag_schema, packet_schema = source.output_schema() + table = source.as_table() + + # keys and schema should have same field names + assert set(tag_keys) == set(tag_schema.keys()) + assert set(packet_keys) == set(packet_schema.keys()) + + # Table should have all key columns + all_keys = set(tag_keys) | set(packet_keys) + assert all_keys.issubset(set(table.column_names)) + + def test_keys_schema_table_consistency_all_info(self): + source = ArrowTableSource( + pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "value": pa.array([10, 20], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + tag_keys, packet_keys = source.keys(all_info=True) + tag_schema, packet_schema = source.output_schema(all_info=True) + table = source.as_table(all_info=True) + + assert set(tag_keys) == set(tag_schema.keys()) + assert set(packet_keys) == set(packet_schema.keys()) + + all_keys = set(tag_keys) | set(packet_keys) + assert all_keys.issubset(set(table.column_names)) + + def test_all_info_has_more_columns_than_default(self): + source = ArrowTableSource( + pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "value": pa.array([10, 20], type=pa.int64()), + } + ), + tag_columns=["id"], + ) + default_table = source.as_table() + all_info_table = source.as_table(all_info=True) + assert all_info_table.num_columns >= default_table.num_columns diff --git a/test-objective/integration/test_hash_invariants.py b/test-objective/integration/test_hash_invariants.py new file mode 100644 index 00000000..1e3d2819 --- /dev/null +++ b/test-objective/integration/test_hash_invariants.py @@ -0,0 +1,169 @@ +"""Specification-derived integration tests for hash stability and Merkle chain properties. + +Tests the two parallel identity chains documented in the design spec: +1. content_hash() — data-inclusive, changes when data changes +2. pipeline_hash() — schema+topology only, ignores data content +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.operators import Join, SemiJoin +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.core.streams import ArrowTableStream + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _double(x: int) -> int: + return x * 2 + + +def _make_source(data: dict, tag_columns: list[str]) -> ArrowTableSource: + table = pa.table(data) + return ArrowTableSource(table, tag_columns=tag_columns) + + +# =================================================================== +# Content hash stability +# =================================================================== + + +class TestContentHashStability: + """Per design: content_hash is deterministic — identical data produces + identical hash across runs.""" + + def test_same_data_same_hash(self): + s1 = ArrowTableStream( + pa.table({"id": [1, 2], "x": [10, 20]}), tag_columns=["id"] + ) + s2 = ArrowTableStream( + pa.table({"id": [1, 2], "x": [10, 20]}), tag_columns=["id"] + ) + assert s1.content_hash() == s2.content_hash() + + def test_different_data_different_hash(self): + s1 = ArrowTableStream( + pa.table({"id": [1, 2], "x": [10, 20]}), tag_columns=["id"] + ) + s2 = ArrowTableStream( + pa.table({"id": [1, 2], "x": [10, 99]}), tag_columns=["id"] + ) + assert s1.content_hash() != s2.content_hash() + + +# =================================================================== +# Pipeline hash properties +# =================================================================== + + +class TestPipelineHashProperties: + """Per design: pipeline_hash is schema+topology only, ignoring data content.""" + + def test_same_schema_different_data_same_pipeline_hash(self): + """Same schema, different data → same pipeline_hash.""" + s1 = _make_source( + {"id": pa.array([1, 2], type=pa.int64()), "x": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + s2 = _make_source( + {"id": pa.array([3, 4], type=pa.int64()), "x": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + assert s1.pipeline_hash() == s2.pipeline_hash() + + def test_different_schema_different_pipeline_hash(self): + """Different schema → different pipeline_hash.""" + s1 = _make_source( + {"id": pa.array([1], type=pa.int64()), "x": pa.array([10], type=pa.int64())}, + ["id"], + ) + s2 = _make_source( + {"id": pa.array([1], type=pa.int64()), "y": pa.array(["a"], type=pa.large_string())}, + ["id"], + ) + assert s1.pipeline_hash() != s2.pipeline_hash() + + +# =================================================================== +# Merkle chain properties +# =================================================================== + + +class TestMerkleChain: + """Per design: each downstream node's pipeline hash commits to its own + identity plus the pipeline hashes of its upstreams.""" + + def test_downstream_hash_depends_on_upstream(self): + """Different upstream sources with different schemas produce different + downstream pipeline hashes even with the same operator/pod.""" + source_a = _make_source( + {"id": pa.array([1, 2], type=pa.int64()), "x": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + # Different schema: tag=category instead of tag=id + source_b = _make_source( + {"category": pa.array([1, 2], type=pa.int64()), "x": pa.array([10, 20], type=pa.int64())}, + ["category"], + ) + + pf_a = PythonPacketFunction(_double, output_keys="result") + pod_a = FunctionPod(packet_function=pf_a) + pf_b = PythonPacketFunction(_double, output_keys="result") + pod_b = FunctionPod(packet_function=pf_b) + + stream_a = pod_a.process(source_a) + stream_b = pod_b.process(source_b) + + # Different upstream schemas → different downstream pipeline hashes + assert stream_a.pipeline_hash() != stream_b.pipeline_hash() + + +# =================================================================== +# Commutativity of join pipeline hash +# =================================================================== + + +class TestJoinPipelineHashCommutativity: + """Per design: commutative operators produce the same pipeline_hash + regardless of input order.""" + + def test_commutative_join_order_independent(self): + sa = _make_source( + {"id": pa.array([1, 2], type=pa.int64()), "a": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + sb = _make_source( + {"id": pa.array([1, 2], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + + join = Join() + result_ab = join.process(sa, sb) + result_ba = join.process(sb, sa) + + assert result_ab.pipeline_hash() == result_ba.pipeline_hash() + + def test_non_commutative_semijoin_order_dependent(self): + sa = _make_source( + {"id": pa.array([1, 2], type=pa.int64()), "a": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + sb = _make_source( + {"id": pa.array([1, 2], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + + semi = SemiJoin() + result_ab = semi.process(sa, sb) + result_ba = semi.process(sb, sa) + + # SemiJoin is non-commutative, so pipeline hashes should differ + assert result_ab.pipeline_hash() != result_ba.pipeline_hash() diff --git a/test-objective/integration/test_pipeline_flows.py b/test-objective/integration/test_pipeline_flows.py new file mode 100644 index 00000000..6a67c466 --- /dev/null +++ b/test-objective/integration/test_pipeline_flows.py @@ -0,0 +1,301 @@ +"""Specification-derived integration tests for end-to-end pipeline flows. + +Tests complete pipeline scenarios as described in the design specification: +Source → Stream → [Operator / FunctionPod] → Stream → ... +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.operators import ( + Batch, + DropPacketColumns, + Join, + MapTags, + MergeJoin, + PolarsFilter, + SelectPacketColumns, + SemiJoin, +) +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource, DictSource +from orcapod.core.streams import ArrowTableStream + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _double(x: int) -> int: + return x * 2 + + +def _negate(x: int) -> int: + return -x + + +def _square(x: int) -> int: + return x * x + + +def _square_doubled(doubled: int) -> int: + return doubled * doubled + + +def _make_source(tag_data: dict, packet_data: dict, tag_columns: list[str]): + all_data = {**tag_data, **packet_data} + table = pa.table(all_data) + return ArrowTableSource(table, tag_columns=tag_columns) + + +# =================================================================== +# Single operator pipelines +# =================================================================== + + +class TestSourceToFilter: + """Source → PolarsFilter → Stream.""" + + def test_filter_reduces_rows(self): + source = _make_source( + {"id": pa.array([1, 2, 3, 4, 5], type=pa.int64())}, + {"value": pa.array([10, 20, 30, 40, 50], type=pa.int64())}, + ["id"], + ) + filt = PolarsFilter(constraints={"id": 3}) + result = filt.process(source) + table = result.as_table() + assert table.num_rows == 1 + assert table.column("id").to_pylist() == [3] + + +class TestSourceToFunctionPod: + """Source → FunctionPod → Stream with transformed packets.""" + + def test_function_pod_transforms_all_packets(self): + source = _make_source( + {"id": pa.array([0, 1, 2], type=pa.int64())}, + {"x": pa.array([10, 20, 30], type=pa.int64())}, + ["id"], + ) + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + result = pod.process(source) + packets = list(result.iter_packets()) + assert len(packets) == 3 + results = [p["result"] for _, p in packets] + assert sorted(results) == [20, 40, 60] + + +class TestMultiSourceJoin: + """Two sources → Join → Stream with combined data.""" + + def test_join_combines_matching_rows(self): + source_a = _make_source( + {"id": pa.array([1, 2, 3], type=pa.int64())}, + {"name": pa.array(["alice", "bob", "charlie"], type=pa.large_string())}, + ["id"], + ) + source_b = _make_source( + {"id": pa.array([2, 3, 4], type=pa.int64())}, + {"score": pa.array([85, 90, 95], type=pa.int64())}, + ["id"], + ) + join = Join() + result = join.process(source_a, source_b) + table = result.as_table() + assert table.num_rows == 2 # id=2, id=3 + assert "name" in table.column_names + assert "score" in table.column_names + + +# =================================================================== +# Chained operator pipelines +# =================================================================== + + +class TestChainedOperators: + """Source → Filter → Select → MapTags → Stream.""" + + def test_chain_of_three_operators(self): + source = _make_source( + { + "id": pa.array([1, 2, 3, 4, 5], type=pa.int64()), + "group": pa.array(["a", "b", "a", "b", "a"], type=pa.large_string()), + }, + {"value": pa.array([10, 20, 30, 40, 50], type=pa.int64())}, + ["id", "group"], + ) + # Step 1: Filter to group="a" + filt = PolarsFilter(constraints={"group": "a"}) + filtered = filt.process(source) + + # Step 2: Select only relevant packet columns + select = SelectPacketColumns(columns=["value"]) + selected = select.process(filtered) + + # Step 3: Rename tag + mapper = MapTags(name_map={"id": "item_id"}) + result = mapper.process(selected) + + table = result.as_table() + assert table.num_rows == 3 # group="a" has 3 rows + assert "item_id" in table.column_names + assert "id" not in table.column_names + + +class TestFunctionPodThenOperator: + """Source → FunctionPod → PolarsFilter → Stream.""" + + def test_transform_then_filter(self): + source = _make_source( + {"id": pa.array([0, 1, 2, 3, 4], type=pa.int64())}, + {"x": pa.array([1, 2, 3, 4, 5], type=pa.int64())}, + ["id"], + ) + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + transformed = pod.process(source) + + # Filter to only results >= 6 (i.e., x >= 3 → result >= 6) + # We can filter on tag id >= 3 + filt = PolarsFilter(constraints={"id": 3}) + result = filt.process(transformed) + table = result.as_table() + assert table.num_rows == 1 + + +class TestJoinThenBatch: + """Two sources → Join → Batch → Stream.""" + + def test_join_then_batch(self): + source_a = _make_source( + {"group": pa.array(["x", "x", "y"], type=pa.large_string())}, + {"a": pa.array([1, 2, 3], type=pa.int64())}, + ["group"], + ) + source_b = _make_source( + {"group": pa.array(["x", "x", "y"], type=pa.large_string())}, + {"b": pa.array([10, 20, 30], type=pa.int64())}, + ["group"], + ) + join = Join() + joined = join.process(source_a, source_b) + + batch = Batch() + result = batch.process(joined) + table = result.as_table() + # After join and batch, rows should be grouped by tag + assert table.num_rows >= 1 + + +class TestSemiJoinFilters: + """Source A semi-joined with Source B.""" + + def test_semijoin_keeps_matching_left(self): + source_a = _make_source( + {"id": pa.array([1, 2, 3, 4, 5], type=pa.int64())}, + {"value": pa.array([10, 20, 30, 40, 50], type=pa.int64())}, + ["id"], + ) + source_b = _make_source( + {"id": pa.array([2, 4], type=pa.int64())}, + {"dummy": pa.array([0, 0], type=pa.int64())}, + ["id"], + ) + semi = SemiJoin() + result = semi.process(source_a, source_b) + table = result.as_table() + assert table.num_rows == 2 + assert sorted(table.column("id").to_pylist()) == [2, 4] + + +class TestMergeJoinCombines: + """Two sources with overlapping columns → MergeJoin.""" + + def test_merge_join_merges_columns(self): + source_a = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"score": pa.array([80, 90], type=pa.int64())}, + ["id"], + ) + source_b = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"score": pa.array([85, 95], type=pa.int64())}, + ["id"], + ) + merge = MergeJoin() + result = merge.process(source_a, source_b) + table = result.as_table() + assert table.num_rows == 2 + # score should now be list type + score_type = table.schema.field("score").type + assert pa.types.is_list(score_type) or pa.types.is_large_list(score_type) + + +# =================================================================== +# Diamond pipeline +# =================================================================== + + +class TestDiamondPipeline: + """Source → [branch A, branch B] → Join → Stream.""" + + def test_diamond_topology(self): + source = _make_source( + {"id": pa.array([1, 2, 3], type=pa.int64())}, + {"x": pa.array([10, 20, 30], type=pa.int64())}, + ["id"], + ) + # Branch A: double x + pf_a = PythonPacketFunction(_double, output_keys="doubled") + pod_a = FunctionPod(packet_function=pf_a) + branch_a = pod_a.process(source) + + # Branch B: negate x + pf_b = PythonPacketFunction(_negate, output_keys="negated") + pod_b = FunctionPod(packet_function=pf_b) + branch_b = pod_b.process(source) + + # Join branches + join = Join() + result = join.process(branch_a, branch_b) + table = result.as_table() + assert table.num_rows == 3 + assert "doubled" in table.column_names + assert "negated" in table.column_names + + +# =================================================================== +# Multiple function pods chained +# =================================================================== + + +class TestChainedFunctionPods: + """Source → FunctionPod1 → FunctionPod2 → Stream.""" + + def test_two_sequential_transformations(self): + source = _make_source( + {"id": pa.array([1, 2, 3], type=pa.int64())}, + {"x": pa.array([2, 3, 4], type=pa.int64())}, + ["id"], + ) + # First: double + pf1 = PythonPacketFunction(_double, output_keys="doubled") + pod1 = FunctionPod(packet_function=pf1) + step1 = pod1.process(source) + + # Second: square the doubled value + pf2 = PythonPacketFunction(_square_doubled, output_keys="squared") + pod2 = FunctionPod(packet_function=pf2) + step2 = pod2.process(step1) + + packets = list(step2.iter_packets()) + assert len(packets) == 3 + # x=2 → doubled=4 → squared=16 + results = sorted([p["squared"] for _, p in packets]) + assert results == [16, 36, 64] diff --git a/test-objective/integration/test_provenance.py b/test-objective/integration/test_provenance.py new file mode 100644 index 00000000..f73a9930 --- /dev/null +++ b/test-objective/integration/test_provenance.py @@ -0,0 +1,250 @@ +"""Specification-derived integration tests for system tag lineage tracking. + +Tests the three system tag evolution rules from the design specification: +1. Name-preserving — single-stream ops (filter, select, map) +2. Name-extending — multi-input ops (join, merge join) +3. Type-evolving — aggregation ops (batch) +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.operators import Batch, Join, MapTags, PolarsFilter, SelectPacketColumns +from orcapod.core.sources import ArrowTableSource +from orcapod.core.streams import ArrowTableStream +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source(tag_data: dict, packet_data: dict, tag_columns: list[str]): + all_data = {**tag_data, **packet_data} + table = pa.table(all_data) + return ArrowTableSource(table, tag_columns=tag_columns) + + +def _get_system_tag_columns(table: pa.Table) -> list[str]: + return [c for c in table.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX)] + + +# =================================================================== +# Source creates system tag column +# =================================================================== + + +class TestSourceSystemTags: + """Per design: each source adds a system tag column encoding provenance.""" + + def test_source_creates_system_tag_column(self): + source = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"value": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + table = source.as_table(all_info=True) + tag_cols = _get_system_tag_columns(table) + assert len(tag_cols) >= 1, "Source should add at least one system tag column" + + +# =================================================================== +# Name-preserving (single-stream ops) +# =================================================================== + + +class TestNamePreserving: + """Per design: single-stream ops preserve system tag column names and values.""" + + def test_filter_preserves_system_tags(self): + source = _make_source( + {"id": pa.array([1, 2, 3], type=pa.int64())}, + {"value": pa.array([10, 20, 30], type=pa.int64())}, + ["id"], + ) + source_table = source.as_table(all_info=True) + source_tag_cols = _get_system_tag_columns(source_table) + + filt = PolarsFilter(constraints={"id": 2}) + result = filt.process(source) + result_table = result.as_table(all_info=True) + result_tag_cols = _get_system_tag_columns(result_table) + + # Column names should be identical + assert set(source_tag_cols) == set(result_tag_cols) + + def test_select_preserves_system_tags(self): + source = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + source_table = source.as_table(all_info=True) + source_tag_cols = _get_system_tag_columns(source_table) + + select = SelectPacketColumns(columns=["a"]) + result = select.process(source) + result_table = result.as_table(all_info=True) + result_tag_cols = _get_system_tag_columns(result_table) + + assert set(source_tag_cols) == set(result_tag_cols) + + def test_map_preserves_system_tags(self): + source = _make_source( + { + "id": pa.array([1, 2], type=pa.int64()), + "group": pa.array(["a", "b"], type=pa.large_string()), + }, + {"value": pa.array([10, 20], type=pa.int64())}, + ["id", "group"], + ) + source_table = source.as_table(all_info=True) + source_tag_cols = _get_system_tag_columns(source_table) + + mapper = MapTags(name_map={"id": "item_id"}) + result = mapper.process(source) + result_table = result.as_table(all_info=True) + result_tag_cols = _get_system_tag_columns(result_table) + + assert set(source_tag_cols) == set(result_tag_cols) + + +# =================================================================== +# Name-extending (multi-input ops) +# =================================================================== + + +class TestNameExtending: + """Per design: multi-input ops extend system tag column names with + ::pipeline_hash:canonical_position.""" + + def test_join_extends_system_tag_names(self): + source_a = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + source_b = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + + # Get original system tag column names + a_tags = _get_system_tag_columns(source_a.as_table(all_info=True)) + b_tags = _get_system_tag_columns(source_b.as_table(all_info=True)) + + join = Join() + result = join.process(source_a, source_b) + result_table = result.as_table(all_info=True) + result_tags = _get_system_tag_columns(result_table) + + # After join, system tag columns should be extended (longer names) + # Each input contributes system tag columns with extended names + assert len(result_tags) >= len(a_tags) + len(b_tags) + + def test_join_sorts_system_tag_values_for_commutativity(self): + """Per design: commutative ops sort paired tag values per row.""" + source_a = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + source_b = _make_source( + {"id": pa.array([1, 2], type=pa.int64())}, + {"b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + + join = Join() + result_ab = join.process(source_a, source_b) + result_ba = join.process(source_b, source_a) + + table_ab = result_ab.as_table(all_info=True) + table_ba = result_ba.as_table(all_info=True) + + # System tag column names should be identical for commutative join + tags_ab = sorted(_get_system_tag_columns(table_ab)) + tags_ba = sorted(_get_system_tag_columns(table_ba)) + assert tags_ab == tags_ba + + +# =================================================================== +# Type-evolving (aggregation ops) +# =================================================================== + + +class TestTypeEvolving: + """Per design: batch operation changes system tag type from str to list[str].""" + + def test_batch_evolves_system_tag_type(self): + source = _make_source( + {"group": pa.array(["a", "a", "b"], type=pa.large_string())}, + {"value": pa.array([1, 2, 3], type=pa.int64())}, + ["group"], + ) + source_table = source.as_table(all_info=True) + source_tag_cols = _get_system_tag_columns(source_table) + + batch = Batch() + result = batch.process(source) + result_table = result.as_table(all_info=True) + result_tag_cols = _get_system_tag_columns(result_table) + + # System tag columns should exist in output + assert len(result_tag_cols) == len(source_tag_cols) + + # The type should have evolved to list + for col_name in result_tag_cols: + col_type = result_table.schema.field(col_name).type + assert pa.types.is_list(col_type) or pa.types.is_large_list( + col_type + ), f"Expected list type for {col_name} after batch, got {col_type}" + + +# =================================================================== +# Full pipeline provenance chain +# =================================================================== + + +class TestFullProvenanceChain: + """End-to-end: source → join → filter → batch with all rules applied.""" + + def test_full_chain(self): + source_a = _make_source( + {"group": pa.array(["x", "x", "y"], type=pa.large_string())}, + {"a": pa.array([1, 2, 3], type=pa.int64())}, + ["group"], + ) + source_b = _make_source( + {"group": pa.array(["x", "y", "y"], type=pa.large_string())}, + {"b": pa.array([10, 20, 30], type=pa.int64())}, + ["group"], + ) + + # Step 1: Join (name-extending) + join = Join() + joined = join.process(source_a, source_b) + + # Step 2: Filter (name-preserving) + filt = PolarsFilter(constraints={"group": "x"}) + filtered = filt.process(joined) + + # Step 3: Batch (type-evolving) + batch = Batch() + batched = batch.process(filtered) + + table = batched.as_table(all_info=True) + tag_cols = _get_system_tag_columns(table) + + # After all three stages, system tags should exist + assert len(tag_cols) > 0 + + # After batch, types should be lists + for col_name in tag_cols: + col_type = table.schema.field(col_name).type + assert pa.types.is_list(col_type) or pa.types.is_large_list(col_type) diff --git a/test-objective/property/__init__.py b/test-objective/property/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test-objective/property/test_hash_properties.py b/test-objective/property/test_hash_properties.py new file mode 100644 index 00000000..0031b1c5 --- /dev/null +++ b/test-objective/property/test_hash_properties.py @@ -0,0 +1,93 @@ +"""Property-based tests for hashing determinism and ContentHash roundtrips. + +Tests that hashing invariants hold for any valid input. +""" + +from __future__ import annotations + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from orcapod.contexts import get_default_context +from orcapod.types import ContentHash + + +# --------------------------------------------------------------------------- +# Strategies +# --------------------------------------------------------------------------- + +# Primitives that the hasher should handle +hashable_primitives = st.one_of( + st.integers(min_value=-10000, max_value=10000), + st.floats(allow_nan=False, allow_infinity=False), + st.text(min_size=0, max_size=50), + st.booleans(), + st.none(), +) + +# ContentHash strategy +content_hashes = st.builds( + ContentHash, + method=st.text(min_size=1, max_size=20).filter(lambda s: ":" not in s), + digest=st.binary(min_size=4, max_size=32), +) + + +# =================================================================== +# Hash determinism +# =================================================================== + + +class TestHashDeterminism: + """Per design: hash(X) == hash(X) for any X.""" + + @given(hashable_primitives) + @settings(max_examples=50) + def test_same_input_same_hash(self, value): + ctx = get_default_context() + hasher = ctx.semantic_hasher + h1 = hasher.hash_object(value) + h2 = hasher.hash_object(value) + assert h1 == h2 + + +# =================================================================== +# ContentHash string roundtrip +# =================================================================== + + +class TestContentHashStringRoundtrip: + """Per design: from_string(to_string(h)) == h.""" + + @given(content_hashes) + @settings(max_examples=50) + def test_roundtrip(self, h): + s = h.to_string() + recovered = ContentHash.from_string(s) + assert recovered.method == h.method + assert recovered.digest == h.digest + + +class TestContentHashHexConsistency: + """to_hex() truncation should be consistent.""" + + @given(content_hashes, st.integers(min_value=1, max_value=64)) + @settings(max_examples=50) + def test_truncation_is_prefix(self, h, length): + full_hex = h.to_hex() + truncated = h.to_hex(length) + assert full_hex.startswith(truncated) + + +class TestContentHashEquality: + """Equal ContentHash objects have equal conversions.""" + + @given(content_hashes) + @settings(max_examples=50) + def test_equal_hashes_equal_conversions(self, h): + h2 = ContentHash(h.method, h.digest) + assert h.to_hex() == h2.to_hex() + assert h.to_int() == h2.to_int() + assert h.to_uuid() == h2.to_uuid() + assert h.to_base64() == h2.to_base64() diff --git a/test-objective/property/test_operator_algebra.py b/test-objective/property/test_operator_algebra.py new file mode 100644 index 00000000..203cda11 --- /dev/null +++ b/test-objective/property/test_operator_algebra.py @@ -0,0 +1,208 @@ +"""Property-based tests for operator algebraic properties. + +Tests mathematical properties that operators must satisfy: +- Join commutativity +- Join associativity +- Filter idempotency +- Select composition +- Drop composition +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.operators import ( + DropPacketColumns, + Join, + PolarsFilter, + SelectPacketColumns, +) +from orcapod.core.streams import ArrowTableStream + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sorted_rows(table: pa.Table, sort_col: str = "id") -> list[dict]: + """Extract rows as sorted list of dicts for comparison.""" + df = table.to_pydict() + rows = [] + n = table.num_rows + for i in range(n): + row = {k: df[k][i] for k in df if not k.startswith("_")} + rows.append(row) + return sorted(rows, key=lambda r: r.get(sort_col, 0)) + + +def _make_stream(tag_data: dict, packet_data: dict, tag_cols: list[str]) -> ArrowTableStream: + all_data = {**tag_data, **packet_data} + return ArrowTableStream(pa.table(all_data), tag_columns=tag_cols) + + +# =================================================================== +# Join commutativity +# =================================================================== + + +class TestJoinCommutativity: + """Per design: Join is commutative — join(A, B) produces same data as join(B, A).""" + + def test_two_way_commutativity(self): + sa = _make_stream( + {"id": pa.array([1, 2, 3], type=pa.int64())}, + {"a": pa.array([10, 20, 30], type=pa.int64())}, + ["id"], + ) + sb = _make_stream( + {"id": pa.array([2, 3, 4], type=pa.int64())}, + {"b": pa.array([200, 300, 400], type=pa.int64())}, + ["id"], + ) + join = Join() + result_ab = join.process(sa, sb) + result_ba = join.process(sb, sa) + + rows_ab = _sorted_rows(result_ab.as_table()) + rows_ba = _sorted_rows(result_ba.as_table()) + assert rows_ab == rows_ba + + +# =================================================================== +# Join associativity +# =================================================================== + + +class TestJoinAssociativity: + """Per design: join(join(A,B),C) should produce same data as join(A,join(B,C)) + when all have non-overlapping packet columns.""" + + def test_three_way_associativity(self): + sa = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + sb = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"b": pa.array([100, 200], type=pa.int64())}, + ["id"], + ) + sc = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"c": pa.array([1000, 2000], type=pa.int64())}, + ["id"], + ) + + join = Join() + + # (A join B) join C + ab = join.process(sa, sb) + abc_left = join.process(ab, sc) + + # A join (B join C) + bc = join.process(sb, sc) + abc_right = join.process(sa, bc) + + rows_left = _sorted_rows(abc_left.as_table()) + rows_right = _sorted_rows(abc_right.as_table()) + assert rows_left == rows_right + + +# =================================================================== +# Filter idempotency +# =================================================================== + + +class TestFilterIdempotency: + """filter(filter(S, P), P) == filter(S, P) — filtering twice with + the same predicate is the same as filtering once.""" + + def test_filter_idempotent(self): + stream = _make_stream( + {"id": pa.array([1, 2, 3, 4, 5], type=pa.int64())}, + {"value": pa.array([10, 20, 30, 40, 50], type=pa.int64())}, + ["id"], + ) + + filt = PolarsFilter(constraints={"id": 3}) + once = filt.process(stream) + twice = filt.process(once) + + table_once = once.as_table() + table_twice = twice.as_table() + assert table_once.num_rows == table_twice.num_rows + assert _sorted_rows(table_once) == _sorted_rows(table_twice) + + +# =================================================================== +# Select composition +# =================================================================== + + +class TestSelectComposition: + """select(select(S, X), Y) == select(S, X∩Y).""" + + def test_select_then_select_is_intersection(self): + stream = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + { + "a": pa.array([10, 20], type=pa.int64()), + "b": pa.array([30, 40], type=pa.int64()), + "c": pa.array([50, 60], type=pa.int64()), + }, + ["id"], + ) + + # select(S, {a, b}) then select(result, {b, c}) → should keep only {b} + sel1 = SelectPacketColumns(columns=["a", "b"]) + step1 = sel1.process(stream) + + sel2 = SelectPacketColumns(columns=["b"]) + step2 = sel2.process(step1) + + # Direct intersection: select {a,b} ∩ {b,c} = {b} + sel_direct = SelectPacketColumns(columns=["b"]) + direct = sel_direct.process(stream) + + _, step2_keys = step2.keys() + _, direct_keys = direct.keys() + assert set(step2_keys) == set(direct_keys) + + +# =================================================================== +# Drop composition +# =================================================================== + + +class TestDropComposition: + """drop(drop(S, X), Y) == drop(S, X∪Y).""" + + def test_drop_then_drop_is_union(self): + stream = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + { + "a": pa.array([10, 20], type=pa.int64()), + "b": pa.array([30, 40], type=pa.int64()), + "c": pa.array([50, 60], type=pa.int64()), + }, + ["id"], + ) + + # drop(S, {a}) then drop(result, {b}) → should drop {a, b} + drop1 = DropPacketColumns(columns=["a"]) + step1 = drop1.process(stream) + + drop2 = DropPacketColumns(columns=["b"]) + step2 = drop2.process(step1) + + # Direct: drop {a} ∪ {b} = drop {a, b} + drop_direct = DropPacketColumns(columns=["a", "b"]) + direct = drop_direct.process(stream) + + _, step2_keys = step2.keys() + _, direct_keys = direct.keys() + assert set(step2_keys) == set(direct_keys) diff --git a/test-objective/property/test_schema_properties.py b/test-objective/property/test_schema_properties.py new file mode 100644 index 00000000..f9af47e0 --- /dev/null +++ b/test-objective/property/test_schema_properties.py @@ -0,0 +1,124 @@ +"""Property-based tests for Schema algebra using Hypothesis. + +Tests algebraic properties that must hold for any valid input, +not just hand-picked examples. +""" + +from __future__ import annotations + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from orcapod.types import Schema + +# --------------------------------------------------------------------------- +# Strategies +# --------------------------------------------------------------------------- + +# Simple Python types that Schema supports +simple_types = st.sampled_from([int, float, str, bool, bytes]) + +# Field name strategy +field_names = st.text( + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"), + min_size=1, + max_size=10, +).filter(lambda s: s[0].isalpha()) + +# Schema strategy: dict of 1-5 fields +schema_dicts = st.dictionaries(field_names, simple_types, min_size=1, max_size=5) + + +def make_schema(d: dict) -> Schema: + return Schema(d) + + +# =================================================================== +# Schema merge commutativity +# =================================================================== + + +class TestSchemaMergeCommutativity: + """merge(A, B) == merge(B, A) when schemas are compatible.""" + + @given(schema_dicts, schema_dicts) + @settings(max_examples=50) + def test_merge_commutative_when_compatible(self, d1, d2): + s1 = make_schema(d1) + s2 = make_schema(d2) + + # Check if they're compatible (no type conflicts) + conflicts = {k for k in d2 if k in d1 and d1[k] != d2[k]} + if conflicts: + return # Skip incompatible schemas + + merged_ab = s1.merge(s2) + merged_ba = s2.merge(s1) + assert dict(merged_ab) == dict(merged_ba) + + +# =================================================================== +# Schema is_compatible_with is reflexive +# =================================================================== + + +class TestSchemaCompatibilityReflexive: + """A.is_compatible_with(A) should always be True.""" + + @given(schema_dicts) + @settings(max_examples=50) + def test_reflexive(self, d): + s = make_schema(d) + assert s.is_compatible_with(s) + + +# =================================================================== +# Schema select/drop complementarity +# =================================================================== + + +class TestSchemaSelectDropComplementary: + """select(X) ∪ drop(X) should recover the original schema's fields.""" + + @given(schema_dicts) + @settings(max_examples=50) + def test_select_drop_complementary(self, d): + s = make_schema(d) + if len(s) < 2: + return # Need at least 2 fields + + fields = list(s.keys()) + mid = len(fields) // 2 + selected_fields = fields[:mid] + dropped_fields = fields[:mid] + + selected = s.select(*selected_fields) + dropped = s.drop(*dropped_fields) + + # Union of selected and dropped should cover all fields + all_keys = set(selected.keys()) | set(dropped.keys()) + assert all_keys == set(s.keys()) + + +# =================================================================== +# Schema optional_fields is subset of all fields +# =================================================================== + + +class TestSchemaOptionalFieldsSubset: + """optional_fields should always be a subset of all field names.""" + + @given(schema_dicts, st.lists(field_names, max_size=3)) + @settings(max_examples=50) + def test_optional_subset(self, d, optional_candidates): + # Only use candidates that are actual fields + valid_optional = [f for f in optional_candidates if f in d] + s = Schema(d, optional_fields=valid_optional) + assert s.optional_fields.issubset(set(s.keys())) + + @given(schema_dicts) + @settings(max_examples=50) + def test_required_plus_optional_equals_all(self, d): + s = make_schema(d) + assert s.required_fields | s.optional_fields == set(s.keys()) diff --git a/test-objective/unit/__init__.py b/test-objective/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test-objective/unit/test_arrow_data_utils.py b/test-objective/unit/test_arrow_data_utils.py new file mode 100644 index 00000000..d9877b31 --- /dev/null +++ b/test-objective/unit/test_arrow_data_utils.py @@ -0,0 +1,196 @@ +"""Specification-derived tests for arrow_data_utils. + +Tests system tag manipulation, source info, and column helper functions +based on documented behavior in the design specification. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.system_constants import constants +from orcapod.utils.arrow_data_utils import ( + add_source_info, + add_system_tag_columns, + append_to_system_tags, + drop_columns_with_prefix, + drop_system_columns, + sort_system_tag_values, +) + + +# --------------------------------------------------------------------------- +# add_system_tag_columns +# --------------------------------------------------------------------------- + + +class TestAddSystemTagColumns: + """Per the design spec, system tag columns are prefixed with _tag_ and + track per-row provenance (source_id, record_id pairs).""" + + def test_adds_system_tag_columns(self): + table = pa.table({"id": [1, 2], "value": [10, 20]}) + result = add_system_tag_columns( + table, + schema_hash="abc123", + source_ids="src1", + record_ids=["rec1", "rec2"], + ) + # Should have original columns plus new system tag columns + assert result.num_rows == 2 + tag_cols = [ + c for c in result.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + assert len(tag_cols) > 0 + + def test_empty_table_returns_empty(self): + table = pa.table({"id": pa.array([], type=pa.int64())}) + result = add_system_tag_columns( + table, + schema_hash="abc", + source_ids="src1", + record_ids=[], + ) + assert result.num_rows == 0 + + def test_length_mismatch_raises(self): + table = pa.table({"id": [1, 2, 3]}) + with pytest.raises(ValueError): + add_system_tag_columns( + table, + schema_hash="abc", + source_ids=["s1", "s2"], # 2 source_ids for 3 rows + record_ids=["r1", "r2", "r3"], + ) + + +# --------------------------------------------------------------------------- +# append_to_system_tags +# --------------------------------------------------------------------------- + + +class TestAppendToSystemTags: + """Per design, appends a value to existing system tag columns.""" + + def test_appends_value_to_system_tags(self): + # Create a table that already has system tag columns + table = pa.table({"id": [1, 2], "value": [10, 20]}) + table_with_tags = add_system_tag_columns( + table, + schema_hash="abc", + source_ids="src1", + record_ids=["r1", "r2"], + ) + result = append_to_system_tags(table_with_tags, value="::extra:0") + # System tag column names should have changed (appended) + tag_cols_before = [ + c + for c in table_with_tags.column_names + if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + tag_cols_after = [ + c for c in result.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + # The column names should be extended + assert len(tag_cols_after) == len(tag_cols_before) + + def test_empty_table_returns_empty(self): + table = pa.table( + {"id": pa.array([], type=pa.int64()), "value": pa.array([], type=pa.int64())} + ) + result = append_to_system_tags(table, value="::extra:0") + assert result.num_rows == 0 + + +# --------------------------------------------------------------------------- +# sort_system_tag_values +# --------------------------------------------------------------------------- + + +class TestSortSystemTagValues: + """Per design, system tag values must be sorted for commutativity in + multi-input operators. Paired (source_id, record_id) tuples are sorted + together per row.""" + + def test_sorts_system_tag_values(self): + # This is a structural test — ensure the function runs without error + # and produces a table with the same shape + table = pa.table({"id": [1, 2], "value": [10, 20]}) + table_with_tags = add_system_tag_columns( + table, + schema_hash="abc", + source_ids="src1", + record_ids=["r1", "r2"], + ) + result = sort_system_tag_values(table_with_tags) + assert result.num_rows == table_with_tags.num_rows + + +# --------------------------------------------------------------------------- +# add_source_info +# --------------------------------------------------------------------------- + + +class TestAddSourceInfo: + """Per design, source info columns are prefixed with _source_ and track + provenance tokens per packet column.""" + + def test_adds_source_info_columns(self): + table = pa.table({"id": [1, 2], "value": [10, 20]}) + result = add_source_info(table, source_info="src_token") + source_cols = [ + c for c in result.column_names if c.startswith(constants.SOURCE_PREFIX) + ] + assert len(source_cols) > 0 + + def test_source_info_length_mismatch_raises(self): + table = pa.table({"id": [1, 2], "value": [10, 20]}) + with pytest.raises(ValueError): + add_source_info(table, source_info=["a", "b", "c"]) # Wrong count + + +# --------------------------------------------------------------------------- +# drop_columns_with_prefix +# --------------------------------------------------------------------------- + + +class TestDropColumnsWithPrefix: + """Removes all columns matching a given prefix.""" + + def test_drops_columns_with_matching_prefix(self): + table = pa.table({"__meta_a": [1], "__meta_b": [2], "data": [3]}) + result = drop_columns_with_prefix(table, "__meta") + assert "data" in result.column_names + assert "__meta_a" not in result.column_names + assert "__meta_b" not in result.column_names + + def test_no_match_returns_unchanged(self): + table = pa.table({"a": [1], "b": [2]}) + result = drop_columns_with_prefix(table, "__nonexistent") + assert result.column_names == table.column_names + + def test_tuple_of_prefixes(self): + table = pa.table({"__a": [1], "_src_b": [2], "data": [3]}) + result = drop_columns_with_prefix(table, ("__", "_src_")) + assert result.column_names == ["data"] + + +# --------------------------------------------------------------------------- +# drop_system_columns +# --------------------------------------------------------------------------- + + +class TestDropSystemColumns: + """Removes columns with system prefixes (__ and datagram prefix).""" + + def test_drops_system_columns(self): + table = pa.table({"__meta": [1], "data": [2]}) + result = drop_system_columns(table) + assert "data" in result.column_names + assert "__meta" not in result.column_names + + def test_preserves_non_system_columns(self): + table = pa.table({"name": ["alice"], "age": [30]}) + result = drop_system_columns(table) + assert result.column_names == ["name", "age"] diff --git a/test-objective/unit/test_arrow_utils.py b/test-objective/unit/test_arrow_utils.py new file mode 100644 index 00000000..1f929d98 --- /dev/null +++ b/test-objective/unit/test_arrow_utils.py @@ -0,0 +1,322 @@ +"""Tests for Arrow utility functions. + +Specification-derived tests covering schema selection/dropping, type +normalization, row/column conversion, table stacking, schema compatibility +checking, and column group splitting. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.utils.arrow_utils import ( + check_arrow_schema_compatibility, + hstack_tables, + normalize_to_large_types, + pydict_to_pylist, + pylist_to_pydict, + schema_drop, + schema_select, + split_by_column_groups, +) + + +# =========================================================================== +# schema_select +# =========================================================================== + + +class TestSchemaSelect: + """Selects subset; KeyError for missing columns.""" + + def test_select_subset(self) -> None: + schema = pa.schema( + [ + pa.field("a", pa.int64()), + pa.field("b", pa.string()), + pa.field("c", pa.float64()), + ] + ) + result = schema_select(schema, ["a", "c"]) + assert result.names == ["a", "c"] + assert result.field("a").type == pa.int64() + assert result.field("c").type == pa.float64() + + def test_select_all(self) -> None: + schema = pa.schema([pa.field("x", pa.int64()), pa.field("y", pa.string())]) + result = schema_select(schema, ["x", "y"]) + assert result.names == ["x", "y"] + + def test_select_missing_column_raises(self) -> None: + schema = pa.schema([pa.field("a", pa.int64())]) + with pytest.raises(KeyError, match="Missing columns"): + schema_select(schema, ["a", "nonexistent"]) + + def test_select_missing_with_ignore(self) -> None: + schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string())]) + result = schema_select(schema, ["a", "nonexistent"], ignore_missing=True) + assert result.names == ["a"] + + +# =========================================================================== +# schema_drop +# =========================================================================== + + +class TestSchemaDrop: + """Drops specified columns; KeyError if missing and not ignore_missing.""" + + def test_drop_columns(self) -> None: + schema = pa.schema( + [ + pa.field("a", pa.int64()), + pa.field("b", pa.string()), + pa.field("c", pa.float64()), + ] + ) + result = schema_drop(schema, ["b"]) + assert result.names == ["a", "c"] + + def test_drop_missing_raises(self) -> None: + schema = pa.schema([pa.field("a", pa.int64())]) + with pytest.raises(KeyError, match="Missing columns"): + schema_drop(schema, ["nonexistent"]) + + def test_drop_missing_with_ignore(self) -> None: + schema = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string())]) + result = schema_drop(schema, ["nonexistent"], ignore_missing=True) + assert result.names == ["a", "b"] + + +# =========================================================================== +# normalize_to_large_types +# =========================================================================== + + +class TestNormalizeToLargeTypes: + """string -> large_string, binary -> large_binary, list -> large_list.""" + + def test_string_to_large_string(self) -> None: + assert normalize_to_large_types(pa.string()) == pa.large_string() + + def test_binary_to_large_binary(self) -> None: + assert normalize_to_large_types(pa.binary()) == pa.large_binary() + + def test_list_to_large_list(self) -> None: + result = normalize_to_large_types(pa.list_(pa.string())) + assert pa.types.is_large_list(result) + # Inner type should also be normalized. + assert result.value_type == pa.large_string() + + def test_large_string_unchanged(self) -> None: + assert normalize_to_large_types(pa.large_string()) == pa.large_string() + + def test_int64_unchanged(self) -> None: + assert normalize_to_large_types(pa.int64()) == pa.int64() + + def test_float64_unchanged(self) -> None: + assert normalize_to_large_types(pa.float64()) == pa.float64() + + def test_nested_struct_normalized(self) -> None: + struct_type = pa.struct([pa.field("name", pa.string())]) + result = normalize_to_large_types(struct_type) + assert pa.types.is_struct(result) + assert result[0].type == pa.large_string() + + def test_null_to_large_string(self) -> None: + assert normalize_to_large_types(pa.null()) == pa.large_string() + + +# =========================================================================== +# pylist_to_pydict +# =========================================================================== + + +class TestPylistToPydict: + """Row-oriented -> column-oriented conversion.""" + + def test_basic_conversion(self) -> None: + rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + result = pylist_to_pydict(rows) + assert result == {"a": [1, 3], "b": [2, 4]} + + def test_missing_keys_filled_with_none(self) -> None: + rows = [{"a": 1, "b": 2}, {"a": 3, "c": 4}] + result = pylist_to_pydict(rows) + assert result["a"] == [1, 3] + assert result["b"] == [2, None] + assert result["c"] == [None, 4] + + def test_empty_list(self) -> None: + result = pylist_to_pydict([]) + assert result == {} + + def test_single_row(self) -> None: + result = pylist_to_pydict([{"x": 10}]) + assert result == {"x": [10]} + + +# =========================================================================== +# pydict_to_pylist +# =========================================================================== + + +class TestPydictToPylist: + """Column-oriented -> row-oriented; ValueError on inconsistent lengths.""" + + def test_basic_conversion(self) -> None: + data = {"a": [1, 3], "b": [2, 4]} + result = pydict_to_pylist(data) + assert result == [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + + def test_empty_dict(self) -> None: + result = pydict_to_pylist({}) + assert result == [] + + def test_inconsistent_lengths_raises(self) -> None: + data = {"a": [1, 2], "b": [3]} + with pytest.raises(ValueError, match="Inconsistent"): + pydict_to_pylist(data) + + def test_single_column(self) -> None: + result = pydict_to_pylist({"x": [10, 20]}) + assert result == [{"x": 10}, {"x": 20}] + + +# =========================================================================== +# hstack_tables +# =========================================================================== + + +class TestHstackTables: + """Horizontal concat; ValueError for different row counts or duplicate columns.""" + + def test_basic_hstack(self) -> None: + t1 = pa.table({"a": [1, 2]}) + t2 = pa.table({"b": ["x", "y"]}) + result = hstack_tables(t1, t2) + assert result.column_names == ["a", "b"] + assert result.num_rows == 2 + + def test_single_table(self) -> None: + t1 = pa.table({"a": [1]}) + result = hstack_tables(t1) + assert result.column_names == ["a"] + + def test_different_row_counts_raises(self) -> None: + t1 = pa.table({"a": [1, 2]}) + t2 = pa.table({"b": [3]}) + with pytest.raises(ValueError, match="same number of rows"): + hstack_tables(t1, t2) + + def test_duplicate_columns_raises(self) -> None: + t1 = pa.table({"a": [1]}) + t2 = pa.table({"a": [2]}) + with pytest.raises(ValueError, match="Duplicate column name"): + hstack_tables(t1, t2) + + def test_no_tables_raises(self) -> None: + with pytest.raises(ValueError, match="At least one table"): + hstack_tables() + + def test_three_tables(self) -> None: + t1 = pa.table({"a": [1]}) + t2 = pa.table({"b": [2]}) + t3 = pa.table({"c": [3]}) + result = hstack_tables(t1, t2, t3) + assert result.column_names == ["a", "b", "c"] + assert result.num_rows == 1 + + +# =========================================================================== +# check_arrow_schema_compatibility +# =========================================================================== + + +class TestCheckArrowSchemaCompatibility: + """Returns (is_compatible, errors).""" + + def test_compatible_schemas(self) -> None: + s1 = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string())]) + s2 = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string())]) + is_compat, errors = check_arrow_schema_compatibility(s1, s2) + assert is_compat is True + assert errors == [] + + def test_missing_field_incompatible(self) -> None: + incoming = pa.schema([pa.field("a", pa.int64())]) + target = pa.schema([pa.field("a", pa.int64()), pa.field("b", pa.string())]) + is_compat, errors = check_arrow_schema_compatibility(incoming, target) + assert is_compat is False + assert any("Missing field" in e for e in errors) + + def test_type_mismatch(self) -> None: + incoming = pa.schema([pa.field("a", pa.string())]) + target = pa.schema([pa.field("a", pa.int64())]) + is_compat, errors = check_arrow_schema_compatibility(incoming, target) + assert is_compat is False + assert any("Type mismatch" in e for e in errors) + + def test_extra_fields_allowed_non_strict(self) -> None: + incoming = pa.schema( + [pa.field("a", pa.int64()), pa.field("extra", pa.string())] + ) + target = pa.schema([pa.field("a", pa.int64())]) + is_compat, errors = check_arrow_schema_compatibility(incoming, target) + assert is_compat is True + assert errors == [] + + def test_extra_fields_rejected_strict(self) -> None: + incoming = pa.schema( + [pa.field("a", pa.int64()), pa.field("extra", pa.string())] + ) + target = pa.schema([pa.field("a", pa.int64())]) + is_compat, errors = check_arrow_schema_compatibility( + incoming, target, strict=True + ) + assert is_compat is False + assert any("Unexpected field" in e for e in errors) + + +# =========================================================================== +# split_by_column_groups +# =========================================================================== + + +class TestSplitByColumnGroups: + """Splits table by column groups.""" + + def test_basic_split(self) -> None: + table = pa.table({"a": [1], "b": [2], "c": [3], "d": [4]}) + result = split_by_column_groups(table, ["a", "b"], ["c"]) + # result[0] = remaining (d), result[1] = group1 (a,b), result[2] = group2 (c) + assert result[0] is not None + assert result[0].column_names == ["d"] + assert result[1] is not None + assert set(result[1].column_names) == {"a", "b"} + assert result[2] is not None + assert result[2].column_names == ["c"] + + def test_no_groups_returns_full_table(self) -> None: + table = pa.table({"a": [1], "b": [2]}) + result = split_by_column_groups(table) + assert len(result) == 1 + assert result[0].column_names == ["a", "b"] + + def test_all_columns_in_groups(self) -> None: + table = pa.table({"a": [1], "b": [2]}) + result = split_by_column_groups(table, ["a"], ["b"]) + # remaining should be None + assert result[0] is None + assert result[1] is not None + assert result[2] is not None + + def test_empty_group_returns_none(self) -> None: + table = pa.table({"a": [1], "b": [2]}) + result = split_by_column_groups(table, ["nonexistent"]) + # Group with nonexistent columns returns None + assert result[1] is None + # Remaining should have both columns + assert result[0] is not None + assert set(result[0].column_names) == {"a", "b"} diff --git a/test-objective/unit/test_contexts.py b/test-objective/unit/test_contexts.py new file mode 100644 index 00000000..f696146c --- /dev/null +++ b/test-objective/unit/test_contexts.py @@ -0,0 +1,73 @@ +"""Specification-derived tests for DataContext resolution and validation. + +Tests based on the documented context management API. +""" + +from __future__ import annotations + +import pytest + +from orcapod.contexts import ( + get_available_contexts, + get_default_context, + resolve_context, +) +from orcapod.contexts.core import ContextResolutionError, DataContext + + +class TestResolveContext: + """Per the documented API, resolve_context handles None, str, and + DataContext inputs.""" + + def test_none_returns_default(self): + ctx = resolve_context(None) + assert isinstance(ctx, DataContext) + + def test_string_version_resolves(self): + ctx = resolve_context("v0.1") + assert isinstance(ctx, DataContext) + assert "v0.1" in ctx.context_key + + def test_datacontext_passthrough(self): + original = get_default_context() + result = resolve_context(original) + assert result is original + + def test_invalid_version_raises(self): + with pytest.raises((ContextResolutionError, KeyError, ValueError)): + resolve_context("v999.999") + + +class TestGetAvailableContexts: + """Per the documented API, returns sorted list of version strings.""" + + def test_returns_list(self): + contexts = get_available_contexts() + assert isinstance(contexts, list) + assert len(contexts) > 0 + + def test_includes_v01(self): + contexts = get_available_contexts() + assert "v0.1" in contexts + + +class TestDefaultContextComponents: + """Per the design, the default context has type_converter, arrow_hasher, + and semantic_hasher.""" + + def test_has_type_converter(self): + ctx = get_default_context() + assert ctx.type_converter is not None + + def test_has_arrow_hasher(self): + ctx = get_default_context() + assert ctx.arrow_hasher is not None + + def test_has_semantic_hasher(self): + ctx = get_default_context() + assert ctx.semantic_hasher is not None + + def test_has_context_key(self): + ctx = get_default_context() + assert isinstance(ctx.context_key, str) + assert len(ctx.context_key) > 0 diff --git a/test-objective/unit/test_databases.py b/test-objective/unit/test_databases.py new file mode 100644 index 00000000..04def11e --- /dev/null +++ b/test-objective/unit/test_databases.py @@ -0,0 +1,290 @@ +"""Tests for InMemoryArrowDatabase, NoOpArrowDatabase, and DeltaTableDatabase. + +Specification-derived tests covering record CRUD, pending-batch semantics, +duplicate handling, and database-specific behaviors. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.databases import DeltaTableDatabase, InMemoryArrowDatabase, NoOpArrowDatabase + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_record(value: int = 1) -> pa.Table: + """Create a simple single-row Arrow table for testing.""" + return pa.table({"col_a": [value], "col_b": [f"val_{value}"]}) + + +def _make_records(n: int = 3) -> pa.Table: + """Create a multi-row Arrow table for testing.""" + return pa.table( + {"col_a": list(range(n)), "col_b": [f"val_{i}" for i in range(n)]} + ) + + +# =========================================================================== +# InMemoryArrowDatabase +# =========================================================================== + + +class TestInMemoryArrowDatabaseRoundtrip: + """add_record + get_record_by_id roundtrip.""" + + def test_add_and_get_single_record(self) -> None: + db = InMemoryArrowDatabase() + path = ("test", "table") + record = _make_record(42) + + db.add_record(path, "rec_1", record, flush=True) + result = db.get_record_by_id(path, "rec_1") + + assert result is not None + assert result.num_rows == 1 + assert result["col_a"].to_pylist() == [42] + + def test_add_and_get_preserves_data(self) -> None: + db = InMemoryArrowDatabase() + path = ("data",) + record = pa.table({"x": [10], "y": ["hello"]}) + + db.add_record(path, "id1", record, flush=True) + result = db.get_record_by_id(path, "id1") + + assert result is not None + assert result["x"].to_pylist() == [10] + assert result["y"].to_pylist() == ["hello"] + + +class TestInMemoryArrowDatabaseBatchAdd: + """add_records batch with multiple rows.""" + + def test_add_records_multiple_rows(self) -> None: + db = InMemoryArrowDatabase() + path = ("batch",) + records = pa.table( + { + "__record_id": ["a", "b", "c"], + "value": [1, 2, 3], + } + ) + + db.add_records(path, records, record_id_column="__record_id", flush=True) + all_records = db.get_all_records(path) + + assert all_records is not None + assert all_records.num_rows == 3 + + +class TestInMemoryArrowDatabaseGetAll: + """get_all_records returns all at path.""" + + def test_get_all_records(self) -> None: + db = InMemoryArrowDatabase() + path = ("multi",) + + db.add_record(path, "r1", _make_record(1), flush=True) + db.add_record(path, "r2", _make_record(2), flush=True) + + all_records = db.get_all_records(path) + assert all_records is not None + assert all_records.num_rows == 2 + + +class TestInMemoryArrowDatabaseGetByIds: + """get_records_by_ids returns subset.""" + + def test_get_records_by_ids_subset(self) -> None: + db = InMemoryArrowDatabase() + path = ("subset",) + + for i in range(5): + db.add_record(path, f"id_{i}", _make_record(i)) + db.flush() + + result = db.get_records_by_ids(path, ["id_1", "id_3"]) + assert result is not None + assert result.num_rows == 2 + + +class TestInMemoryArrowDatabaseSkipDuplicates: + """skip_duplicates=True doesn't raise on duplicate.""" + + def test_skip_duplicates_no_error(self) -> None: + db = InMemoryArrowDatabase() + path = ("dup",) + + db.add_record(path, "same_id", _make_record(1), flush=True) + # Adding same ID again with skip_duplicates=True should not raise. + db.add_record(path, "same_id", _make_record(2), skip_duplicates=True, flush=True) + + result = db.get_record_by_id(path, "same_id") + assert result is not None + # Original record should be preserved (duplicate was skipped). + assert result.num_rows == 1 + + +class TestInMemoryArrowDatabasePendingBatch: + """Pending batch semantics: records not visible until flush().""" + + def test_records_accessible_before_flush(self) -> None: + db = InMemoryArrowDatabase() + path = ("pending",) + + db.add_record(path, "p1", _make_record(1)) + # Records should be accessible via public API even before flush + result = db.get_record_by_id(path, "p1") + assert result is not None, "Record should be accessible via get_record_by_id before flush" + + def test_flush_makes_records_visible(self) -> None: + db = InMemoryArrowDatabase() + path = ("pending",) + + db.add_record(path, "p1", _make_record(1)) + db.flush() + + # After flush, records should still be accessible via public API + result = db.get_record_by_id(path, "p1") + assert result is not None, "Record should be accessible after flush" + + all_records = db.get_all_records(path) + assert all_records is not None + assert all_records.num_rows >= 1 + + +class TestInMemoryArrowDatabaseFlush: + """flush() makes records visible.""" + + def test_flush_commits_pending(self) -> None: + db = InMemoryArrowDatabase() + path = ("flush_test",) + + db.add_record(path, "f1", _make_record(10)) + db.add_record(path, "f2", _make_record(20)) + db.flush() + + all_records = db.get_all_records(path) + assert all_records is not None + assert all_records.num_rows == 2 + + +class TestInMemoryArrowDatabaseInvalidPath: + """Invalid path (empty tuple) raises ValueError.""" + + def test_empty_path_raises(self) -> None: + db = InMemoryArrowDatabase() + with pytest.raises(ValueError, match="cannot be empty"): + db.add_record((), "id", _make_record()) + + +class TestInMemoryArrowDatabaseNonexistentPath: + """get_* on nonexistent path returns None.""" + + def test_get_record_by_id_nonexistent(self) -> None: + db = InMemoryArrowDatabase() + result = db.get_record_by_id(("no", "such"), "missing_id") + assert result is None + + def test_get_all_records_nonexistent(self) -> None: + db = InMemoryArrowDatabase() + result = db.get_all_records(("no", "such")) + assert result is None + + def test_get_records_by_ids_nonexistent(self) -> None: + db = InMemoryArrowDatabase() + result = db.get_records_by_ids(("no", "such"), ["a", "b"]) + assert result is None + + +# =========================================================================== +# NoOpArrowDatabase +# =========================================================================== + + +class TestNoOpArrowDatabaseWrites: + """All writes silently discarded (no errors).""" + + def test_add_record_no_error(self) -> None: + db = NoOpArrowDatabase() + db.add_record(("path",), "id", _make_record()) + + def test_add_records_no_error(self) -> None: + db = NoOpArrowDatabase() + db.add_records(("path",), _make_records()) + + +class TestNoOpArrowDatabaseReads: + """All reads return None.""" + + def test_get_record_by_id_returns_none(self) -> None: + db = NoOpArrowDatabase() + db.add_record(("path",), "id", _make_record()) + assert db.get_record_by_id(("path",), "id") is None + + def test_get_all_records_returns_none(self) -> None: + db = NoOpArrowDatabase() + assert db.get_all_records(("path",)) is None + + def test_get_records_by_ids_returns_none(self) -> None: + db = NoOpArrowDatabase() + assert db.get_records_by_ids(("path",), ["a"]) is None + + def test_get_records_with_column_value_returns_none(self) -> None: + db = NoOpArrowDatabase() + assert db.get_records_with_column_value(("path",), {"col": "val"}) is None + + +class TestNoOpArrowDatabaseFlush: + """flush() is a no-op (no errors).""" + + def test_flush_no_error(self) -> None: + db = NoOpArrowDatabase() + db.flush() # Should not raise + + +# =========================================================================== +# DeltaTableDatabase (slow tests) +# =========================================================================== + + +@pytest.mark.slow +class TestDeltaTableDatabaseRoundtrip: + """add_record + get_record_by_id roundtrip (uses tmp_path fixture).""" + + def test_add_and_get_record(self, tmp_path: object) -> None: + db = DeltaTableDatabase(base_path=tmp_path) + path = ("delta", "test") + record = _make_record(99) + + db.add_record(path, "d1", record, flush=True) + result = db.get_record_by_id(path, "d1") + + assert result is not None + assert result.num_rows == 1 + assert result["col_a"].to_pylist() == [99] + + +@pytest.mark.slow +class TestDeltaTableDatabaseFlush: + """flush writes to disk.""" + + def test_flush_persists_to_disk(self, tmp_path: object) -> None: + db = DeltaTableDatabase(base_path=tmp_path) + path = ("persist",) + record = _make_record(7) + + db.add_record(path, "p1", record) + db.flush() + + # Create a new database instance pointing at the same path to verify + # data was persisted. + db2 = DeltaTableDatabase(base_path=tmp_path, create_base_path=False) + result = db2.get_record_by_id(path, "p1") + assert result is not None + assert result["col_a"].to_pylist() == [7] diff --git a/test-objective/unit/test_datagram.py b/test-objective/unit/test_datagram.py new file mode 100644 index 00000000..ba4b4525 --- /dev/null +++ b/test-objective/unit/test_datagram.py @@ -0,0 +1,281 @@ +"""Specification-derived tests for Datagram.""" + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.datagram import Datagram +from orcapod.types import ColumnConfig + + +# --------------------------------------------------------------------------- +# Helper to create a DataContext (needed for Arrow conversions) +# --------------------------------------------------------------------------- + +def _make_context(): + """Create a DataContext for tests that need Arrow conversion.""" + from orcapod.contexts import resolve_context + return resolve_context(None) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +class TestDatagramConstruction: + """Datagram can be constructed from dict, pa.Table, pa.RecordBatch.""" + + def test_construct_from_dict(self): + dg = Datagram({"x": 1, "y": "hello"}, data_context=_make_context()) + assert "x" in dg + assert dg["x"] == 1 + + def test_construct_from_arrow_table(self): + table = pa.table({"x": [1], "y": ["hello"]}) + dg = Datagram(table, data_context=_make_context()) + assert "x" in dg + assert "y" in dg + + def test_construct_from_record_batch(self): + batch = pa.record_batch({"x": [1], "y": ["hello"]}) + dg = Datagram(batch, data_context=_make_context()) + assert "x" in dg + assert "y" in dg + + +# --------------------------------------------------------------------------- +# Dict-like access +# --------------------------------------------------------------------------- + +class TestDatagramDictAccess: + """Dict-like access: __getitem__, __contains__, __iter__, get().""" + + def _make_datagram(self): + return Datagram({"a": 10, "b": "text"}, data_context=_make_context()) + + def test_getitem(self): + dg = self._make_datagram() + assert dg["a"] == 10 + + def test_contains(self): + dg = self._make_datagram() + assert "a" in dg + assert "missing" not in dg + + def test_iter(self): + dg = self._make_datagram() + keys = list(dg) + assert set(keys) == {"a", "b"} + + def test_get_existing(self): + dg = self._make_datagram() + assert dg.get("a") == 10 + + def test_get_missing_returns_default(self): + dg = self._make_datagram() + assert dg.get("missing", 42) == 42 + + +# --------------------------------------------------------------------------- +# Lazy conversion +# --------------------------------------------------------------------------- + +class TestDatagramLazyConversion: + """Dict access uses dict backing; as_table() triggers Arrow conversion.""" + + def test_dict_constructed_datagram_dict_access_no_arrow(self): + """Accessing a dict-constructed datagram by key should work without Arrow.""" + dg = Datagram({"x": 1}, data_context=_make_context()) + assert dg["x"] == 1 + + def test_as_table_returns_arrow_table(self): + dg = Datagram({"x": 1, "y": "hello"}, data_context=_make_context()) + table = dg.as_table() + assert isinstance(table, pa.Table) + + def test_arrow_constructed_as_dict_returns_dict(self): + table = pa.table({"x": [1], "y": ["hello"]}) + dg = Datagram(table, data_context=_make_context()) + d = dg.as_dict() + assert isinstance(d, dict) + assert "x" in d + + +# --------------------------------------------------------------------------- +# Round-trip +# --------------------------------------------------------------------------- + +class TestDatagramRoundTrip: + """dict->Arrow->dict and Arrow->dict->Arrow preserve data.""" + + def test_dict_to_arrow_to_dict(self): + ctx = _make_context() + original = {"x": 1, "y": "hello"} + dg = Datagram(original, data_context=ctx) + # Force Arrow conversion + _ = dg.as_table() + # Convert back to dict + result = dg.as_dict() + assert result["x"] == original["x"] + assert result["y"] == original["y"] + + def test_arrow_to_dict_to_arrow(self): + ctx = _make_context() + table = pa.table({"x": [42], "y": ["world"]}) + dg = Datagram(table, data_context=ctx) + # Force dict conversion + _ = dg.as_dict() + # Convert back to Arrow + result = dg.as_table() + assert isinstance(result, pa.Table) + assert result.column("x").to_pylist() == [42] + + +# --------------------------------------------------------------------------- +# Schema methods +# --------------------------------------------------------------------------- + +class TestDatagramSchemaMethods: + """keys(), schema(), arrow_schema() with ColumnConfig.""" + + def test_keys_returns_field_names(self): + dg = Datagram({"x": 1, "y": "hello"}, data_context=_make_context()) + assert set(dg.keys()) == {"x", "y"} + + def test_schema_returns_schema_object(self): + from orcapod.types import Schema + dg = Datagram({"x": 1, "y": "hello"}, data_context=_make_context()) + s = dg.schema() + assert isinstance(s, Schema) + assert "x" in s + assert "y" in s + + +# --------------------------------------------------------------------------- +# Format conversions +# --------------------------------------------------------------------------- + +class TestDatagramFormatConversions: + """as_dict(), as_table(), as_arrow_compatible_dict().""" + + def test_as_dict_returns_dict(self): + dg = Datagram({"x": 1}, data_context=_make_context()) + assert isinstance(dg.as_dict(), dict) + + def test_as_table_returns_table(self): + dg = Datagram({"x": 1}, data_context=_make_context()) + assert isinstance(dg.as_table(), pa.Table) + + def test_as_arrow_compatible_dict(self): + dg = Datagram({"x": 1, "y": "hello"}, data_context=_make_context()) + result = dg.as_arrow_compatible_dict() + assert isinstance(result, dict) + assert "x" in result + + +# --------------------------------------------------------------------------- +# Immutable operations +# --------------------------------------------------------------------------- + +class TestDatagramImmutableOperations: + """select, drop, rename, update, with_columns return NEW instances.""" + + def _make_datagram(self): + return Datagram({"a": 1, "b": 2, "c": 3}, data_context=_make_context()) + + def test_select_returns_new_instance(self): + dg = self._make_datagram() + selected = dg.select("a", "b") + assert selected is not dg + assert "a" in selected + assert "c" not in selected + + def test_drop_returns_new_instance(self): + dg = self._make_datagram() + dropped = dg.drop("c") + assert dropped is not dg + assert "c" not in dropped + assert "a" in dropped + + def test_rename_returns_new_instance(self): + dg = self._make_datagram() + renamed = dg.rename({"a": "alpha"}) + assert renamed is not dg + assert "alpha" in renamed + assert "a" not in renamed + + def test_update_returns_new_instance(self): + dg = self._make_datagram() + updated = dg.update(a=99) + assert updated is not dg + assert updated["a"] == 99 + assert dg["a"] == 1 # original unchanged + + def test_with_columns_returns_new_instance(self): + dg = self._make_datagram() + extended = dg.with_columns(d=4) + assert extended is not dg + assert "d" in extended + assert "d" not in dg + + def test_original_unchanged_after_select(self): + dg = self._make_datagram() + dg.select("a") + assert "b" in dg + assert "c" in dg + + +# --------------------------------------------------------------------------- +# Meta operations +# --------------------------------------------------------------------------- + +class TestDatagramMetaOperations: + """get_meta_value (auto-prefixed), with_meta_columns, drop_meta_columns.""" + + def test_with_meta_columns_adds_prefixed_columns(self): + dg = Datagram({"x": 1}, data_context=_make_context()) + with_meta = dg.with_meta_columns(my_meta="value") + assert with_meta is not dg + + def test_get_meta_value_retrieves_by_unprefixed_name(self): + dg = Datagram({"x": 1}, data_context=_make_context()) + with_meta = dg.with_meta_columns(my_meta="value") + val = with_meta.get_meta_value("my_meta") + assert val == "value" + + def test_drop_meta_columns(self): + dg = Datagram({"x": 1}, data_context=_make_context()) + with_meta = dg.with_meta_columns(my_meta="value") + dropped = with_meta.drop_meta_columns("my_meta") + assert dropped is not with_meta + + +# --------------------------------------------------------------------------- +# Content hashing +# --------------------------------------------------------------------------- + +class TestDatagramContentHashing: + """Content hashing is deterministic, changes with data, equality by content.""" + + def test_hashing_is_deterministic(self): + ctx = _make_context() + dg1 = Datagram({"x": 1, "y": "a"}, data_context=ctx) + dg2 = Datagram({"x": 1, "y": "a"}, data_context=ctx) + assert dg1.content_hash() == dg2.content_hash() + + def test_hash_changes_with_data(self): + ctx = _make_context() + dg1 = Datagram({"x": 1}, data_context=ctx) + dg2 = Datagram({"x": 2}, data_context=ctx) + assert dg1.content_hash() != dg2.content_hash() + + def test_equality_by_content(self): + ctx = _make_context() + dg1 = Datagram({"x": 1, "y": "a"}, data_context=ctx) + dg2 = Datagram({"x": 1, "y": "a"}, data_context=ctx) + assert dg1 == dg2 + + def test_inequality_by_content(self): + ctx = _make_context() + dg1 = Datagram({"x": 1}, data_context=ctx) + dg2 = Datagram({"x": 2}, data_context=ctx) + assert dg1 != dg2 diff --git a/test-objective/unit/test_function_pod.py b/test-objective/unit/test_function_pod.py new file mode 100644 index 00000000..a7ff704b --- /dev/null +++ b/test-objective/unit/test_function_pod.py @@ -0,0 +1,228 @@ +"""Specification-derived tests for FunctionPod and FunctionPodStream. + +Tests based on FunctionPodProtocol and documented behaviors: +- FunctionPod wraps a PacketFunction for per-packet transformation +- Never inspects or modifies tags +- Exactly one input stream +- output_schema() prediction matches actual output +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.tag_packet import Packet, Tag +from orcapod.core.function_pod import FunctionPod, FunctionPodStream, function_pod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.streams import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import Schema + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _double(x: int) -> int: + return x * 2 + + +def _add(x: int, y: int) -> int: + return x + y + + +def _make_stream(n: int = 3) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def _make_two_col_stream(n: int = 3) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + "y": pa.array([i * 10 for i in range(n)], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +# --------------------------------------------------------------------------- +# FunctionPod construction and processing +# --------------------------------------------------------------------------- + + +class TestFunctionPodProcess: + """Per FunctionPodProtocol, process() accepts exactly one stream and + returns a FunctionPodStream.""" + + def test_process_returns_function_pod_stream(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + assert isinstance(result, FunctionPodStream) + + def test_callable_alias(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod(stream) + assert isinstance(result, FunctionPodStream) + + def test_validate_inputs_rejects_multiple_streams(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + s1 = _make_stream() + s2 = _make_stream() + with pytest.raises(Exception): + pod.validate_inputs(s1, s2) + + def test_validate_inputs_accepts_single_stream(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + pod.validate_inputs(stream) # Should not raise + + +class TestFunctionPodTagInvariant: + """Per the strict boundary: function pods NEVER inspect or modify tags.""" + + def test_tags_pass_through_unchanged(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + + input_tags = [tag for tag, _ in stream.iter_packets()] + output_tags = [tag for tag, _ in result.iter_packets()] + + for in_tag, out_tag in zip(input_tags, output_tags): + # Tag data columns should be identical + assert in_tag.keys() == out_tag.keys() + for key in in_tag.keys(): + assert in_tag[key] == out_tag[key] + + def test_packets_are_transformed(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + + for tag, packet in result.iter_packets(): + assert "result" in packet.keys() + + +class TestFunctionPodOutputSchema: + """Per PodProtocol, output_schema() must match the actual output.""" + + def test_output_schema_matches_actual(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + + predicted_tag_schema, predicted_packet_schema = pod.output_schema(stream) + result = pod.process(stream) + actual_tag_schema, actual_packet_schema = result.output_schema() + + # Tag schemas should match + assert set(predicted_tag_schema.keys()) == set(actual_tag_schema.keys()) + # Packet schemas should match + assert set(predicted_packet_schema.keys()) == set(actual_packet_schema.keys()) + + +# --------------------------------------------------------------------------- +# FunctionPodStream +# --------------------------------------------------------------------------- + + +class TestFunctionPodStream: + """Per design, FunctionPodStream is lazy — computation happens on iteration.""" + + def test_producer_is_function_pod(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + assert result.producer is pod + + def test_upstreams_contains_input_stream(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + assert stream in result.upstreams + + def test_keys_matches_output_schema(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + tag_keys, packet_keys = result.keys() + tag_schema, packet_schema = result.output_schema() + assert set(tag_keys) == set(tag_schema.keys()) + assert set(packet_keys) == set(packet_schema.keys()) + + def test_as_table_materialization(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream(3) + result = pod.process(stream) + table = result.as_table() + assert isinstance(table, pa.Table) + assert table.num_rows == 3 + + def test_iter_packets_yields_correct_count(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream(5) + result = pod.process(stream) + packets = list(result.iter_packets()) + assert len(packets) == 5 + + def test_clear_cache_forces_recompute(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + result = pod.process(stream) + # Materialize + list(result.iter_packets()) + # Clear and re-iterate + result.clear_cache() + packets = list(result.iter_packets()) + assert len(packets) == 3 + + +# --------------------------------------------------------------------------- +# @function_pod decorator +# --------------------------------------------------------------------------- + + +class TestFunctionPodDecorator: + """Per design, the @function_pod decorator adds a .pod attribute.""" + + def test_decorator_creates_pod_attribute(self): + @function_pod(output_keys="result") + def my_double(x: int) -> int: + return x * 2 + + assert hasattr(my_double, "pod") + assert isinstance(my_double.pod, FunctionPod) + + def test_decorated_function_still_callable(self): + @function_pod(output_keys="result") + def my_double(x: int) -> int: + return x * 2 + + # The pod can process streams + stream = _make_stream() + result = my_double.pod.process(stream) + packets = list(result.iter_packets()) + assert len(packets) == 3 diff --git a/test-objective/unit/test_hashing.py b/test-objective/unit/test_hashing.py new file mode 100644 index 00000000..c2083c21 --- /dev/null +++ b/test-objective/unit/test_hashing.py @@ -0,0 +1,447 @@ +"""Tests for BaseSemanticHasher and TypeHandlerRegistry. + +Specification-derived tests covering deterministic hashing of primitives, +structures, ContentHash pass-through, identity_structure resolution, +strict-mode errors, collision resistance, and registry operations. +""" + +from __future__ import annotations + +import threading +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from orcapod.hashing.semantic_hashing.semantic_hasher import BaseSemanticHasher +from orcapod.hashing.semantic_hashing.type_handler_registry import ( + BuiltinTypeHandlerRegistry, + TypeHandlerRegistry, +) +from orcapod.types import ContentHash + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def registry() -> TypeHandlerRegistry: + """An empty TypeHandlerRegistry.""" + return TypeHandlerRegistry() + + +@pytest.fixture +def hasher(registry: TypeHandlerRegistry) -> BaseSemanticHasher: + """A strict BaseSemanticHasher backed by an empty registry.""" + return BaseSemanticHasher( + hasher_id="test_v1", + type_handler_registry=registry, + strict=True, + ) + + +@pytest.fixture +def lenient_hasher(registry: TypeHandlerRegistry) -> BaseSemanticHasher: + """A non-strict BaseSemanticHasher backed by an empty registry.""" + return BaseSemanticHasher( + hasher_id="test_v1", + type_handler_registry=registry, + strict=False, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _FakeHandler: + """Minimal object satisfying TypeHandlerProtocol for testing.""" + + def __init__(self, return_value: Any = "handled") -> None: + self._return_value = return_value + + def handle(self, obj: Any, hasher: BaseSemanticHasher) -> Any: + return self._return_value + + +class _IdentityObj: + """Object implementing identity_structure() for hashing.""" + + def __init__(self, structure: Any) -> None: + self._structure = structure + + def identity_structure(self) -> Any: + return self._structure + + def content_hash(self, hasher: Any = None) -> ContentHash: + if hasher is not None: + return hasher.hash_object(self.identity_structure()) + h = BaseSemanticHasher( + "test_v1", type_handler_registry=TypeHandlerRegistry(), strict=False + ) + return h.hash_object(self.identity_structure()) + + +# =================================================================== +# BaseSemanticHasher -- primitive hashing +# =================================================================== + + +class TestBaseSemanticHasherPrimitives: + """Primitives (int, str, float, bool, None) are hashed deterministically.""" + + @pytest.mark.parametrize( + "value", + [0, 1, -42, 3.14, -0.0, "", "hello", True, False, None], + ids=lambda v: f"{type(v).__name__}({v!r})", + ) + def test_primitive_produces_content_hash( + self, hasher: BaseSemanticHasher, value: Any + ) -> None: + result = hasher.hash_object(value) + assert isinstance(result, ContentHash) + + @pytest.mark.parametrize("value", [42, "hello", 3.14, True, None]) + def test_primitive_deterministic( + self, hasher: BaseSemanticHasher, value: Any + ) -> None: + """Same input always produces the same hash.""" + h1 = hasher.hash_object(value) + h2 = hasher.hash_object(value) + assert h1 == h2 + + def test_different_primitives_differ(self, hasher: BaseSemanticHasher) -> None: + """Different inputs produce different hashes (collision resistance).""" + h_int = hasher.hash_object(42) + h_str = hasher.hash_object("42") + assert h_int != h_str + + +# =================================================================== +# BaseSemanticHasher -- structures +# =================================================================== + + +class TestBaseSemanticHasherStructures: + """Structures (list, dict, tuple, set) are expanded and hashed.""" + + def test_list_hashed(self, hasher: BaseSemanticHasher) -> None: + result = hasher.hash_object([1, 2, 3]) + assert isinstance(result, ContentHash) + + def test_dict_hashed(self, hasher: BaseSemanticHasher) -> None: + result = hasher.hash_object({"a": 1, "b": 2}) + assert isinstance(result, ContentHash) + + def test_tuple_hashed(self, hasher: BaseSemanticHasher) -> None: + result = hasher.hash_object((1, 2, 3)) + assert isinstance(result, ContentHash) + + def test_set_hashed(self, hasher: BaseSemanticHasher) -> None: + result = hasher.hash_object({1, 2, 3}) + assert isinstance(result, ContentHash) + + def test_list_and_tuple_differ(self, hasher: BaseSemanticHasher) -> None: + """list and tuple with same elements produce different hashes.""" + h_list = hasher.hash_object([1, 2, 3]) + h_tuple = hasher.hash_object((1, 2, 3)) + assert h_list != h_tuple + + def test_set_order_independent(self, hasher: BaseSemanticHasher) -> None: + """Sets with the same elements hash identically regardless of insertion order.""" + h1 = hasher.hash_object({3, 1, 2}) + h2 = hasher.hash_object({1, 2, 3}) + assert h1 == h2 + + def test_dict_key_order_independent(self, hasher: BaseSemanticHasher) -> None: + """Dicts with the same key-value pairs hash identically regardless of order.""" + h1 = hasher.hash_object({"b": 2, "a": 1}) + h2 = hasher.hash_object({"a": 1, "b": 2}) + assert h1 == h2 + + def test_nested_structures(self, hasher: BaseSemanticHasher) -> None: + """Nested structures are hashed correctly.""" + nested = {"key": [1, (2, 3)], "other": {"inner": True}} + result = hasher.hash_object(nested) + assert isinstance(result, ContentHash) + # Determinism + assert result == hasher.hash_object(nested) + + def test_different_structures_differ(self, hasher: BaseSemanticHasher) -> None: + h1 = hasher.hash_object([1, 2]) + h2 = hasher.hash_object([1, 2, 3]) + assert h1 != h2 + + +# =================================================================== +# BaseSemanticHasher -- ContentHash passthrough +# =================================================================== + + +class TestBaseSemanticHasherContentHash: + """ContentHash inputs are returned as-is (terminal).""" + + def test_content_hash_passthrough(self, hasher: BaseSemanticHasher) -> None: + ch = ContentHash(method="sha256", digest=b"\x00" * 32) + result = hasher.hash_object(ch) + assert result is ch + + +# =================================================================== +# BaseSemanticHasher -- identity_structure resolution +# =================================================================== + + +class TestBaseSemanticHasherIdentityStructure: + """Objects implementing identity_structure() are resolved via it.""" + + def test_identity_structure_object(self, hasher: BaseSemanticHasher) -> None: + obj = _IdentityObj(structure={"name": "test", "version": 1}) + result = hasher.hash_object(obj) + assert isinstance(result, ContentHash) + + def test_identity_structure_deterministic( + self, hasher: BaseSemanticHasher + ) -> None: + obj1 = _IdentityObj(structure=[1, 2, 3]) + obj2 = _IdentityObj(structure=[1, 2, 3]) + assert hasher.hash_object(obj1) == hasher.hash_object(obj2) + + def test_different_identity_structures_differ( + self, hasher: BaseSemanticHasher + ) -> None: + obj1 = _IdentityObj(structure="alpha") + obj2 = _IdentityObj(structure="beta") + assert hasher.hash_object(obj1) != hasher.hash_object(obj2) + + +# =================================================================== +# BaseSemanticHasher -- strict mode +# =================================================================== + + +class TestBaseSemanticHasherStrictMode: + """Unknown type in strict mode raises TypeError.""" + + def test_unknown_type_strict_raises(self, hasher: BaseSemanticHasher) -> None: + class Unknown: + pass + + with pytest.raises(TypeError, match="no TypeHandlerProtocol registered"): + hasher.hash_object(Unknown()) + + def test_unknown_type_lenient_succeeds( + self, lenient_hasher: BaseSemanticHasher + ) -> None: + class Unknown: + pass + + result = lenient_hasher.hash_object(Unknown()) + assert isinstance(result, ContentHash) + + +# =================================================================== +# BaseSemanticHasher -- collision resistance +# =================================================================== + + +class TestBaseSemanticHasherCollisionResistance: + """Different inputs produce different hashes.""" + + def test_int_vs_string(self, hasher: BaseSemanticHasher) -> None: + assert hasher.hash_object(1) != hasher.hash_object("1") + + def test_empty_list_vs_empty_tuple(self, hasher: BaseSemanticHasher) -> None: + assert hasher.hash_object([]) != hasher.hash_object(()) + + def test_empty_dict_vs_empty_list(self, hasher: BaseSemanticHasher) -> None: + assert hasher.hash_object({}) != hasher.hash_object([]) + + def test_none_vs_string_none(self, hasher: BaseSemanticHasher) -> None: + assert hasher.hash_object(None) != hasher.hash_object("None") + + def test_true_vs_one(self, hasher: BaseSemanticHasher) -> None: + """bool True and int 1 produce different hashes due to JSON encoding.""" + h_true = hasher.hash_object(True) + h_one = hasher.hash_object(1) + assert h_true != h_one + + +# =================================================================== +# TypeHandlerRegistry -- register/get_handler roundtrip +# =================================================================== + + +class TestTypeHandlerRegistryBasics: + """register() + get_handler() roundtrip.""" + + def test_register_and_get_handler(self, registry: TypeHandlerRegistry) -> None: + handler = _FakeHandler() + registry.register(int, handler) + assert registry.get_handler(42) is handler + + def test_get_handler_returns_none_for_unregistered( + self, registry: TypeHandlerRegistry + ) -> None: + assert registry.get_handler("hello") is None + + +# =================================================================== +# TypeHandlerRegistry -- MRO-aware lookup +# =================================================================== + + +class TestTypeHandlerRegistryMRO: + """MRO-aware lookup: handler for parent class matches subclass.""" + + def test_subclass_inherits_parent_handler( + self, registry: TypeHandlerRegistry + ) -> None: + class Base: + pass + + class Child(Base): + pass + + handler = _FakeHandler() + registry.register(Base, handler) + assert registry.get_handler(Child()) is handler + + def test_specific_handler_overrides_parent( + self, registry: TypeHandlerRegistry + ) -> None: + class Base: + pass + + class Child(Base): + pass + + parent_handler = _FakeHandler("parent") + child_handler = _FakeHandler("child") + registry.register(Base, parent_handler) + registry.register(Child, child_handler) + assert registry.get_handler(Child()) is child_handler + assert registry.get_handler(Base()) is parent_handler + + +# =================================================================== +# TypeHandlerRegistry -- unregister +# =================================================================== + + +class TestTypeHandlerRegistryUnregister: + """unregister() removes handler.""" + + def test_unregister_existing(self, registry: TypeHandlerRegistry) -> None: + handler = _FakeHandler() + registry.register(int, handler) + result = registry.unregister(int) + assert result is True + assert registry.get_handler(42) is None + + def test_unregister_nonexistent(self, registry: TypeHandlerRegistry) -> None: + result = registry.unregister(float) + assert result is False + + +# =================================================================== +# TypeHandlerRegistry -- has_handler +# =================================================================== + + +class TestTypeHandlerRegistryHasHandler: + """has_handler() boolean check.""" + + def test_has_handler_true(self, registry: TypeHandlerRegistry) -> None: + registry.register(int, _FakeHandler()) + assert registry.has_handler(int) is True + + def test_has_handler_false(self, registry: TypeHandlerRegistry) -> None: + assert registry.has_handler(str) is False + + def test_has_handler_via_mro(self, registry: TypeHandlerRegistry) -> None: + class Base: + pass + + class Child(Base): + pass + + registry.register(Base, _FakeHandler()) + assert registry.has_handler(Child) is True + + +# =================================================================== +# TypeHandlerRegistry -- registered_types +# =================================================================== + + +class TestTypeHandlerRegistryRegisteredTypes: + """registered_types() lists types.""" + + def test_registered_types_empty(self, registry: TypeHandlerRegistry) -> None: + assert registry.registered_types() == [] + + def test_registered_types_populated(self, registry: TypeHandlerRegistry) -> None: + registry.register(int, _FakeHandler()) + registry.register(str, _FakeHandler()) + types = registry.registered_types() + assert set(types) == {int, str} + + +# =================================================================== +# TypeHandlerRegistry -- thread safety +# =================================================================== + + +class TestTypeHandlerRegistryThreadSafety: + """Concurrent register/lookup doesn't crash.""" + + def test_concurrent_register_lookup(self, registry: TypeHandlerRegistry) -> None: + errors: list[Exception] = [] + + def register_types(start: int, count: int) -> None: + try: + for i in range(start, start + count): + t = type(f"Type{i}", (), {}) + registry.register(t, _FakeHandler(f"handler_{i}")) + except Exception as exc: + errors.append(exc) + + def lookup_types() -> None: + try: + for _ in range(100): + registry.get_handler(42) + registry.registered_types() + registry.has_handler(int) + except Exception as exc: + errors.append(exc) + + threads = [] + for i in range(5): + threads.append( + threading.Thread(target=register_types, args=(i * 20, 20)) + ) + threads.append(threading.Thread(target=lookup_types)) + + for t in threads: + t.start() + for t in threads: + t.join(timeout=10) + + assert errors == [], f"Concurrent operations raised: {errors}" + + +# =================================================================== +# BuiltinTypeHandlerRegistry +# =================================================================== + + +class TestBuiltinTypeHandlerRegistry: + """BuiltinTypeHandlerRegistry is pre-populated with built-in handlers.""" + + def test_construction(self) -> None: + reg = BuiltinTypeHandlerRegistry() + assert len(reg.registered_types()) > 0 diff --git a/test-objective/unit/test_lazy_module.py b/test-objective/unit/test_lazy_module.py new file mode 100644 index 00000000..72209ca3 --- /dev/null +++ b/test-objective/unit/test_lazy_module.py @@ -0,0 +1,55 @@ +"""Specification-derived tests for LazyModule. + +Tests based on documented behavior: deferred import until first attribute access. +""" + +from __future__ import annotations + +import pytest + +from orcapod.utils.lazy_module import LazyModule + + +class TestLazyModule: + """Per design, LazyModule defers import until first attribute access.""" + + def test_not_loaded_initially(self): + lazy = LazyModule("json") + assert lazy.is_loaded is False + + def test_loads_on_attribute_access(self): + lazy = LazyModule("json") + # Accessing an attribute should trigger the import + _ = lazy.dumps + assert lazy.is_loaded is True + + def test_attribute_access_works(self): + lazy = LazyModule("json") + # Should be able to use the module's functions + result = lazy.dumps({"key": "value"}) + assert isinstance(result, str) + + def test_force_load(self): + lazy = LazyModule("json") + mod = lazy.force_load() + assert lazy.is_loaded is True + assert mod is not None + + def test_invalid_module_raises(self): + lazy = LazyModule("nonexistent_module_xyz_12345") + with pytest.raises(ModuleNotFoundError): + _ = lazy.dumps + + def test_module_name_property(self): + lazy = LazyModule("json") + assert lazy.module_name == "json" + + def test_repr(self): + lazy = LazyModule("json") + r = repr(lazy) + assert "json" in r + + def test_str(self): + lazy = LazyModule("json") + s = str(lazy) + assert "json" in s diff --git a/test-objective/unit/test_nodes.py b/test-objective/unit/test_nodes.py new file mode 100644 index 00000000..f8cc78e7 --- /dev/null +++ b/test-objective/unit/test_nodes.py @@ -0,0 +1,302 @@ +"""Specification-derived tests for FunctionNode, OperatorNode, and +Persistent variants. + +Tests based on design specification: +- FunctionNode: in-memory function pod execution as a stream +- FunctionNode: two-phase iteration (cached first, compute missing) +- OperatorNode: operator execution as a stream +- OperatorNode: CacheMode behavior (OFF/LOG/REPLAY) +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import ( + FunctionNode, + OperatorNode, +) +from orcapod.core.operators import Join +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import DerivedSource +from orcapod.core.streams import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import CacheMode + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _double(x: int) -> int: + return x * 2 + + +def _make_stream(n: int = 3) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def _make_joinable_streams() -> tuple[ArrowTableStream, ArrowTableStream]: + left = pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "age": pa.array([25, 30, 35], type=pa.int64()), + } + ) + right = pa.table( + { + "id": pa.array([2, 3, 4], type=pa.int64()), + "score": pa.array([85, 90, 95], type=pa.int64()), + } + ) + return ( + ArrowTableStream(left, tag_columns=["id"]), + ArrowTableStream(right, tag_columns=["id"]), + ) + + +# =================================================================== +# FunctionNode +# =================================================================== + + +class TestFunctionNode: + """Per design, FunctionNode wraps a FunctionPod for stream-based execution.""" + + def test_iter_packets(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream(3) + node = FunctionNode(function_pod=pod, input_stream=stream) + packets = list(node.iter_packets()) + assert len(packets) == 3 + for tag, packet in packets: + assert "result" in packet.keys() + + def test_process_packet(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + node = FunctionNode(function_pod=pod, input_stream=stream) + # Get first tag/packet from input + tag, packet = next(iter(stream.iter_packets())) + out_tag, out_packet = node.process_packet(tag, packet) + assert out_packet is not None + assert "result" in out_packet.keys() + + def test_producer_is_function_pod(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + node = FunctionNode(function_pod=pod, input_stream=stream) + assert node.producer is pod + + def test_upstreams(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + node = FunctionNode(function_pod=pod, input_stream=stream) + assert stream in node.upstreams + + def test_clear_cache(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + node = FunctionNode(function_pod=pod, input_stream=stream) + list(node.iter_packets()) + node.clear_cache() + # Should be able to iterate again after clearing + packets = list(node.iter_packets()) + assert len(packets) == 3 + + +# =================================================================== +# FunctionNode +# =================================================================== + + +class TestFunctionNode: + """Per design: two-phase iteration — Phase 1 returns cached records, + Phase 2 computes missing. Uses pipeline_hash for DB path scoping.""" + + def test_caches_computed_results(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream(3) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=pod, + input_stream=stream, + pipeline_database=pipeline_db, + result_database=result_db, + ) + # First iteration computes all + packets = list(node.iter_packets()) + assert len(packets) == 3 + + def test_run_eagerly_processes_all(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream(3) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=pod, + input_stream=stream, + pipeline_database=pipeline_db, + result_database=result_db, + ) + node.run() + # After run, results should be in DB + records = node.get_all_records() + assert records is not None + assert records.num_rows == 3 + + def test_as_source_returns_derived_source(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream(3) + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=pod, + input_stream=stream, + pipeline_database=pipeline_db, + result_database=result_db, + ) + node.run() + source = node.as_source() + assert isinstance(source, DerivedSource) + + def test_pipeline_path_uses_pipeline_hash(self): + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + pipeline_db = InMemoryArrowDatabase() + result_db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=pod, + input_stream=stream, + pipeline_database=pipeline_db, + result_database=result_db, + ) + path = node.pipeline_path + assert isinstance(path, tuple) + assert len(path) > 0 + + +# =================================================================== +# OperatorNode +# =================================================================== + + +class TestOperatorNode: + """Per design, OperatorNode wraps an operator for stream-based execution.""" + + def test_delegates_to_operator(self): + join = Join() + s1, s2 = _make_joinable_streams() + node = OperatorNode(operator=join, input_streams=[s1, s2]) + node.run() + table = node.as_table() + assert table.num_rows == 2 # Inner join on id=2, id=3 + + def test_clear_cache(self): + join = Join() + s1, s2 = _make_joinable_streams() + node = OperatorNode(operator=join, input_streams=[s1, s2]) + node.run() + node.clear_cache() + # Should be able to run again + node.run() + table = node.as_table() + assert table.num_rows == 2 + + +# =================================================================== +# OperatorNode +# =================================================================== + + +class TestOperatorNode: + """Per design, supports CacheMode: OFF (always compute), LOG (compute+store), + REPLAY (load from DB).""" + + def test_cache_mode_off(self): + join = Join() + s1, s2 = _make_joinable_streams() + db = InMemoryArrowDatabase() + node = OperatorNode( + operator=join, + input_streams=[s1, s2], + pipeline_database=db, + cache_mode=CacheMode.OFF, + ) + node.run() + table = node.as_table() + assert table.num_rows == 2 + + def test_cache_mode_log(self): + join = Join() + s1, s2 = _make_joinable_streams() + db = InMemoryArrowDatabase() + node = OperatorNode( + operator=join, + input_streams=[s1, s2], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + node.run() + # Results should be stored in DB + records = node.get_all_records() + assert records is not None + assert records.num_rows == 2 + + def test_cache_mode_replay(self): + join = Join() + s1, s2 = _make_joinable_streams() + db = InMemoryArrowDatabase() + + # First: LOG to populate DB + node1 = OperatorNode( + operator=join, + input_streams=[s1, s2], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + node1.run() + + # Second: REPLAY to load from DB + node2 = OperatorNode( + operator=join, + input_streams=[s1, s2], + pipeline_database=db, + cache_mode=CacheMode.REPLAY, + ) + node2.run() + table = node2.as_table() + assert table.num_rows == 2 + + def test_as_source_returns_derived_source(self): + join = Join() + s1, s2 = _make_joinable_streams() + db = InMemoryArrowDatabase() + node = OperatorNode( + operator=join, + input_streams=[s1, s2], + pipeline_database=db, + cache_mode=CacheMode.LOG, + ) + node.run() + source = node.as_source() + assert isinstance(source, DerivedSource) diff --git a/test-objective/unit/test_operators.py b/test-objective/unit/test_operators.py new file mode 100644 index 00000000..4835c03e --- /dev/null +++ b/test-objective/unit/test_operators.py @@ -0,0 +1,513 @@ +"""Specification-derived tests for all operators. + +Tests based on the design specification's operator semantics: +- Operators inspect tags, never packet content +- Operators can rename columns but never synthesize new values +- System tag evolution rules: name-preserving, name-extending, type-evolving +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.operators import ( + Batch, + DropPacketColumns, + DropTagColumns, + Join, + MapPackets, + MapTags, + MergeJoin, + PolarsFilter, + SelectPacketColumns, + SelectTagColumns, + SemiJoin, +) +from orcapod.core.sources import ArrowTableSource +from orcapod.core.streams import ArrowTableStream +from orcapod.errors import InputValidationError +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_stream( + tag_data: dict, packet_data: dict, tag_columns: list[str] +) -> ArrowTableStream: + all_data = {**tag_data, **packet_data} + table = pa.table(all_data) + return ArrowTableStream(table, tag_columns=tag_columns) + + +def _stream_a() -> ArrowTableStream: + """Stream with tag=id, packet=age.""" + return _make_stream( + {"id": pa.array([1, 2, 3], type=pa.int64())}, + {"age": pa.array([25, 30, 35], type=pa.int64())}, + ["id"], + ) + + +def _stream_b() -> ArrowTableStream: + """Stream with tag=id, packet=score (overlaps with A on id=2,3).""" + return _make_stream( + {"id": pa.array([2, 3, 4], type=pa.int64())}, + {"score": pa.array([85, 90, 95], type=pa.int64())}, + ["id"], + ) + + +def _stream_b_overlapping_packet() -> ArrowTableStream: + """Stream with tag=id, packet=age (same packet col name as A).""" + return _make_stream( + {"id": pa.array([2, 3, 4], type=pa.int64())}, + {"age": pa.array([40, 45, 50], type=pa.int64())}, + ["id"], + ) + + +def _stream_with_two_tags() -> ArrowTableStream: + """Stream with tag={id, group}, packet=value.""" + return _make_stream( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "group": pa.array(["a", "a", "b"], type=pa.large_string()), + }, + {"value": pa.array([10, 20, 30], type=pa.int64())}, + ["id", "group"], + ) + + +# =================================================================== +# Join (N-ary, commutative) +# =================================================================== + + +class TestJoin: + """Per design: N-ary inner join on shared tag columns. Requires + non-overlapping packet columns. Commutative. System tags: name-extending.""" + + def test_two_streams_on_common_tags(self): + join = Join() + result = join.process(_stream_a(), _stream_b()) + table = result.as_table() + # Inner join on id: should have rows for id=2, id=3 + assert table.num_rows == 2 + assert "age" in table.column_names + assert "score" in table.column_names + + def test_non_overlapping_packet_columns_required(self): + join = Join() + with pytest.raises(InputValidationError): + join.validate_inputs(_stream_a(), _stream_b_overlapping_packet()) + + def test_commutative(self): + """join(A, B) should produce the same data as join(B, A).""" + join = Join() + result_ab = join.process(_stream_a(), _stream_b()) + result_ba = join.process(_stream_b(), _stream_a()) + + table_ab = result_ab.as_table() + table_ba = result_ba.as_table() + + # Same number of rows + assert table_ab.num_rows == table_ba.num_rows + + # Same data (check by sorting by id and comparing values) + ab_ids = sorted(table_ab.column("id").to_pylist()) + ba_ids = sorted(table_ba.column("id").to_pylist()) + assert ab_ids == ba_ids + + def test_empty_result_when_no_matches(self): + """Disjoint tags → empty stream.""" + s1 = _make_stream( + {"id": pa.array([1], type=pa.int64())}, + {"a": pa.array([10], type=pa.int64())}, + ["id"], + ) + s2 = _make_stream( + {"id": pa.array([99], type=pa.int64())}, + {"b": pa.array([20], type=pa.int64())}, + ["id"], + ) + join = Join() + result = join.process(s1, s2) + table = result.as_table() + assert table.num_rows == 0 + + def test_three_or_more_streams(self): + s1 = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + s2 = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + s3 = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"c": pa.array([50, 60], type=pa.int64())}, + ["id"], + ) + join = Join() + result = join.process(s1, s2, s3) + table = result.as_table() + assert table.num_rows == 2 + assert "a" in table.column_names + assert "b" in table.column_names + assert "c" in table.column_names + + def test_system_tag_name_extending(self): + """Per design, multi-input ops extend system tag column names with + ::pipeline_hash:position. Sources (not raw streams) create system tags.""" + sa = ArrowTableSource( + pa.table({"id": pa.array([2, 3], type=pa.int64()), "a": pa.array([10, 20], type=pa.int64())}), + tag_columns=["id"], + ) + sb = ArrowTableSource( + pa.table({"id": pa.array([2, 3], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64())}), + tag_columns=["id"], + ) + join = Join() + result = join.process(sa, sb) + table = result.as_table(all_info=True) + tag_cols = [ + c for c in table.column_names if c.startswith(constants.SYSTEM_TAG_PREFIX) + ] + # After join, system tag columns should have extended names (at least 2 per input) + assert len(tag_cols) >= 2 + + def test_output_schema_prediction(self): + join = Join() + sa, sb = _stream_a(), _stream_b() + predicted_tag, predicted_packet = join.output_schema(sa, sb) + result = join.process(sa, sb) + actual_tag, actual_packet = result.output_schema() + assert set(predicted_tag.keys()) == set(actual_tag.keys()) + assert set(predicted_packet.keys()) == set(actual_packet.keys()) + + +# =================================================================== +# MergeJoin (binary) +# =================================================================== + + +class TestMergeJoin: + """Per design: binary join where colliding packet columns merge into + sorted list[T]. Requires identical types for colliding columns.""" + + def test_colliding_columns_become_sorted_lists(self): + merge = MergeJoin() + sa = _stream_a() # packet: age + sb = _stream_b_overlapping_packet() # packet: age + result = merge.process(sa, sb) + table = result.as_table() + # age should now be list[int] type + age_type = table.schema.field("age").type + assert pa.types.is_list(age_type) or pa.types.is_large_list(age_type) + + def test_non_colliding_columns_pass_through(self): + merge = MergeJoin() + # Create streams with some overlapping and some non-overlapping + s1 = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"shared": pa.array([10, 20], type=pa.int64()), "only_left": pa.array([1, 2], type=pa.int64())}, + ["id"], + ) + s2 = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"shared": pa.array([30, 40], type=pa.int64()), "only_right": pa.array([3, 4], type=pa.int64())}, + ["id"], + ) + result = merge.process(s1, s2) + table = result.as_table() + assert "only_left" in table.column_names + assert "only_right" in table.column_names + # Non-overlapping columns should keep their original type + assert table.schema.field("only_left").type == pa.int64() + + def test_output_schema_predicts_list_types(self): + merge = MergeJoin() + sa = _stream_a() + sb = _stream_b_overlapping_packet() + predicted_tag, predicted_packet = merge.output_schema(sa, sb) + # The 'age' column should be predicted as list type + assert "age" in predicted_packet + + +# =================================================================== +# SemiJoin (binary, non-commutative) +# =================================================================== + + +class TestSemiJoin: + """Per design: binary non-commutative join. Keeps left rows matching + right tags. Right packet columns are dropped.""" + + def test_filters_left_by_right_tags(self): + semi = SemiJoin() + result = semi.process(_stream_a(), _stream_b()) + table = result.as_table() + # A has id=[1,2,3], B has id=[2,3,4] + # Semi-join keeps A rows where id in B → id=2, id=3 + assert table.num_rows == 2 + + def test_non_commutative(self): + semi = SemiJoin() + result_ab = semi.process(_stream_a(), _stream_b()) + result_ba = semi.process(_stream_b(), _stream_a()) + # Generally not the same (different left/right roles) + table_ab = result_ab.as_table() + table_ba = result_ba.as_table() + # AB keeps A's packets (age), BA keeps B's packets (score) + assert "age" in table_ab.column_names + assert "score" in table_ba.column_names + + def test_preserves_left_packet_columns(self): + semi = SemiJoin() + result = semi.process(_stream_a(), _stream_b()) + table = result.as_table() + assert "age" in table.column_names + assert "score" not in table.column_names + + +# =================================================================== +# Batch +# =================================================================== + + +class TestBatch: + """Per design: groups rows by tag, aggregates packets. Packet column + types become list[T]. System tag type evolves from str to list[str].""" + + def test_groups_rows(self): + stream = _make_stream( + { + "group": pa.array(["a", "a", "b"], type=pa.large_string()), + }, + {"value": pa.array([1, 2, 3], type=pa.int64())}, + ["group"], + ) + batch = Batch() + result = batch.process(stream) + table = result.as_table() + # Batch aggregates all rows into a single batch row + assert table.num_rows == 1 + # Values should be collected into lists + values = table.column("value").to_pylist() + assert values == [[1, 2, 3]] + + def test_types_become_lists(self): + stream = _make_stream( + {"group": pa.array(["a", "a", "b"], type=pa.large_string())}, + {"value": pa.array([1, 2, 3], type=pa.int64())}, + ["group"], + ) + batch = Batch() + result = batch.process(stream) + table = result.as_table() + value_type = table.schema.field("value").type + assert pa.types.is_list(value_type) or pa.types.is_large_list(value_type) + + def test_batch_output_schema_prediction(self): + stream = _make_stream( + {"group": pa.array(["a", "a", "b"], type=pa.large_string())}, + {"value": pa.array([1, 2, 3], type=pa.int64())}, + ["group"], + ) + batch = Batch() + predicted_tag, predicted_packet = batch.output_schema(stream) + result = batch.process(stream) + actual_tag, actual_packet = result.output_schema() + assert set(predicted_tag.keys()) == set(actual_tag.keys()) + assert set(predicted_packet.keys()) == set(actual_packet.keys()) + + def test_batch_with_batch_size(self): + stream = _make_stream( + {"group": pa.array(["a"] * 5, type=pa.large_string())}, + {"value": pa.array([1, 2, 3, 4, 5], type=pa.int64())}, + ["group"], + ) + batch = Batch(batch_size=2) + result = batch.process(stream) + table = result.as_table() + # 5 items with batch_size=2: groups of [2, 2, 1] + assert table.num_rows >= 2 + + def test_batch_drop_partial(self): + stream = _make_stream( + {"group": pa.array(["a"] * 5, type=pa.large_string())}, + {"value": pa.array([1, 2, 3, 4, 5], type=pa.int64())}, + ["group"], + ) + batch = Batch(batch_size=2, drop_partial_batch=True) + result = batch.process(stream) + table = result.as_table() + # 5 items, batch_size=2, drop_partial → only 2 full batches + assert table.num_rows == 2 + + +# =================================================================== +# Column Selection +# =================================================================== + + +class TestSelectTagColumns: + """Per design: keeps only specified tag columns.""" + + def test_select_tag_columns(self): + stream = _stream_with_two_tags() + select = SelectTagColumns(columns=["id"]) + result = select.process(stream) + tag_keys, _ = result.keys() + assert "id" in tag_keys + assert "group" not in tag_keys + + def test_strict_missing_raises(self): + stream = _stream_with_two_tags() + select = SelectTagColumns(columns=["nonexistent"], strict=True) + with pytest.raises(Exception): + select.process(stream) + + +class TestSelectPacketColumns: + """Per design: keeps only specified packet columns.""" + + def test_select_packet_columns(self): + stream = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + select = SelectPacketColumns(columns=["a"]) + result = select.process(stream) + _, packet_keys = result.keys() + assert "a" in packet_keys + assert "b" not in packet_keys + + +class TestDropTagColumns: + """Per design: removes specified tag columns.""" + + def test_drop_tag_columns(self): + stream = _stream_with_two_tags() + drop = DropTagColumns(columns=["group"]) + result = drop.process(stream) + tag_keys, _ = result.keys() + assert "group" not in tag_keys + assert "id" in tag_keys + + +class TestDropPacketColumns: + """Per design: removes specified packet columns.""" + + def test_drop_packet_columns(self): + stream = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"a": pa.array([10, 20], type=pa.int64()), "b": pa.array([30, 40], type=pa.int64())}, + ["id"], + ) + drop = DropPacketColumns(columns=["b"]) + result = drop.process(stream) + _, packet_keys = result.keys() + assert "a" in packet_keys + assert "b" not in packet_keys + + +# =================================================================== +# MapTags / MapPackets +# =================================================================== + + +class TestMapTags: + """Per design: renames tag columns. System tags: name-preserving.""" + + def test_renames_tag_columns(self): + stream = _stream_with_two_tags() + mapper = MapTags(name_map={"id": "identifier"}) + result = mapper.process(stream) + tag_keys, _ = result.keys() + assert "identifier" in tag_keys + assert "id" not in tag_keys + + def test_drop_unmapped(self): + stream = _stream_with_two_tags() + mapper = MapTags(name_map={"id": "identifier"}, drop_unmapped=True) + result = mapper.process(stream) + tag_keys, _ = result.keys() + assert "identifier" in tag_keys + assert "group" not in tag_keys + + +class TestMapPackets: + """Per design: renames packet columns.""" + + def test_renames_packet_columns(self): + stream = _make_stream( + {"id": pa.array([1, 2], type=pa.int64())}, + {"value": pa.array([10, 20], type=pa.int64())}, + ["id"], + ) + mapper = MapPackets(name_map={"value": "score"}) + result = mapper.process(stream) + _, packet_keys = result.keys() + assert "score" in packet_keys + assert "value" not in packet_keys + + +# =================================================================== +# PolarsFilter +# =================================================================== + + +class TestPolarsFilter: + """Per design: filters rows by predicate or constraints. Schema preserved. + System tags: name-preserving.""" + + def test_filter_with_constraints(self): + stream = _stream_a() + filt = PolarsFilter(constraints={"id": 2}) + result = filt.process(stream) + table = result.as_table() + assert table.num_rows == 1 + assert table.column("id").to_pylist() == [2] + + def test_filter_preserves_schema(self): + stream = _stream_a() + filt = PolarsFilter(constraints={"id": 2}) + predicted_tag, predicted_packet = filt.output_schema(stream) + result = filt.process(stream) + actual_tag, actual_packet = result.output_schema() + assert set(predicted_tag.keys()) == set(actual_tag.keys()) + assert set(predicted_packet.keys()) == set(actual_packet.keys()) + + +# =================================================================== +# Operator Base Class Validation +# =================================================================== + + +class TestOperatorInputValidation: + """Per design, operators enforce input arity.""" + + def test_unary_rejects_multiple_inputs(self): + batch = Batch() + with pytest.raises(Exception): + batch.validate_inputs(_stream_a(), _stream_b()) + + def test_binary_rejects_wrong_count(self): + join = SemiJoin() + with pytest.raises(Exception): + join.validate_inputs(_stream_a()) # Only 1 for a binary op + + def test_nonzero_input_rejects_zero(self): + join = Join() + with pytest.raises(Exception): + join.validate_inputs() # No inputs diff --git a/test-objective/unit/test_packet.py b/test-objective/unit/test_packet.py new file mode 100644 index 00000000..b70c2ea3 --- /dev/null +++ b/test-objective/unit/test_packet.py @@ -0,0 +1,224 @@ +"""Specification-derived tests for Packet.""" + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.datagram import Datagram +from orcapod.core.datagrams.tag_packet import Packet +from orcapod.types import ColumnConfig + + +def _make_context(): + """Create a DataContext for tests.""" + from orcapod.contexts import resolve_context + return resolve_context(None) + + +# --------------------------------------------------------------------------- +# Source info stored per data column +# --------------------------------------------------------------------------- + +class TestPacketSourceInfo: + """source_info is stored per data column.""" + + def test_packet_stores_source_info(self): + ctx = _make_context() + pkt = Packet( + {"x": 1, "y": "hello"}, + data_context=ctx, + source_info={"x": "src_x", "y": "src_y"}, + ) + assert pkt["x"] == 1 + + def test_source_info_not_in_keys_by_default(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + keys = list(pkt.keys()) + assert "x" in keys + assert not any(k.startswith("_source_") for k in keys) + + def test_source_info_not_in_as_dict_by_default(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + d = pkt.as_dict() + assert not any(k.startswith("_source_") for k in d) + + def test_source_info_not_in_as_table_by_default(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + table = pkt.as_table() + assert not any(name.startswith("_source_") for name in table.column_names) + + +# --------------------------------------------------------------------------- +# Source info included with ColumnConfig +# --------------------------------------------------------------------------- + +class TestPacketSourceInfoWithConfig: + """With ColumnConfig source=True or all_info=True, source columns included.""" + + def test_keys_with_source_true(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + keys = list(pkt.keys(columns=ColumnConfig(source=True))) + assert any(k.startswith("_source_") for k in keys) + + def test_as_dict_with_source_true(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + d = pkt.as_dict(columns=ColumnConfig(source=True)) + assert any(k.startswith("_source_") for k in d) + + def test_as_table_with_source_true(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + table = pkt.as_table(columns=ColumnConfig(source=True)) + assert any(name.startswith("_source_") for name in table.column_names) + + def test_keys_with_all_info(self): + ctx = _make_context() + pkt = Packet( + {"x": 1}, + data_context=ctx, + source_info={"x": "src_x"}, + ) + keys = list(pkt.keys(columns=ColumnConfig.all())) + assert any(k.startswith("_source_") for k in keys) + + +# --------------------------------------------------------------------------- +# with_source_info() returns new instance (immutable) +# --------------------------------------------------------------------------- + +class TestPacketWithSourceInfo: + """with_source_info() returns new instance (immutable).""" + + def test_with_source_info_returns_new_instance(self): + ctx = _make_context() + pkt = Packet({"x": 1}, data_context=ctx, source_info={"x": "src_x"}) + new_pkt = pkt.with_source_info(x="new_src") + assert new_pkt is not pkt + + def test_with_source_info_does_not_mutate_original(self): + ctx = _make_context() + pkt = Packet({"x": 1}, data_context=ctx, source_info={"x": "src_x"}) + pkt.with_source_info(x="new_src") + # Original should still have old source info + d = pkt.as_dict(columns=ColumnConfig(source=True)) + source_vals = {k: v for k, v in d.items() if k.startswith("_source_")} + assert any(v == "src_x" for v in source_vals.values()) + + +# --------------------------------------------------------------------------- +# rename() also renames source_info keys +# --------------------------------------------------------------------------- + +class TestPacketRename: + """rename() also renames source_info keys.""" + + def test_rename_updates_source_info_keys(self): + ctx = _make_context() + pkt = Packet( + {"x": 1, "y": 2}, + data_context=ctx, + source_info={"x": "src_x", "y": "src_y"}, + ) + renamed = pkt.rename({"x": "alpha"}) + assert "alpha" in renamed + assert "x" not in renamed + # Source info should also be renamed + d = renamed.as_dict(columns=ColumnConfig(source=True)) + assert any("alpha" in k for k in d if k.startswith("_source_")) + assert not any("_source_x" == k for k in d) + + +# --------------------------------------------------------------------------- +# with_columns() adds source_info=None for new columns +# --------------------------------------------------------------------------- + +class TestPacketWithColumns: + """with_columns() adds source_info=None for new columns.""" + + def test_with_columns_new_column_has_none_source(self): + ctx = _make_context() + pkt = Packet({"x": 1}, data_context=ctx, source_info={"x": "src_x"}) + extended = pkt.with_columns(z=99) + assert "z" in extended + # The new column should exist with source_info accessible + d = extended.as_dict(columns=ColumnConfig(source=True)) + # z should have a source column, likely with None value + source_z_keys = [k for k in d if k.startswith("_source_") and "z" in k] + assert len(source_z_keys) > 0 + + +# --------------------------------------------------------------------------- +# as_datagram() returns Datagram, not Packet +# --------------------------------------------------------------------------- + +class TestPacketAsDatagram: + """as_datagram() returns a Datagram (not Packet).""" + + def test_as_datagram_returns_datagram_type(self): + ctx = _make_context() + pkt = Packet({"x": 1}, data_context=ctx, source_info={"x": "src_x"}) + dg = pkt.as_datagram() + assert isinstance(dg, Datagram) + assert not isinstance(dg, Packet) + + def test_as_datagram_preserves_data(self): + ctx = _make_context() + pkt = Packet({"x": 1, "y": "hello"}, data_context=ctx, source_info={"x": "s1", "y": "s2"}) + dg = pkt.as_datagram() + assert dg["x"] == 1 + assert dg["y"] == "hello" + + +# --------------------------------------------------------------------------- +# copy() preserves source_info +# --------------------------------------------------------------------------- + +class TestPacketCopy: + """copy() preserves source_info.""" + + def test_copy_preserves_source_info(self): + ctx = _make_context() + pkt = Packet({"x": 1}, data_context=ctx, source_info={"x": "src_x"}) + copied = pkt.copy() + assert copied is not pkt + # Both should have same source info + orig_d = pkt.as_dict(columns=ColumnConfig(source=True)) + copy_d = copied.as_dict(columns=ColumnConfig(source=True)) + orig_sources = {k: v for k, v in orig_d.items() if k.startswith("_source_")} + copy_sources = {k: v for k, v in copy_d.items() if k.startswith("_source_")} + assert orig_sources == copy_sources + + def test_copy_preserves_data(self): + ctx = _make_context() + pkt = Packet({"x": 1, "y": "hello"}, data_context=ctx, source_info={"x": "s1", "y": "s2"}) + copied = pkt.copy() + assert copied["x"] == 1 + assert copied["y"] == "hello" diff --git a/test-objective/unit/test_packet_function.py b/test-objective/unit/test_packet_function.py new file mode 100644 index 00000000..efc49bd3 --- /dev/null +++ b/test-objective/unit/test_packet_function.py @@ -0,0 +1,231 @@ +"""Specification-derived tests for PythonPacketFunction and CachedPacketFunction. + +Tests based on PacketFunctionProtocol and documented behaviors. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.tag_packet import Packet +from orcapod.core.packet_function import CachedPacketFunction, PythonPacketFunction +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import Schema + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +def double(x: int) -> int: + return x * 2 + + +def add(x: int, y: int) -> int: + return x + y + + +def to_upper(name: str) -> str: + return name.upper() + + +def return_none(x: int) -> int: + return None # type: ignore[return-value] + + +def variadic_func(*args: int) -> int: + return sum(args) + + +def kwargs_func(**kwargs: int) -> int: + return sum(kwargs.values()) + + +def no_annotations(x, y): + return x + y + + +# --------------------------------------------------------------------------- +# PythonPacketFunction construction +# --------------------------------------------------------------------------- + + +class TestPythonPacketFunctionConstruction: + """Per design, PythonPacketFunction wraps a plain Python function.""" + + def test_from_simple_function(self): + pf = PythonPacketFunction(double, output_keys="result") + assert pf.canonical_function_name == "double" + + def test_infers_input_schema_from_signature(self): + pf = PythonPacketFunction(double, output_keys="result") + input_schema = pf.input_packet_schema + assert "x" in input_schema + assert input_schema["x"] is int + + def test_infers_output_schema(self): + pf = PythonPacketFunction(double, output_keys="result") + output_schema = pf.output_packet_schema + assert "result" in output_schema + + def test_multi_input_schema(self): + pf = PythonPacketFunction(add, output_keys="result") + input_schema = pf.input_packet_schema + assert "x" in input_schema + assert "y" in input_schema + + def test_rejects_variadic_args(self): + with pytest.raises((ValueError, TypeError)): + PythonPacketFunction(variadic_func, output_keys="result") + + def test_rejects_variadic_kwargs(self): + with pytest.raises((ValueError, TypeError)): + PythonPacketFunction(kwargs_func, output_keys="result") + + def test_explicit_function_name(self): + pf = PythonPacketFunction( + double, output_keys="result", function_name="my_doubler" + ) + assert pf.canonical_function_name == "my_doubler" + + def test_version_parsing(self): + pf = PythonPacketFunction(double, output_keys="result", version="v1.2") + assert pf.major_version == 1 + assert pf.minor_version_string == "2" + + def test_default_version(self): + pf = PythonPacketFunction(double, output_keys="result") + assert pf.major_version == 0 + + +# --------------------------------------------------------------------------- +# PythonPacketFunction execution +# --------------------------------------------------------------------------- + + +class TestPythonPacketFunctionExecution: + """Per PacketFunctionProtocol, call() applies function to packet data.""" + + def test_call_transforms_packet(self): + pf = PythonPacketFunction(double, output_keys="result") + packet = Packet({"x": 5}) + result = pf.call(packet) + assert result is not None + assert result["result"] == 10 + + def test_call_multi_input(self): + pf = PythonPacketFunction(add, output_keys="result") + packet = Packet({"x": 3, "y": 7}) + result = pf.call(packet) + assert result is not None + assert result["result"] == 10 + + def test_call_returns_none_propagates(self): + pf = PythonPacketFunction(return_none, output_keys="result") + packet = Packet({"x": 5}) + result = pf.call(packet) + # When function returns None, it's wrapped: {"result": None} + assert result["result"] is None + + def test_direct_call_bypasses_executor(self): + pf = PythonPacketFunction(double, output_keys="result") + packet = Packet({"x": 5}) + result = pf.direct_call(packet) + assert result is not None + assert result["result"] == 10 + + +# --------------------------------------------------------------------------- +# PythonPacketFunction hashing +# --------------------------------------------------------------------------- + + +class TestPythonPacketFunctionHashing: + """Per ContentIdentifiableProtocol, hash is deterministic and changes + with function content.""" + + def test_content_hash_deterministic(self): + pf1 = PythonPacketFunction(double, output_keys="result") + pf2 = PythonPacketFunction(double, output_keys="result") + assert pf1.content_hash() == pf2.content_hash() + + def test_content_hash_changes_with_different_function(self): + pf1 = PythonPacketFunction(double, output_keys="result") + pf2 = PythonPacketFunction(to_upper, output_keys="result") + assert pf1.content_hash() != pf2.content_hash() + + def test_pipeline_hash_schema_based(self): + pf = PythonPacketFunction(double, output_keys="result") + ph = pf.pipeline_hash() + assert ph is not None + + +# --------------------------------------------------------------------------- +# CachedPacketFunction +# --------------------------------------------------------------------------- + + +class TestCachedPacketFunction: + """Per design, CachedPacketFunction wraps a PacketFunction and caches + results in an ArrowDatabaseProtocol.""" + + def test_cache_miss_computes_and_stores(self): + db = InMemoryArrowDatabase() + inner_pf = PythonPacketFunction(double, output_keys="result") + cached_pf = CachedPacketFunction(inner_pf, result_database=db) + packet = Packet({"x": 5}) + result = cached_pf.call(packet) + assert result is not None + assert result["result"] == 10 + # After flush, record should be in DB + db.flush() + + def test_cache_hit_returns_stored(self): + db = InMemoryArrowDatabase() + inner_pf = PythonPacketFunction(double, output_keys="result") + cached_pf = CachedPacketFunction(inner_pf, result_database=db) + cached_pf.set_auto_flush(True) + packet = Packet({"x": 5}) + # First call computes + result1 = cached_pf.call(packet) + # Second call should return cached + result2 = cached_pf.call(packet) + assert result1 is not None + assert result2 is not None + assert result1["result"] == result2["result"] + + def test_skip_cache_lookup_always_computes(self): + db = InMemoryArrowDatabase() + inner_pf = PythonPacketFunction(double, output_keys="result") + cached_pf = CachedPacketFunction(inner_pf, result_database=db) + cached_pf.set_auto_flush(True) + packet = Packet({"x": 5}) + cached_pf.call(packet) + # With skip_cache_lookup, should recompute + result = cached_pf.call(packet, skip_cache_lookup=True) + assert result is not None + assert result["result"] == 10 + + def test_skip_cache_insert_doesnt_store(self): + db = InMemoryArrowDatabase() + inner_pf = PythonPacketFunction(double, output_keys="result") + cached_pf = CachedPacketFunction(inner_pf, result_database=db) + packet = Packet({"x": 5}) + cached_pf.call(packet, skip_cache_insert=True) + db.flush() + # Should not be cached + cached_output = cached_pf.get_cached_output_for_packet(packet) + assert cached_output is None + + def test_get_all_cached_outputs(self): + db = InMemoryArrowDatabase() + inner_pf = PythonPacketFunction(double, output_keys="result") + cached_pf = CachedPacketFunction(inner_pf, result_database=db) + cached_pf.set_auto_flush(True) + cached_pf.call(Packet({"x": 1})) + cached_pf.call(Packet({"x": 2})) + all_outputs = cached_pf.get_all_cached_outputs() + assert all_outputs is not None + assert all_outputs.num_rows == 2 diff --git a/test-objective/unit/test_schema_utils.py b/test-objective/unit/test_schema_utils.py new file mode 100644 index 00000000..214f35a5 --- /dev/null +++ b/test-objective/unit/test_schema_utils.py @@ -0,0 +1,268 @@ +"""Tests for schema utility functions. + +Specification-derived tests covering schema extraction from function +signatures, schema verification, compatibility checking, type inference, +union/intersection operations, and type promotion. +""" + +from __future__ import annotations + +import pytest + +from orcapod.types import Schema +from orcapod.utils.schema_utils import ( + check_schema_compatibility, + extract_function_schemas, + get_compatible_type, + infer_schema_from_dict, + intersection_schemas, + union_schemas, + verify_packet_schema, +) + + +# =========================================================================== +# extract_function_schemas +# =========================================================================== + + +class TestExtractFunctionSchemas: + """Infers schemas from type-annotated function signatures.""" + + def test_simple_function(self) -> None: + def add(x: int, y: int) -> int: + return x + y + + input_schema, output_schema = extract_function_schemas(add, ["result"]) + assert dict(input_schema) == {"x": int, "y": int} + assert dict(output_schema) == {"result": int} + + def test_multi_return(self) -> None: + def process(data: str) -> tuple[int, str]: + return len(data), data.upper() + + input_schema, output_schema = extract_function_schemas( + process, ["length", "upper"] + ) + assert dict(input_schema) == {"data": str} + assert dict(output_schema) == {"length": int, "upper": str} + + def test_with_input_typespec_override(self) -> None: + def func(x, y): # noqa: ANN001, ANN201 + return x + y + + input_schema, output_schema = extract_function_schemas( + func, + ["sum"], + input_typespec={"x": int, "y": int}, + output_typespec={"sum": int}, + ) + assert dict(input_schema) == {"x": int, "y": int} + assert dict(output_schema) == {"sum": int} + + def test_output_typespec_as_sequence(self) -> None: + def func(a: int) -> tuple[str, float]: + return str(a), float(a) + + input_schema, output_schema = extract_function_schemas( + func, ["s", "f"], output_typespec=[str, float] + ) + assert dict(output_schema) == {"s": str, "f": float} + + def test_optional_parameters_tracked(self) -> None: + def func(x: int, y: int = 10) -> int: + return x + y + + input_schema, _ = extract_function_schemas(func, ["result"]) + assert "y" in input_schema.optional_fields + assert "x" not in input_schema.optional_fields + + def test_raises_for_unannotated_parameter(self) -> None: + def func(x): # noqa: ANN001, ANN201 + return x + + with pytest.raises(ValueError, match="no type annotation"): + extract_function_schemas(func, ["result"]) + + def test_raises_for_variadic_args(self) -> None: + """Functions with *args raise ValueError because the parameter has no annotation.""" + + def func(*args): # noqa: ANN002, ANN201 + return sum(args) + + with pytest.raises(ValueError, match="no type annotation"): + extract_function_schemas(func, ["result"]) + + def test_raises_for_variadic_kwargs(self) -> None: + """Functions with **kwargs raise ValueError because the parameter has no annotation.""" + + def func(**kwargs): # noqa: ANN003, ANN201 + return kwargs + + with pytest.raises(ValueError, match="no type annotation"): + extract_function_schemas(func, ["result"]) + + +# =========================================================================== +# verify_packet_schema +# =========================================================================== + + +class TestVerifyPacketSchema: + """Returns True when dict matches schema types.""" + + def test_matching_packet(self) -> None: + schema = Schema({"name": str, "age": int}) + packet = {"name": "Alice", "age": 30} + assert verify_packet_schema(packet, schema) is True + + def test_mismatched_type(self) -> None: + schema = Schema({"name": str, "age": int}) + packet = {"name": "Alice", "age": "thirty"} + assert verify_packet_schema(packet, schema) is False + + def test_extra_keys_in_packet(self) -> None: + schema = Schema({"name": str}) + packet = {"name": "Alice", "extra": 42} + assert verify_packet_schema(packet, schema) is False + + +# =========================================================================== +# check_schema_compatibility +# =========================================================================== + + +class TestCheckSchemaCompatibility: + """Compatible types pass.""" + + def test_compatible_schemas(self) -> None: + incoming = Schema({"x": int, "y": str}) + receiving = Schema({"x": int, "y": str}) + assert check_schema_compatibility(incoming, receiving) is True + + def test_incompatible_missing_required_key(self) -> None: + incoming = Schema({"x": int}) + receiving = Schema({"x": int, "y": str}) + assert check_schema_compatibility(incoming, receiving) is False + + def test_optional_key_can_be_missing(self) -> None: + incoming = Schema({"x": int}) + receiving = Schema({"x": int, "y": str}, optional_fields=["y"]) + assert check_schema_compatibility(incoming, receiving) is True + + +# =========================================================================== +# infer_schema_from_dict +# =========================================================================== + + +class TestInferSchemaFromDict: + """Infers types from dict values.""" + + def test_basic_inference(self) -> None: + data = {"name": "Alice", "age": 30, "score": 9.5} + schema = infer_schema_from_dict(data) + assert dict(schema) == {"name": str, "age": int, "score": float} + + def test_none_value_defaults_to_str(self) -> None: + data = {"name": None} + schema = infer_schema_from_dict(data) + assert dict(schema) == {"name": str} + + def test_with_base_schema(self) -> None: + data = {"name": "Alice", "age": 30} + base = {"age": float} + schema = infer_schema_from_dict(data, schema=base) + # "age" should use the base schema type (float), not inferred (int) + assert schema["age"] is float + assert schema["name"] is str + + +# =========================================================================== +# union_schemas +# =========================================================================== + + +class TestUnionSchemas: + """Merges cleanly when no conflicts.""" + + def test_disjoint_merge(self) -> None: + s1 = Schema({"a": int}) + s2 = Schema({"b": str}) + result = union_schemas(s1, s2) + assert dict(result) == {"a": int, "b": str} + + def test_overlapping_same_type(self) -> None: + s1 = Schema({"a": int, "b": str}) + s2 = Schema({"b": str, "c": float}) + result = union_schemas(s1, s2) + assert dict(result) == {"a": int, "b": str, "c": float} + + def test_conflicting_types_raises(self) -> None: + s1 = Schema({"a": int}) + s2 = Schema({"a": str}) + with pytest.raises(TypeError): + union_schemas(s1, s2) + + +# =========================================================================== +# intersection_schemas +# =========================================================================== + + +class TestIntersectionSchemas: + """Returns common fields only.""" + + def test_common_fields_only(self) -> None: + s1 = Schema({"a": int, "b": str, "c": float}) + s2 = Schema({"b": str, "c": float, "d": bool}) + result = intersection_schemas(s1, s2) + assert set(result.keys()) == {"b", "c"} + assert result["b"] is str + assert result["c"] is float + + def test_no_common_fields(self) -> None: + s1 = Schema({"a": int}) + s2 = Schema({"b": str}) + result = intersection_schemas(s1, s2) + assert len(result) == 0 + + def test_conflicting_common_field_raises(self) -> None: + s1 = Schema({"a": int}) + s2 = Schema({"a": str}) + with pytest.raises(TypeError, match="conflict"): + intersection_schemas(s1, s2) + + +# =========================================================================== +# get_compatible_type +# =========================================================================== + + +class TestGetCompatibleType: + """Numeric promotion and incompatibility detection.""" + + def test_identical_types(self) -> None: + assert get_compatible_type(int, int) is int + + def test_numeric_promotion_int_float(self) -> None: + """int is a subclass of float in Python's numeric tower -- should promote.""" + # int is not actually a subclass of float in Python, but bool is a subclass of int. + # get_compatible_type uses issubclass, so int/float may raise. + # Actually: issubclass(int, float) is False in Python. + # The function falls back to raising TypeError for int vs float. + # Let's test bool vs int which is a true subclass relationship. + result = get_compatible_type(bool, int) + assert result is int + + def test_incompatible_types_raises(self) -> None: + with pytest.raises(TypeError, match="not compatible"): + get_compatible_type(int, str) + + def test_none_type_handling(self) -> None: + """NoneType combined with another type returns the other type.""" + result = get_compatible_type(type(None), int) + assert result is int + + result2 = get_compatible_type(str, type(None)) + assert result2 is str diff --git a/test-objective/unit/test_semantic_types.py b/test-objective/unit/test_semantic_types.py new file mode 100644 index 00000000..db098e97 --- /dev/null +++ b/test-objective/unit/test_semantic_types.py @@ -0,0 +1,122 @@ +"""Specification-derived tests for semantic type conversion. + +Tests the UniversalTypeConverter and SemanticTypeRegistry based on +documented behavior in protocols and design specification. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.contexts import get_default_type_converter +from orcapod.types import Schema + + +# --------------------------------------------------------------------------- +# UniversalTypeConverter — Python ↔ Arrow type conversion +# --------------------------------------------------------------------------- + + +class TestPythonToArrowType: + """Per the TypeConverterProtocol, python_type_to_arrow_type converts + Python type hints to Arrow types.""" + + @pytest.fixture + def converter(self): + return get_default_type_converter() + + def test_int_to_int64(self, converter): + result = converter.python_type_to_arrow_type(int) + assert result == pa.int64() + + def test_float_to_float64(self, converter): + result = converter.python_type_to_arrow_type(float) + assert result == pa.float64() + + def test_str_to_large_string(self, converter): + result = converter.python_type_to_arrow_type(str) + assert result == pa.large_string() + + def test_bool_to_bool(self, converter): + result = converter.python_type_to_arrow_type(bool) + assert result == pa.bool_() + + def test_bytes_to_binary(self, converter): + result = converter.python_type_to_arrow_type(bytes) + # Could be large_binary or binary + assert pa.types.is_binary(result) or pa.types.is_large_binary(result) + + def test_list_of_int(self, converter): + result = converter.python_type_to_arrow_type(list[int]) + assert pa.types.is_list(result) or pa.types.is_large_list(result) + + +class TestArrowToPythonType: + """Per the TypeConverterProtocol, arrow_type_to_python_type converts + Arrow types back to Python type hints.""" + + @pytest.fixture + def converter(self): + return get_default_type_converter() + + def test_int64_to_int(self, converter): + result = converter.arrow_type_to_python_type(pa.int64()) + assert result is int + + def test_float64_to_float(self, converter): + result = converter.arrow_type_to_python_type(pa.float64()) + assert result is float + + def test_bool_to_bool(self, converter): + result = converter.arrow_type_to_python_type(pa.bool_()) + assert result is bool + + +class TestSchemaConversionRoundtrip: + """Python Schema → Arrow Schema → Python Schema should preserve types.""" + + @pytest.fixture + def converter(self): + return get_default_type_converter() + + def test_simple_schema_roundtrip(self, converter): + python_schema = Schema({"x": int, "y": float, "name": str}) + arrow_schema = converter.python_schema_to_arrow_schema(python_schema) + roundtripped = converter.arrow_schema_to_python_schema(arrow_schema) + assert set(roundtripped.keys()) == set(python_schema.keys()) + for key in python_schema: + assert roundtripped[key] == python_schema[key] + + +class TestPythonDictsToArrowTable: + """Per protocol, python_dicts_to_arrow_table converts list of dicts to pa.Table.""" + + @pytest.fixture + def converter(self): + return get_default_type_converter() + + def test_simple_conversion(self, converter): + data = [{"x": 1, "y": 2.0}, {"x": 3, "y": 4.0}] + schema = Schema({"x": int, "y": float}) + result = converter.python_dicts_to_arrow_table(data, python_schema=schema) + assert isinstance(result, pa.Table) + assert result.num_rows == 2 + assert "x" in result.column_names + assert "y" in result.column_names + + +class TestArrowTableToPythonDicts: + """Per protocol, arrow_table_to_python_dicts converts pa.Table to list of dicts.""" + + @pytest.fixture + def converter(self): + return get_default_type_converter() + + def test_simple_conversion(self, converter): + table = pa.table({"x": [1, 2], "y": [3.0, 4.0]}) + result = converter.arrow_table_to_python_dicts(table) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["x"] == 1 + assert result[1]["y"] == 4.0 diff --git a/test-objective/unit/test_source_registry.py b/test-objective/unit/test_source_registry.py new file mode 100644 index 00000000..5481a0e3 --- /dev/null +++ b/test-objective/unit/test_source_registry.py @@ -0,0 +1,261 @@ +"""Specification-derived tests for SourceRegistry. + +Tests documented behaviors of SourceRegistry including registration, +lookup, replacement, unregistration, idempotency, and introspection. +""" + +import pyarrow as pa +import pytest + +from orcapod.core.sources import ArrowTableSource +from orcapod.core.sources.source_registry import SourceRegistry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source(tag_val: str = "a", data_val: int = 1) -> ArrowTableSource: + """Create a minimal ArrowTableSource for registry testing.""" + table = pa.table( + { + "tag": pa.array([tag_val], type=pa.large_string()), + "data": pa.array([data_val], type=pa.int64()), + } + ) + return ArrowTableSource(table, tag_columns=["tag"]) + + +# --------------------------------------------------------------------------- +# Registration and Lookup +# --------------------------------------------------------------------------- + + +class TestRegisterAndGet: + """register() + get() roundtrip behaviors.""" + + def test_register_and_get_roundtrip(self): + """A registered source can be retrieved with get().""" + registry = SourceRegistry() + source = _make_source() + registry.register("s1", source) + assert registry.get("s1") is source + + def test_register_multiple_sources(self): + """Multiple sources can be registered under different IDs.""" + registry = SourceRegistry() + s1 = _make_source("a", 1) + s2 = _make_source("b", 2) + registry.register("s1", s1) + registry.register("s2", s2) + assert registry.get("s1") is s1 + assert registry.get("s2") is s2 + + def test_register_empty_id_raises_value_error(self): + """register() with empty string id raises ValueError.""" + registry = SourceRegistry() + source = _make_source() + with pytest.raises(ValueError): + registry.register("", source) + + def test_register_none_source_raises_value_error(self): + """register() with None source raises ValueError.""" + registry = SourceRegistry() + with pytest.raises(ValueError): + registry.register("s1", None) + + def test_register_same_object_idempotent(self): + """Registering the same object under the same id is a no-op.""" + registry = SourceRegistry() + source = _make_source() + registry.register("s1", source) + registry.register("s1", source) # same object, no error + assert registry.get("s1") is source + assert len(registry) == 1 + + def test_register_different_object_same_id_keeps_existing(self): + """Registering a different object under an existing id keeps the original.""" + registry = SourceRegistry() + s1 = _make_source("a", 1) + s2 = _make_source("b", 2) + registry.register("s1", s1) + registry.register("s1", s2) # different object, warns, keeps s1 + assert registry.get("s1") is s1 + assert len(registry) == 1 + + +# --------------------------------------------------------------------------- +# Replace +# --------------------------------------------------------------------------- + + +class TestReplace: + """replace() unconditionally overwrites and returns previous.""" + + def test_replace_overwrites(self): + """replace() overwrites existing entry.""" + registry = SourceRegistry() + s1 = _make_source("a", 1) + s2 = _make_source("b", 2) + registry.register("s1", s1) + registry.replace("s1", s2) + assert registry.get("s1") is s2 + + def test_replace_returns_previous(self): + """replace() returns the previous source object.""" + registry = SourceRegistry() + s1 = _make_source("a", 1) + s2 = _make_source("b", 2) + registry.register("s1", s1) + old = registry.replace("s1", s2) + assert old is s1 + + def test_replace_returns_none_if_no_previous(self): + """replace() returns None when there was no previous entry.""" + registry = SourceRegistry() + source = _make_source() + old = registry.replace("new_id", source) + assert old is None + + def test_replace_empty_id_raises(self): + """replace() with empty id raises ValueError.""" + registry = SourceRegistry() + with pytest.raises(ValueError): + registry.replace("", _make_source()) + + +# --------------------------------------------------------------------------- +# Unregister +# --------------------------------------------------------------------------- + + +class TestUnregister: + """unregister() removes and returns source.""" + + def test_unregister_removes_and_returns(self): + """unregister() removes entry and returns the source.""" + registry = SourceRegistry() + source = _make_source() + registry.register("s1", source) + removed = registry.unregister("s1") + assert removed is source + assert "s1" not in registry + + def test_unregister_missing_raises_key_error(self): + """unregister() on missing id raises KeyError.""" + registry = SourceRegistry() + with pytest.raises(KeyError): + registry.unregister("nonexistent") + + def test_unregister_decrements_length(self): + """unregister() decreases the registry length.""" + registry = SourceRegistry() + source = _make_source() + registry.register("s1", source) + assert len(registry) == 1 + registry.unregister("s1") + assert len(registry) == 0 + + +# --------------------------------------------------------------------------- +# Lookup: get() and get_optional() +# --------------------------------------------------------------------------- + + +class TestLookup: + """get() and get_optional() behaviors.""" + + def test_get_missing_raises_key_error(self): + """get() on missing id raises KeyError.""" + registry = SourceRegistry() + with pytest.raises(KeyError): + registry.get("nonexistent") + + def test_get_optional_missing_returns_none(self): + """get_optional() on missing id returns None.""" + registry = SourceRegistry() + result = registry.get_optional("nonexistent") + assert result is None + + def test_get_optional_existing_returns_source(self): + """get_optional() returns the source when it exists.""" + registry = SourceRegistry() + source = _make_source() + registry.register("s1", source) + result = registry.get_optional("s1") + assert result is source + + +# --------------------------------------------------------------------------- +# Introspection: __contains__, __len__, __iter__, clear(), list_ids() +# --------------------------------------------------------------------------- + + +class TestIntrospection: + """Dunder methods and introspection on SourceRegistry.""" + + def test_contains(self): + """__contains__ returns True for registered ids.""" + registry = SourceRegistry() + source = _make_source() + registry.register("s1", source) + assert "s1" in registry + assert "s2" not in registry + + def test_len_empty(self): + """__len__ returns 0 for empty registry.""" + registry = SourceRegistry() + assert len(registry) == 0 + + def test_len_after_registrations(self): + """__len__ returns correct count.""" + registry = SourceRegistry() + registry.register("s1", _make_source("a", 1)) + registry.register("s2", _make_source("b", 2)) + assert len(registry) == 2 + + def test_iter(self): + """__iter__ yields registered source ids.""" + registry = SourceRegistry() + s1 = _make_source("a", 1) + s2 = _make_source("b", 2) + registry.register("s1", s1) + registry.register("s2", s2) + ids = set(registry) + assert ids == {"s1", "s2"} + + def test_clear_removes_all(self): + """clear() removes all entries.""" + registry = SourceRegistry() + registry.register("s1", _make_source("a", 1)) + registry.register("s2", _make_source("b", 2)) + assert len(registry) == 2 + registry.clear() + assert len(registry) == 0 + assert "s1" not in registry + assert "s2" not in registry + + def test_list_ids_returns_list(self): + """list_ids() returns a list of registered ids.""" + registry = SourceRegistry() + registry.register("s1", _make_source("a", 1)) + registry.register("s2", _make_source("b", 2)) + ids = registry.list_ids() + assert isinstance(ids, list) + assert set(ids) == {"s1", "s2"} + + def test_list_ids_empty(self): + """list_ids() returns empty list for empty registry.""" + registry = SourceRegistry() + assert registry.list_ids() == [] + + def test_clear_then_register(self): + """After clear(), new registrations work normally.""" + registry = SourceRegistry() + s1 = _make_source("a", 1) + registry.register("s1", s1) + registry.clear() + s2 = _make_source("b", 2) + registry.register("s1", s2) + assert registry.get("s1") is s2 diff --git a/test-objective/unit/test_sources.py b/test-objective/unit/test_sources.py new file mode 100644 index 00000000..298eee9b --- /dev/null +++ b/test-objective/unit/test_sources.py @@ -0,0 +1,490 @@ +"""Specification-derived tests for all source types. + +Tests documented behaviors of ArrowTableSource, DictSource, ListSource, +and DerivedSource from orcapod.core.sources. +""" + +from unittest.mock import MagicMock + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams import Packet, Tag +from orcapod.core.sources import ArrowTableSource +from orcapod.core.sources.derived_source import DerivedSource +from orcapod.core.sources.dict_source import DictSource +from orcapod.core.sources.list_source import ListSource +from orcapod.errors import FieldNotResolvableError +from orcapod.types import ColumnConfig, Schema + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _simple_table(n_rows: int = 3) -> pa.Table: + return pa.table( + { + "name": pa.array([f"n{i}" for i in range(n_rows)], type=pa.large_string()), + "age": pa.array([20 + i for i in range(n_rows)], type=pa.int64()), + } + ) + + +# =========================================================================== +# ArrowTableSource +# =========================================================================== + + +class TestArrowTableSourceConstruction: + """ArrowTableSource construction behaviors.""" + + def test_normal_construction(self): + """A valid table with tag columns constructs successfully.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + assert source is not None + + def test_empty_table_raises(self): + """An empty table raises an error during construction.""" + empty = pa.table( + { + "name": pa.array([], type=pa.large_string()), + "age": pa.array([], type=pa.int64()), + } + ) + with pytest.raises(Exception): + ArrowTableSource(empty, tag_columns=["name"]) + + def test_missing_tag_columns_raises_value_error(self): + """Specifying tag columns not in the table raises ValueError.""" + table = _simple_table() + with pytest.raises(ValueError, match="tag_columns"): + ArrowTableSource(table, tag_columns=["nonexistent"]) + + def test_adds_system_tag_column(self): + """The source auto-adds system tag columns to the underlying table.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + table = source.as_table(all_info=True) + system_tag_cols = [c for c in table.column_names if c.startswith("_tag_")] + assert len(system_tag_cols) > 0 + + def test_adds_source_info_columns(self): + """The source adds source info columns (prefixed with _source_).""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + table = source.as_table(columns=ColumnConfig(source=True)) + source_cols = [c for c in table.column_names if c.startswith("_source_")] + assert len(source_cols) > 0 + + def test_source_id_populated(self): + """source_id property is populated (defaults to table hash).""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + assert source.source_id is not None + assert len(source.source_id) > 0 + + def test_source_id_explicit(self): + """Explicit source_id is preserved.""" + source = ArrowTableSource( + _simple_table(), + tag_columns=["name"], + source_id="my_source", + ) + assert source.source_id == "my_source" + + def test_producer_is_none(self): + """Root sources have producer == None.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + assert source.producer is None + + def test_upstreams_is_empty(self): + """Root sources have empty upstreams tuple.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + assert source.upstreams == () + + def test_no_tag_columns_valid(self): + """Construction with no tag columns is valid (all columns are packets).""" + source = ArrowTableSource(_simple_table(), tag_columns=[]) + tag_keys, packet_keys = source.keys() + assert tag_keys == () + assert "name" in packet_keys + assert "age" in packet_keys + + +class TestArrowTableSourceResolveField: + """ArrowTableSource.resolve_field() behaviors. + + NOTE: resolve_field is currently not implemented on ArrowTableSource + (raises NotImplementedError from RootSource base). These tests are + marked xfail until the implementation is restored. + """ + + NOT_IMPLEMENTED = pytest.mark.xfail( + reason="resolve_field not yet re-implemented after source refactor", + raises=NotImplementedError, + strict=True, + ) + + @NOT_IMPLEMENTED + def test_resolve_field_valid_record_id(self): + """resolve_field works with valid positional record_id.""" + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + value = source.resolve_field("row_0", "age") + assert value == 20 + + @NOT_IMPLEMENTED + def test_resolve_field_second_row(self): + """resolve_field returns data from the correct row.""" + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + value = source.resolve_field("row_1", "age") + assert value == 21 + + @NOT_IMPLEMENTED + def test_resolve_field_with_record_id_column(self): + """resolve_field works with named record_id column.""" + source = ArrowTableSource( + _simple_table(3), + tag_columns=["name"], + record_id_column="name", + ) + value = source.resolve_field("name=n1", "age") + assert value == 21 + + @NOT_IMPLEMENTED + def test_resolve_field_missing_record_raises(self): + """resolve_field raises FieldNotResolvableError for missing records.""" + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + with pytest.raises(FieldNotResolvableError): + source.resolve_field("row_999", "age") + + @NOT_IMPLEMENTED + def test_resolve_field_missing_field_raises(self): + """resolve_field raises FieldNotResolvableError for missing field names.""" + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + with pytest.raises(FieldNotResolvableError): + source.resolve_field("row_0", "nonexistent_field") + + @NOT_IMPLEMENTED + def test_resolve_field_invalid_record_id_format(self): + """resolve_field raises FieldNotResolvableError for invalid record_id format.""" + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + with pytest.raises(FieldNotResolvableError): + source.resolve_field("invalid_format", "age") + + @NOT_IMPLEMENTED + def test_resolve_field_tag_column(self): + """resolve_field can resolve tag column values too.""" + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + value = source.resolve_field("row_0", "name") + assert value == "n0" + + +class TestArrowTableSourceSchema: + """ArrowTableSource schema and identity behaviors.""" + + def test_pipeline_identity_structure_returns_schemas(self): + """pipeline_identity_structure returns (tag_schema, packet_schema).""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + result = source.pipeline_identity_structure() + assert isinstance(result, tuple) + assert len(result) == 2 + tag_schema, packet_schema = result + assert isinstance(tag_schema, Schema) + assert isinstance(packet_schema, Schema) + + def test_output_schema_returns_schemas(self): + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + tag_schema, packet_schema = source.output_schema() + assert "name" in tag_schema + assert "age" in packet_schema + + def test_output_schema_types(self): + """output_schema types match column data types.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + tag_schema, packet_schema = source.output_schema() + assert tag_schema["name"] is str + assert packet_schema["age"] is int + + def test_keys_returns_correct_split(self): + """keys() correctly separates tag and packet columns.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + tag_keys, packet_keys = source.keys() + assert "name" in tag_keys + assert "age" in packet_keys + assert "name" not in packet_keys + + +class TestArrowTableSourceIteration: + """ArrowTableSource iter_packets and as_table behaviors.""" + + def test_iter_packets_yields_tag_packet_pairs(self): + source = ArrowTableSource(_simple_table(3), tag_columns=["name"]) + pairs = list(source.iter_packets()) + assert len(pairs) == 3 + for tag, packet in pairs: + assert isinstance(tag, Tag) + assert isinstance(packet, Packet) + + def test_as_table_has_expected_columns(self): + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + table = source.as_table() + assert "name" in table.column_names + assert "age" in table.column_names + + def test_as_table_row_count(self): + """as_table row count matches input table row count.""" + source = ArrowTableSource(_simple_table(5), tag_columns=["name"]) + table = source.as_table() + assert table.num_rows == 5 + + def test_as_table_all_info_has_more_columns(self): + """as_table(all_info=True) has more columns than default.""" + source = ArrowTableSource(_simple_table(), tag_columns=["name"]) + table_default = source.as_table() + table_all = source.as_table(all_info=True) + assert table_all.num_columns > table_default.num_columns + + def test_iter_packets_count_matches_as_table_rows(self): + """iter_packets count equals as_table row count.""" + source = ArrowTableSource(_simple_table(4), tag_columns=["name"]) + pairs = list(source.iter_packets()) + table = source.as_table() + assert len(pairs) == table.num_rows + + +# =========================================================================== +# DictSource +# =========================================================================== + + +class TestDictSource: + """DictSource construction and delegation behaviors.""" + + def test_construction_from_list_of_dicts(self): + """DictSource can be constructed from a collection of dicts.""" + data = [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}] + source = DictSource(data=data, tag_columns=["x"]) + assert source is not None + + def test_delegates_to_arrow_table_source(self): + """DictSource produces valid iter_packets output.""" + data = [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}] + source = DictSource(data=data, tag_columns=["x"]) + pairs = list(source.iter_packets()) + assert len(pairs) == 2 + + def test_keys_correct(self): + data = [{"x": 1, "y": "a"}] + source = DictSource(data=data, tag_columns=["x"]) + tag_keys, packet_keys = source.keys() + assert "x" in tag_keys + assert "y" in packet_keys + + def test_source_id_populated(self): + data = [{"x": 1, "y": "a"}] + source = DictSource(data=data, tag_columns=["x"]) + assert source.source_id is not None + assert len(source.source_id) > 0 + + def test_producer_is_none(self): + data = [{"x": 1, "y": "a"}] + source = DictSource(data=data, tag_columns=["x"]) + assert source.producer is None + + def test_upstreams_is_empty(self): + data = [{"x": 1, "y": "a"}] + source = DictSource(data=data, tag_columns=["x"]) + assert source.upstreams == () + + def test_output_schema(self): + """DictSource output_schema delegates correctly.""" + data = [{"x": 1, "y": "a"}] + source = DictSource(data=data, tag_columns=["x"]) + tag_schema, packet_schema = source.output_schema() + assert "x" in tag_schema + assert "y" in packet_schema + + def test_as_table_has_correct_rows(self): + """DictSource as_table returns correct number of rows.""" + data = [{"x": 1, "y": "a"}, {"x": 2, "y": "b"}, {"x": 3, "y": "c"}] + source = DictSource(data=data, tag_columns=["x"]) + table = source.as_table() + assert table.num_rows == 3 + + def test_iter_packets_yields_tag_packet_pairs(self): + """DictSource iter_packets yields proper types.""" + data = [{"x": 1, "y": "a"}] + source = DictSource(data=data, tag_columns=["x"]) + pairs = list(source.iter_packets()) + assert len(pairs) == 1 + tag, packet = pairs[0] + assert isinstance(tag, Tag) + assert isinstance(packet, Packet) + + def test_multiple_packet_columns(self): + """DictSource handles multiple packet columns.""" + data = [{"tag": 1, "a": "x", "b": 10}] + source = DictSource(data=data, tag_columns=["tag"]) + _, packet_keys = source.keys() + assert "a" in packet_keys + assert "b" in packet_keys + + +# =========================================================================== +# ListSource +# =========================================================================== + + +class TestListSource: + """ListSource construction and behaviors.""" + + def test_construction_from_list(self): + """ListSource can be constructed from a list of elements.""" + source = ListSource(name="item", data=["a", "b", "c"]) + assert source is not None + + def test_iter_packets_yields_correct_count(self): + source = ListSource(name="item", data=["a", "b", "c"]) + pairs = list(source.iter_packets()) + assert len(pairs) == 3 + + def test_default_tag_is_element_index(self): + """Default tag function produces element_index tag.""" + source = ListSource(name="item", data=["a", "b"]) + tag_keys, _ = source.keys() + assert "element_index" in tag_keys + + def test_empty_list_raises_value_error(self): + """An empty list raises ValueError (empty table).""" + with pytest.raises(ValueError): + ListSource(name="item", data=[]) + + def test_custom_tag_function(self): + """Custom tag_function is used for tag generation.""" + source = ListSource( + name="item", + data=["a", "b"], + tag_function=lambda el, idx: {"pos": idx * 10}, + expected_tag_keys=["pos"], + ) + tag_keys, _ = source.keys() + assert "pos" in tag_keys + + def test_packet_column_name_matches(self): + """The packet column is named after the 'name' parameter.""" + source = ListSource(name="my_data", data=[1, 2, 3]) + _, packet_keys = source.keys() + assert "my_data" in packet_keys + + def test_source_id_populated(self): + """ListSource has a populated source_id.""" + source = ListSource(name="item", data=["a"]) + assert source.source_id is not None + assert len(source.source_id) > 0 + + def test_as_table_correct_row_count(self): + """ListSource as_table returns correct number of rows.""" + source = ListSource(name="item", data=["a", "b", "c", "d"]) + table = source.as_table() + assert table.num_rows == 4 + + def test_producer_is_none(self): + """ListSource has no producer (root source).""" + source = ListSource(name="item", data=["a"]) + assert source.producer is None + + def test_upstreams_is_empty(self): + """ListSource has empty upstreams.""" + source = ListSource(name="item", data=["a"]) + assert source.upstreams == () + + def test_integer_elements(self): + """ListSource works with integer elements.""" + source = ListSource(name="num", data=[10, 20, 30]) + pairs = list(source.iter_packets()) + assert len(pairs) == 3 + + def test_output_schema(self): + """ListSource output_schema has tag and packet fields.""" + source = ListSource(name="item", data=["a", "b"]) + tag_schema, packet_schema = source.output_schema() + assert "element_index" in tag_schema + assert "item" in packet_schema + + +# =========================================================================== +# DerivedSource +# =========================================================================== + + +class TestDerivedSource: + """DerivedSource behaviors before and after origin run.""" + + def _make_mock_origin(self, records=None): + """Create a mock origin node for DerivedSource testing.""" + mock_origin = MagicMock() + mock_origin.content_hash.return_value = MagicMock( + to_string=MagicMock(return_value="abcdef1234567890") + ) + mock_origin.output_schema.return_value = ( + Schema({"tag_col": str}), + Schema({"data_col": int}), + ) + mock_origin.keys.return_value = (("tag_col",), ("data_col",)) + mock_origin.get_all_records.return_value = records + return mock_origin + + def test_before_run_empty_stream(self): + """Before run(), DerivedSource presents an empty stream (zero rows).""" + mock_origin = self._make_mock_origin(records=None) + source = DerivedSource(origin=mock_origin) + table = source.as_table() + assert table.num_rows == 0 + + def test_before_run_correct_schema(self): + """Before run(), the empty stream has the correct schema columns.""" + mock_origin = self._make_mock_origin(records=None) + source = DerivedSource(origin=mock_origin) + table = source.as_table() + assert "tag_col" in table.column_names + assert "data_col" in table.column_names + + def test_source_id_derived_prefix(self): + """DerivedSource auto-generates a source_id with 'derived:' prefix.""" + mock_origin = self._make_mock_origin(records=None) + source = DerivedSource(origin=mock_origin) + assert source.source_id.startswith("derived:") + + def test_explicit_source_id(self): + """Explicit source_id overrides the auto-generated one.""" + mock_origin = self._make_mock_origin(records=None) + source = DerivedSource(origin=mock_origin, source_id="custom_id") + assert source.source_id == "custom_id" + + def test_output_schema_delegates_to_origin(self): + """output_schema delegates to origin node.""" + mock_origin = self._make_mock_origin(records=None) + source = DerivedSource(origin=mock_origin) + tag_schema, packet_schema = source.output_schema() + assert "tag_col" in tag_schema + assert "data_col" in packet_schema + + def test_keys_delegates_to_origin(self): + """keys() delegates to origin node.""" + mock_origin = self._make_mock_origin(records=None) + source = DerivedSource(origin=mock_origin) + tag_keys, packet_keys = source.keys() + assert "tag_col" in tag_keys + assert "data_col" in packet_keys + + def test_after_run_with_records(self): + """After run(), DerivedSource presents the computed records.""" + records_table = pa.table( + { + "tag_col": pa.array(["a", "b"], type=pa.large_string()), + "data_col": pa.array([1, 2], type=pa.int64()), + } + ) + mock_origin = self._make_mock_origin(records=records_table) + source = DerivedSource(origin=mock_origin) + table = source.as_table() + assert table.num_rows == 2 diff --git a/test-objective/unit/test_stream.py b/test-objective/unit/test_stream.py new file mode 100644 index 00000000..2457ae12 --- /dev/null +++ b/test-objective/unit/test_stream.py @@ -0,0 +1,546 @@ +"""Specification-derived tests for ArrowTableStream. + +Tests documented behaviors of ArrowTableStream construction, immutability, +schema/key introspection, iteration, table output, ColumnConfig filtering, +and format conversions. +""" + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams import Packet, Tag +from orcapod.core.streams import ArrowTableStream +from orcapod.types import ColumnConfig, Schema + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _simple_table(n_rows: int = 3) -> pa.Table: + """A table with one tag-eligible column and one packet column.""" + return pa.table( + { + "id": pa.array(list(range(n_rows)), type=pa.int64()), + "value": pa.array([f"v{i}" for i in range(n_rows)], type=pa.large_string()), + } + ) + + +def _multi_packet_table(n_rows: int = 3) -> pa.Table: + """A table with one tag column and two packet columns.""" + return pa.table( + { + "id": pa.array(list(range(n_rows)), type=pa.int64()), + "x": pa.array([i * 10 for i in range(n_rows)], type=pa.int64()), + "y": pa.array([f"y{i}" for i in range(n_rows)], type=pa.large_string()), + } + ) + + +def _make_stream( + tag_columns: list[str] | None = None, + n_rows: int = 3, + **kwargs, +) -> ArrowTableStream: + tag_columns = tag_columns if tag_columns is not None else ["id"] + return ArrowTableStream(_simple_table(n_rows), tag_columns=tag_columns, **kwargs) + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + """ArrowTableStream construction from a pa.Table.""" + + def test_basic_construction(self): + """Stream can be created from a pa.Table with tag_columns.""" + stream = _make_stream() + assert stream is not None + + def test_construction_with_system_tag_columns(self): + """Stream accepts system_tag_columns parameter.""" + table = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "value": pa.array(["a", "b"], type=pa.large_string()), + "sys": pa.array(["s1", "s2"], type=pa.large_string()), + } + ) + stream = ArrowTableStream( + table, tag_columns=["id"], system_tag_columns=["sys"] + ) + assert stream is not None + + def test_construction_with_source_info(self): + """Stream accepts source_info dict parameter.""" + stream = ArrowTableStream( + _simple_table(), + tag_columns=["id"], + source_info={"value": "test_source::row_0"}, + ) + assert stream is not None + + def test_construction_with_producer_and_upstreams(self): + """Stream accepts producer and upstreams parameters.""" + upstream = _make_stream() + # producer=None is the default; just verify upstreams tuple is stored + stream = ArrowTableStream( + _simple_table(), tag_columns=["id"], upstreams=(upstream,) + ) + assert stream.upstreams == (upstream,) + assert stream.producer is None + + def test_no_packet_columns_raises_value_error(self): + """Stream requires at least one packet column; ValueError if none.""" + table = pa.table({"id": pa.array([1, 2, 3], type=pa.int64())}) + with pytest.raises(ValueError): + ArrowTableStream(table, tag_columns=["id"]) + + def test_no_tag_columns_is_valid(self): + """All columns may be packet columns (no tags).""" + table = pa.table({"value": pa.array(["a", "b"], type=pa.large_string())}) + stream = ArrowTableStream(table, tag_columns=[]) + tag_keys, packet_keys = stream.keys() + assert tag_keys == () + assert "value" in packet_keys + + def test_multiple_tag_columns(self): + """Stream supports multiple tag columns.""" + table = pa.table( + { + "t1": pa.array([1, 2], type=pa.int64()), + "t2": pa.array(["a", "b"], type=pa.large_string()), + "val": pa.array([10.0, 20.0], type=pa.float64()), + } + ) + stream = ArrowTableStream(table, tag_columns=["t1", "t2"]) + tag_keys, packet_keys = stream.keys() + assert set(tag_keys) == {"t1", "t2"} + assert packet_keys == ("val",) + + def test_multiple_packet_columns(self): + """Stream supports multiple packet columns.""" + stream = ArrowTableStream( + _multi_packet_table(), tag_columns=["id"] + ) + _, packet_keys = stream.keys() + assert set(packet_keys) == {"x", "y"} + + +# --------------------------------------------------------------------------- +# keys() +# --------------------------------------------------------------------------- + + +class TestKeys: + """keys() returns (tag_keys, packet_keys) tuples.""" + + def test_keys_returns_tuple_of_tuples(self): + stream = _make_stream() + result = stream.keys() + assert isinstance(result, tuple) + assert len(result) == 2 + tag_keys, packet_keys = result + assert isinstance(tag_keys, tuple) + assert isinstance(packet_keys, tuple) + + def test_keys_correct_split(self): + stream = _make_stream(tag_columns=["id"]) + tag_keys, packet_keys = stream.keys() + assert "id" in tag_keys + assert "value" in packet_keys + assert "id" not in packet_keys + assert "value" not in tag_keys + + def test_keys_with_column_config_system_tags(self): + """When system_tags=True, system tag columns appear in tag_keys.""" + table = pa.table( + { + "id": pa.array([1], type=pa.int64()), + "value": pa.array(["a"], type=pa.large_string()), + "sys_col": pa.array(["s"], type=pa.large_string()), + } + ) + stream = ArrowTableStream( + table, tag_columns=["id"], system_tag_columns=["sys_col"] + ) + tag_keys_default, _ = stream.keys() + tag_keys_all, _ = stream.keys(columns=ColumnConfig(system_tags=True)) + # Default: system tags excluded from keys + assert len(tag_keys_all) > len(tag_keys_default) + + def test_keys_with_all_info(self): + """all_info=True includes system tags in tag_keys.""" + table = pa.table( + { + "id": pa.array([1], type=pa.int64()), + "value": pa.array(["a"], type=pa.large_string()), + "sys_col": pa.array(["s"], type=pa.large_string()), + } + ) + stream = ArrowTableStream( + table, tag_columns=["id"], system_tag_columns=["sys_col"] + ) + tag_keys_all, _ = stream.keys(all_info=True) + assert len(tag_keys_all) > 1 # id + system tag(s) + + def test_keys_no_tag_columns(self): + """With no tag columns, tag_keys is empty.""" + table = pa.table( + {"a": pa.array([1], type=pa.int64()), "b": pa.array([2], type=pa.int64())} + ) + stream = ArrowTableStream(table, tag_columns=[]) + tag_keys, packet_keys = stream.keys() + assert tag_keys == () + assert set(packet_keys) == {"a", "b"} + + +# --------------------------------------------------------------------------- +# output_schema() +# --------------------------------------------------------------------------- + + +class TestOutputSchema: + """output_schema() returns (tag_schema, packet_schema) as Schema objects.""" + + def test_returns_tuple_of_schemas(self): + stream = _make_stream() + tag_schema, packet_schema = stream.output_schema() + assert isinstance(tag_schema, Schema) + assert isinstance(packet_schema, Schema) + + def test_schema_field_names_match_keys(self): + stream = _make_stream(tag_columns=["id"]) + tag_schema, packet_schema = stream.output_schema() + tag_keys, packet_keys = stream.keys() + assert set(tag_schema.keys()) == set(tag_keys) + assert set(packet_schema.keys()) == set(packet_keys) + + def test_schema_types_match_table_column_types(self): + """output_schema types must be consistent with the actual data in as_table.""" + stream = _make_stream(tag_columns=["id"]) + tag_schema, packet_schema = stream.output_schema() + # tag schema type for "id" should be int + assert tag_schema["id"] is int + # packet schema type for "value" should be str + assert packet_schema["value"] is str + + def test_schema_with_multiple_types(self): + """Schema correctly reflects different column types.""" + table = pa.table( + { + "tag": pa.array([1], type=pa.int64()), + "col_int": pa.array([42], type=pa.int64()), + "col_str": pa.array(["hello"], type=pa.large_string()), + "col_float": pa.array([3.14], type=pa.float64()), + } + ) + stream = ArrowTableStream(table, tag_columns=["tag"]) + tag_schema, packet_schema = stream.output_schema() + assert tag_schema["tag"] is int + assert packet_schema["col_int"] is int + assert packet_schema["col_str"] is str + assert packet_schema["col_float"] is float + + def test_schema_with_system_tags_config(self): + """output_schema with system_tags=True includes system tag fields.""" + table = pa.table( + { + "id": pa.array([1], type=pa.int64()), + "value": pa.array(["a"], type=pa.large_string()), + "sys": pa.array(["s"], type=pa.large_string()), + } + ) + stream = ArrowTableStream( + table, tag_columns=["id"], system_tag_columns=["sys"] + ) + tag_schema_default, _ = stream.output_schema() + tag_schema_with_sys, _ = stream.output_schema( + columns=ColumnConfig(system_tags=True) + ) + assert len(tag_schema_with_sys) > len(tag_schema_default) + + +# --------------------------------------------------------------------------- +# iter_packets() +# --------------------------------------------------------------------------- + + +class TestIterPackets: + """iter_packets() yields (Tag, Packet) pairs.""" + + def test_yields_tag_packet_pairs(self): + stream = _make_stream(n_rows=2) + pairs = list(stream.iter_packets()) + assert len(pairs) == 2 + for tag, packet in pairs: + assert isinstance(tag, Tag) + assert isinstance(packet, Packet) + + def test_count_matches_row_count(self): + for n in [1, 5, 10]: + stream = _make_stream(n_rows=n) + pairs = list(stream.iter_packets()) + assert len(pairs) == n + + def test_iter_packets_idempotent(self): + """Iterating twice produces the same number of results (cached).""" + stream = _make_stream(n_rows=3) + first = list(stream.iter_packets()) + second = list(stream.iter_packets()) + assert len(first) == len(second) + + def test_single_row(self): + """iter_packets works with a single-row table.""" + stream = _make_stream(n_rows=1) + pairs = list(stream.iter_packets()) + assert len(pairs) == 1 + tag, packet = pairs[0] + assert isinstance(tag, Tag) + assert isinstance(packet, Packet) + + def test_no_tag_columns_still_yields_packets(self): + """iter_packets works when there are no tag columns.""" + table = pa.table({"value": pa.array(["a", "b"], type=pa.large_string())}) + stream = ArrowTableStream(table, tag_columns=[]) + pairs = list(stream.iter_packets()) + assert len(pairs) == 2 + + +# --------------------------------------------------------------------------- +# as_table() consistency with iter_packets() +# --------------------------------------------------------------------------- + + +class TestAsTable: + """as_table() returns a pa.Table consistent with iter_packets.""" + + def test_as_table_returns_arrow_table(self): + stream = _make_stream() + table = stream.as_table() + assert isinstance(table, pa.Table) + + def test_as_table_row_count_matches_iter_packets(self): + stream = _make_stream(n_rows=4) + table = stream.as_table() + pairs = list(stream.iter_packets()) + assert table.num_rows == len(pairs) + + def test_as_table_contains_tag_and_packet_columns(self): + stream = _make_stream(tag_columns=["id"]) + table = stream.as_table() + assert "id" in table.column_names + assert "value" in table.column_names + + def test_as_table_column_count_matches_keys(self): + """Default as_table columns match keys() tag + packet columns.""" + stream = _make_stream(tag_columns=["id"]) + table = stream.as_table() + tag_keys, packet_keys = stream.keys() + expected_cols = set(tag_keys) | set(packet_keys) + assert set(table.column_names) == expected_cols + + def test_as_table_data_values_consistent(self): + """The data in as_table matches the original input data.""" + table_in = _simple_table(3) + stream = ArrowTableStream(table_in, tag_columns=["id"]) + table_out = stream.as_table() + assert table_out.column("id").to_pylist() == [0, 1, 2] + assert table_out.column("value").to_pylist() == ["v0", "v1", "v2"] + + +# --------------------------------------------------------------------------- +# ColumnConfig filtering +# --------------------------------------------------------------------------- + + +class TestColumnConfigFiltering: + """ColumnConfig controls which columns appear in keys/schema/table.""" + + def test_default_excludes_system_tags(self): + """Default ColumnConfig excludes system tag columns.""" + table = pa.table( + { + "id": pa.array([1], type=pa.int64()), + "val": pa.array(["x"], type=pa.large_string()), + "stag": pa.array(["t"], type=pa.large_string()), + } + ) + stream = ArrowTableStream( + table, tag_columns=["id"], system_tag_columns=["stag"] + ) + tag_keys, _ = stream.keys() + # System tag columns are prefixed with _tag_ internally + assert all(not k.startswith("_tag_") for k in tag_keys) + + def test_all_info_includes_everything(self): + """all_info=True should include source, context, system_tags columns.""" + stream = _make_stream() + table_default = stream.as_table() + table_all = stream.as_table(all_info=True) + assert table_all.num_columns >= table_default.num_columns + + def test_source_column_config(self): + """source=True includes source info columns in as_table.""" + stream = _make_stream() + table_no_source = stream.as_table() + table_with_source = stream.as_table( + columns=ColumnConfig(source=True) + ) + assert table_with_source.num_columns >= table_no_source.num_columns + + def test_context_column_config(self): + """context=True includes context columns in as_table.""" + stream = _make_stream() + table_no_ctx = stream.as_table() + table_with_ctx = stream.as_table(columns=ColumnConfig(context=True)) + assert table_with_ctx.num_columns >= table_no_ctx.num_columns + + def test_system_tags_in_as_table(self): + """system_tags=True includes system tag columns in the output table.""" + table = pa.table( + { + "id": pa.array([1], type=pa.int64()), + "val": pa.array(["x"], type=pa.large_string()), + "stag": pa.array(["t"], type=pa.large_string()), + } + ) + stream = ArrowTableStream( + table, tag_columns=["id"], system_tag_columns=["stag"] + ) + table_default = stream.as_table() + table_with_sys = stream.as_table(columns=ColumnConfig(system_tags=True)) + assert table_with_sys.num_columns > table_default.num_columns + + def test_column_config_as_dict(self): + """ColumnConfig can be passed as a dict.""" + stream = _make_stream() + table = stream.as_table(columns={"source": True}) + assert isinstance(table, pa.Table) + + def test_keys_schema_table_consistency_with_config(self): + """keys(), output_schema(), and as_table() agree under the same ColumnConfig.""" + stream = _make_stream(tag_columns=["id"]) + tag_keys, packet_keys = stream.keys() + tag_schema, packet_schema = stream.output_schema() + table = stream.as_table() + + assert set(tag_schema.keys()) == set(tag_keys) + assert set(packet_schema.keys()) == set(packet_keys) + expected_cols = set(tag_keys) | set(packet_keys) + assert set(table.column_names) == expected_cols + + +# --------------------------------------------------------------------------- +# Format conversions +# --------------------------------------------------------------------------- + + +class TestFormatConversions: + """as_polars_df(), as_pandas_df(), as_lazy_frame() produce expected types.""" + + def test_as_polars_df(self): + import polars as pl + + stream = _make_stream() + df = stream.as_polars_df() + assert isinstance(df, pl.DataFrame) + assert df.shape[0] == 3 + + def test_as_pandas_df(self): + import pandas as pd + + stream = _make_stream() + df = stream.as_pandas_df() + assert isinstance(df, pd.DataFrame) + assert len(df) == 3 + + def test_as_lazy_frame(self): + import polars as pl + + stream = _make_stream() + lf = stream.as_lazy_frame() + assert isinstance(lf, pl.LazyFrame) + + def test_as_polars_df_preserves_columns(self): + """Polars DataFrame has the same columns as as_table.""" + stream = _make_stream(tag_columns=["id"]) + table = stream.as_table() + df = stream.as_polars_df() + assert set(df.columns) == set(table.column_names) + + def test_as_pandas_df_preserves_row_count(self): + """Pandas DataFrame has the same row count.""" + stream = _make_stream(n_rows=5) + df = stream.as_pandas_df() + assert len(df) == 5 + + def test_as_lazy_frame_collects_to_correct_shape(self): + """LazyFrame collects to the correct shape.""" + import polars as pl + + stream = _make_stream(n_rows=4) + lf = stream.as_lazy_frame() + df = lf.collect() + assert isinstance(df, pl.DataFrame) + assert df.shape[0] == 4 + + def test_format_conversions_with_column_config(self): + """Format conversions respect ColumnConfig.""" + import polars as pl + + stream = _make_stream() + df_default = stream.as_polars_df() + df_all = stream.as_polars_df(all_info=True) + assert df_all.shape[1] >= df_default.shape[1] + + +# --------------------------------------------------------------------------- +# Immutability +# --------------------------------------------------------------------------- + + +class TestImmutability: + """ArrowTableStream is immutable -- no public mutation methods.""" + + def test_as_table_returns_consistent_data(self): + """Repeated as_table calls return the same data.""" + stream = _make_stream(n_rows=3) + t1 = stream.as_table() + t2 = stream.as_table() + assert t1.equals(t2) + + def test_producer_is_none_for_standalone_stream(self): + """A stream created without a producer has producer == None.""" + stream = _make_stream() + assert stream.producer is None + + def test_upstreams_is_empty_for_standalone_stream(self): + """A stream created without upstreams has upstreams == ().""" + stream = _make_stream() + assert stream.upstreams == () + + def test_iter_packets_same_on_repeated_calls(self): + """Iterating multiple times yields consistent data.""" + stream = _make_stream(n_rows=3) + first = list(stream.iter_packets()) + second = list(stream.iter_packets()) + assert len(first) == len(second) == 3 + + def test_output_schema_stable(self): + """output_schema() returns the same result on repeated calls.""" + stream = _make_stream() + s1 = stream.output_schema() + s2 = stream.output_schema() + assert s1 == s2 + + def test_keys_stable(self): + """keys() returns the same result on repeated calls.""" + stream = _make_stream() + k1 = stream.keys() + k2 = stream.keys() + assert k1 == k2 diff --git a/test-objective/unit/test_tag.py b/test-objective/unit/test_tag.py new file mode 100644 index 00000000..a7474f58 --- /dev/null +++ b/test-objective/unit/test_tag.py @@ -0,0 +1,157 @@ +"""Specification-derived tests for Tag.""" + +import pyarrow as pa +import pytest + +from orcapod.core.datagrams.datagram import Datagram +from orcapod.core.datagrams.tag_packet import Tag +from orcapod.system_constants import constants +from orcapod.types import ColumnConfig + +# Use the actual system tag prefix from constants +_SYS_TAG_KEY = f"{constants.SYSTEM_TAG_PREFIX}src:abc" + + +def _make_context(): + """Create a DataContext for tests.""" + from orcapod.contexts import resolve_context + return resolve_context(None) + + +# --------------------------------------------------------------------------- +# System tags stored separately from data columns +# --------------------------------------------------------------------------- + +class TestTagSystemTagsSeparation: + """System tags are stored separately from data columns.""" + + def test_system_tags_not_in_keys_by_default(self): + ctx = _make_context() + tag = Tag({"x": 1, "y": "hello"}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + keys = list(tag.keys()) + assert "x" in keys + assert "y" in keys + assert not any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in keys) + + def test_system_tags_not_in_as_dict_by_default(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + d = tag.as_dict() + assert not any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in d) + + def test_system_tags_not_in_as_table_by_default(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + table = tag.as_table() + assert not any(name.startswith(constants.SYSTEM_TAG_PREFIX) for name in table.column_names) + + def test_system_tags_not_in_schema_by_default(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + s = tag.schema() + assert not any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in s) + + +# --------------------------------------------------------------------------- +# System tags included with ColumnConfig +# --------------------------------------------------------------------------- + +class TestTagSystemTagsWithConfig: + """With ColumnConfig system_tags=True or all_info=True, system tags are included.""" + + def test_keys_with_system_tags_true(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + keys = list(tag.keys(columns=ColumnConfig(system_tags=True))) + assert any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in keys) + + def test_as_dict_with_system_tags_true(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + d = tag.as_dict(columns=ColumnConfig(system_tags=True)) + assert any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in d) + + def test_as_table_with_system_tags_true(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + table = tag.as_table(columns=ColumnConfig(system_tags=True)) + assert any(name.startswith(constants.SYSTEM_TAG_PREFIX) for name in table.column_names) + + def test_keys_with_all_info(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + keys = list(tag.keys(columns=ColumnConfig.all())) + assert any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in keys) + + def test_schema_with_system_tags_true(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + s = tag.schema(columns=ColumnConfig(system_tags=True)) + assert any(k.startswith(constants.SYSTEM_TAG_PREFIX) for k in s) + + +# --------------------------------------------------------------------------- +# system_tags() returns a dict COPY +# --------------------------------------------------------------------------- + +class TestTagSystemTagsCopy: + """system_tags() returns a dict COPY (not a reference).""" + + def test_system_tags_returns_dict(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + st = tag.system_tags() + assert isinstance(st, dict) + assert _SYS_TAG_KEY in st + + def test_system_tags_is_copy(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + st = tag.system_tags() + st[_SYS_TAG_KEY] = "modified" + # Original should be unchanged + assert tag.system_tags()[_SYS_TAG_KEY] == "val" + + +# --------------------------------------------------------------------------- +# copy() preserves system tags +# --------------------------------------------------------------------------- + +class TestTagCopy: + """copy() preserves system tags.""" + + def test_copy_preserves_system_tags(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={_SYS_TAG_KEY: "val"}) + copied = tag.copy() + assert copied is not tag + assert copied.system_tags() == tag.system_tags() + + def test_copy_preserves_data(self): + ctx = _make_context() + tag = Tag({"x": 1, "y": "hello"}, data_context=ctx, system_tags={}) + copied = tag.copy() + assert copied["x"] == 1 + assert copied["y"] == "hello" + + +# --------------------------------------------------------------------------- +# as_datagram() returns Datagram, not Tag +# --------------------------------------------------------------------------- + +class TestTagAsDatagram: + """as_datagram() returns a Datagram (not Tag).""" + + def test_as_datagram_returns_datagram_type(self): + ctx = _make_context() + tag = Tag({"x": 1}, data_context=ctx, system_tags={}) + dg = tag.as_datagram() + assert isinstance(dg, Datagram) + assert not isinstance(dg, Tag) + + def test_as_datagram_preserves_data(self): + ctx = _make_context() + tag = Tag({"x": 1, "y": "hello"}, data_context=ctx, system_tags={}) + dg = tag.as_datagram() + assert dg["x"] == 1 + assert dg["y"] == "hello" diff --git a/test-objective/unit/test_tracker.py b/test-objective/unit/test_tracker.py new file mode 100644 index 00000000..2853d124 --- /dev/null +++ b/test-objective/unit/test_tracker.py @@ -0,0 +1,120 @@ +"""Specification-derived tests for tracker and pipeline. + +Tests based on TrackerProtocol, TrackerManagerProtocol, and +Pipeline documented behavior. +""" + +from __future__ import annotations + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.operators import Join +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.streams import ArrowTableStream +from orcapod.core.tracker import BasicTrackerManager +from orcapod.databases import InMemoryArrowDatabase +from orcapod.pipeline import Pipeline + + +def _double(x: int) -> int: + return x * 2 + + +def _make_pipeline( + tracker_manager: BasicTrackerManager | None = None, +) -> Pipeline: + return Pipeline( + name="test", + pipeline_database=InMemoryArrowDatabase(), + tracker_manager=tracker_manager, + auto_compile=False, + ) + + +def _make_stream(n: int = 3) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array([i * 10 for i in range(n)], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +class TestBasicTrackerManager: + """Per TrackerManagerProtocol: manages tracker registration, broadcasting, + and no_tracking context.""" + + def test_register_and_get_active_trackers(self): + mgr = BasicTrackerManager() + tracker = _make_pipeline(tracker_manager=mgr) + tracker.set_active(True) + active = mgr.get_active_trackers() + assert tracker in active + + def test_deregister_removes_tracker(self): + mgr = BasicTrackerManager() + tracker = _make_pipeline(tracker_manager=mgr) + mgr.deregister_tracker(tracker) + assert tracker not in mgr.get_active_trackers() + + def test_no_tracking_context_suspends_recording(self): + mgr = BasicTrackerManager() + tracker = _make_pipeline(tracker_manager=mgr) + tracker.set_active(True) + with mgr.no_tracking(): + # Invocations inside this block should not be recorded + active = mgr.get_active_trackers() + assert len(active) == 0 + # After exiting, tracker should be active again + active = mgr.get_active_trackers() + assert tracker in active + + +class TestPipelineTracker: + """Per design, Pipeline records pipeline structure as a directed graph.""" + + def test_records_function_pod_invocation(self): + mgr = BasicTrackerManager() + tracker = _make_pipeline(tracker_manager=mgr) + tracker.set_active(True) + + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + + # Explicitly record the invocation + tracker.record_function_pod_invocation(pod, stream) + + # The tracker should have recorded at least one node + assert len(tracker.nodes) >= 1 + + def test_reset_clears_state(self): + mgr = BasicTrackerManager() + tracker = _make_pipeline(tracker_manager=mgr) + tracker.set_active(True) + + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + tracker.record_function_pod_invocation(pod, stream) + + tracker.reset() + assert len(tracker.nodes) == 0 + + def test_compile_builds_graph(self): + mgr = BasicTrackerManager() + tracker = _make_pipeline(tracker_manager=mgr) + tracker.set_active(True) + + pf = PythonPacketFunction(_double, output_keys="result") + pod = FunctionPod(packet_function=pf) + stream = _make_stream() + tracker.record_function_pod_invocation(pod, stream) + + tracker.compile() + graph = tracker.graph + assert graph is not None + assert graph.number_of_nodes() >= 1 diff --git a/test-objective/unit/test_types.py b/test-objective/unit/test_types.py new file mode 100644 index 00000000..3b4bf88c --- /dev/null +++ b/test-objective/unit/test_types.py @@ -0,0 +1,275 @@ +"""Specification-derived tests for Schema, ColumnConfig, and ContentHash.""" + +import uuid + +import pytest + +from orcapod.types import ColumnConfig, ContentHash, Schema + + +# --------------------------------------------------------------------------- +# Schema basics +# --------------------------------------------------------------------------- + +class TestSchemaImmutableMapping: + """Schema behaves as an immutable Mapping[str, DataType].""" + + def test_schema_acts_as_mapping(self): + s = Schema({"x": int, "y": str}) + assert "x" in s + assert s["x"] == int + assert len(s) == 2 + assert set(s) == {"x", "y"} + + def test_schema_is_immutable(self): + s = Schema({"x": int}) + with pytest.raises(TypeError): + s["x"] = float + + def test_schema_equality(self): + a = Schema({"x": int, "y": str}) + b = Schema({"x": int, "y": str}) + assert a == b + + def test_schema_inequality_different_types(self): + a = Schema({"x": int}) + b = Schema({"x": float}) + assert a != b + + +class TestSchemaOptionalFields: + """Schema supports optional_fields.""" + + def test_optional_fields_default_empty(self): + s = Schema({"x": int}) + assert s.optional_fields == frozenset() + + def test_optional_fields_set_at_construction(self): + s = Schema({"x": int, "y": str}, optional_fields={"y"}) + assert "y" in s.optional_fields + assert "x" not in s.optional_fields + + def test_optional_fields_can_include_unknown_fields(self): + # Schema doesn't validate optional_fields against actual fields + s = Schema({"x": int}, optional_fields={"z"}) + assert "z" in s.optional_fields + + +class TestSchemaEmpty: + """Schema.empty() returns a zero-field schema.""" + + def test_empty_schema_has_no_fields(self): + s = Schema.empty() + assert len(s) == 0 + assert list(s) == [] + + +class TestSchemaMerge: + """Schema.merge() raises ValueError on type conflicts.""" + + def test_merge_disjoint_schemas(self): + a = Schema({"x": int}) + b = Schema({"y": str}) + merged = a.merge(b) + assert "x" in merged + assert "y" in merged + + def test_merge_overlapping_same_type(self): + a = Schema({"x": int, "y": str}) + b = Schema({"x": int, "z": float}) + merged = a.merge(b) + assert merged["x"] == int + assert "z" in merged + + def test_merge_raises_on_type_conflict(self): + a = Schema({"x": int}) + b = Schema({"x": str}) + with pytest.raises(ValueError): + a.merge(b) + + +class TestSchemaSelect: + """Schema.select() raises KeyError on missing fields.""" + + def test_select_existing_fields(self): + s = Schema({"x": int, "y": str, "z": float}) + selected = s.select("x", "z") + assert set(selected) == {"x", "z"} + + def test_select_raises_on_missing_field(self): + s = Schema({"x": int}) + with pytest.raises(KeyError): + s.select("x", "missing") + + +class TestSchemaDrop: + """Schema.drop() silently ignores missing fields.""" + + def test_drop_existing_fields(self): + s = Schema({"x": int, "y": str, "z": float}) + dropped = s.drop("y") + assert set(dropped) == {"x", "z"} + + def test_drop_missing_field_silently_ignored(self): + s = Schema({"x": int, "y": str}) + dropped = s.drop("nonexistent") + assert set(dropped) == {"x", "y"} + + def test_drop_mix_of_existing_and_missing(self): + s = Schema({"x": int, "y": str}) + dropped = s.drop("x", "nonexistent") + assert set(dropped) == {"y"} + + +class TestSchemaCompatibility: + """Schema.is_compatible_with() returns True when other is superset.""" + + def test_compatible_when_other_is_superset(self): + small = Schema({"x": int}) + big = Schema({"x": int, "y": str}) + assert small.is_compatible_with(big) + + def test_compatible_with_itself(self): + s = Schema({"x": int}) + assert s.is_compatible_with(s) + + def test_not_compatible_when_field_missing(self): + a = Schema({"x": int, "y": str}) + b = Schema({"x": int}) + assert not a.is_compatible_with(b) + + def test_not_compatible_when_type_differs(self): + a = Schema({"x": int}) + b = Schema({"x": str}) + assert not a.is_compatible_with(b) + + +class TestSchemaWithValues: + """Schema.with_values() overrides silently (no errors).""" + + def test_with_values_adds_new_field(self): + s = Schema({"x": int}) + updated = s.with_values({"y": str}) + assert "y" in updated + assert "x" in updated + + def test_with_values_overrides_existing_type(self): + s = Schema({"x": int}) + updated = s.with_values({"x": float}) + assert updated["x"] == float + + def test_with_values_does_not_mutate_original(self): + s = Schema({"x": int}) + s.with_values({"x": float}) + assert s["x"] == int + + +# --------------------------------------------------------------------------- +# ContentHash +# --------------------------------------------------------------------------- + +class TestContentHash: + """ContentHash is a frozen dataclass with method+digest.""" + + def test_content_hash_is_frozen(self): + h = ContentHash(method="sha256", digest=b"\x00" * 32) + with pytest.raises(AttributeError): + h.method = "md5" + + def test_content_hash_has_method_and_digest(self): + h = ContentHash(method="sha256", digest=b"\xab\xcd") + assert h.method == "sha256" + assert h.digest == b"\xab\xcd" + + +class TestContentHashConversions: + """ContentHash conversions: to_hex, to_int, to_uuid, to_base64, to_string.""" + + def _make_hash(self): + return ContentHash(method="sha256", digest=b"\x01\x02\x03\x04" * 4) + + def test_to_hex_returns_string(self): + h = self._make_hash() + hex_str = h.to_hex() + assert isinstance(hex_str, str) + assert all(c in "0123456789abcdef" for c in hex_str) + + def test_to_int_returns_integer(self): + h = self._make_hash() + assert isinstance(h.to_int(), int) + + def test_to_uuid_returns_uuid(self): + h = self._make_hash() + result = h.to_uuid() + assert isinstance(result, uuid.UUID) + + def test_to_base64_returns_string(self): + h = self._make_hash() + b64 = h.to_base64() + assert isinstance(b64, str) + + def test_to_string_returns_string(self): + h = self._make_hash() + s = h.to_string() + assert isinstance(s, str) + + def test_from_string_roundtrip(self): + h = self._make_hash() + s = h.to_string() + restored = ContentHash.from_string(s) + assert restored.method == h.method + assert restored.digest == h.digest + + +# --------------------------------------------------------------------------- +# ColumnConfig +# --------------------------------------------------------------------------- + +class TestColumnConfig: + """ColumnConfig is frozen, has .all() and .data_only() convenience methods.""" + + def test_column_config_is_frozen(self): + cc = ColumnConfig() + with pytest.raises(AttributeError): + cc.meta = True + + def test_all_sets_everything_true(self): + cc = ColumnConfig.all() + assert cc.meta is True + assert cc.source is True + assert cc.system_tags is True + assert cc.context is True + + def test_data_only_excludes_extras(self): + cc = ColumnConfig.data_only() + assert cc.meta is False + assert cc.source is False + assert cc.system_tags is False + + def test_default_construction(self): + cc = ColumnConfig() + assert isinstance(cc, ColumnConfig) + + +class TestColumnConfigHandleConfig: + """ColumnConfig.handle_config() normalizes dict/None/instance inputs.""" + + def test_handle_config_none_returns_default(self): + result = ColumnConfig.handle_config(None) + assert isinstance(result, ColumnConfig) + + def test_handle_config_instance_passes_through(self): + cc = ColumnConfig.all() + result = ColumnConfig.handle_config(cc) + assert result is cc + + def test_handle_config_dict_creates_config(self): + result = ColumnConfig.handle_config({"meta": True}) + assert isinstance(result, ColumnConfig) + assert result.meta is True + + def test_handle_config_all_info_flag(self): + result = ColumnConfig.handle_config(None, all_info=True) + assert result.meta is True + assert result.source is True + assert result.system_tags is True diff --git a/tests/test_channels/__init__.py b/tests/test_channels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_channels/test_async_execute.py b/tests/test_channels/test_async_execute.py new file mode 100644 index 00000000..c1c73036 --- /dev/null +++ b/tests/test_channels/test_async_execute.py @@ -0,0 +1,838 @@ +""" +Comprehensive tests for async_execute on operators and FunctionPod. + +Covers: +- StaticOutputPod._materialize_to_stream round-trip +- UnaryOperator barrier-mode async_execute (Select, Drop, Map, Filter, Batch) +- BinaryOperator barrier-mode async_execute (MergeJoin, SemiJoin) +- NonZeroInputOperator barrier-mode async_execute (Join) +- FunctionPod streaming async_execute +- FunctionPod concurrency control (max_concurrency) +- PythonPacketFunction.direct_async_call via run_in_executor +- End-to-end multi-stage async pipeline wiring +- Error propagation through channels +- NodeConfig / PipelineConfig integration with FunctionPod +""" + +from __future__ import annotations + +import asyncio +from typing import cast + +import pyarrow as pa +import pytest + +from orcapod.channels import Channel +from orcapod.core.datagrams import Packet +from orcapod.core.function_pod import FunctionPod +from orcapod.core.operators import ( + Batch, + DropPacketColumns, + DropTagColumns, + Join, + MapPackets, + MapTags, + MergeJoin, + PolarsFilter, + SelectPacketColumns, + SelectTagColumns, + SemiJoin, +) +from orcapod.core.operators.static_output_pod import StaticOutputOperatorPod +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.types import NodeConfig, PipelineConfig + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_stream(n: int = 3) -> ArrowTableStream: + """Stream with tag=id, packet=x (ints).""" + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_two_col_stream(n: int = 3) -> ArrowTableStream: + """Stream with tag=id, packet={x, y}.""" + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + "y": pa.array([i * 10 for i in range(n)], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_name_stream() -> ArrowTableStream: + """Stream with tag=id, packet=name (str).""" + table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "name": pa.array(["alice", "bob", "carol"], type=pa.large_string()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +async def feed_stream_to_channel(stream: ArrowTableStream, ch: Channel) -> None: + """Push all (tag, packet) pairs from a stream into a channel, then close.""" + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + +async def collect_output(ch: Channel) -> list[tuple]: + """Collect all (tag, packet) pairs from a channel's reader.""" + return await ch.reader.collect() + + +# --------------------------------------------------------------------------- +# 1. _materialize_to_stream round-trip +# --------------------------------------------------------------------------- + + +class TestMaterializeToStream: + def test_round_trip_preserves_data(self): + stream = make_stream(5) + rows = list(stream.iter_packets()) + rebuilt = StaticOutputOperatorPod._materialize_to_stream(rows) + + original_table = stream.as_table() + rebuilt_table = rebuilt.as_table() + + assert ( + original_table.column("id").to_pylist() + == rebuilt_table.column("id").to_pylist() + ) + assert ( + original_table.column("x").to_pylist() + == rebuilt_table.column("x").to_pylist() + ) + + def test_round_trip_preserves_schema(self): + stream = make_stream(3) + rows = list(stream.iter_packets()) + rebuilt = StaticOutputOperatorPod._materialize_to_stream(rows) + + orig_tag, orig_pkt = stream.output_schema() + rebuilt_tag, rebuilt_pkt = rebuilt.output_schema() + assert dict(orig_tag) == dict(rebuilt_tag) + assert dict(orig_pkt) == dict(rebuilt_pkt) + + def test_empty_rows_raises(self): + with pytest.raises(ValueError, match="empty"): + StaticOutputOperatorPod._materialize_to_stream([]) + + def test_round_trip_two_col_stream(self): + stream = make_two_col_stream(4) + rows = list(stream.iter_packets()) + rebuilt = StaticOutputOperatorPod._materialize_to_stream(rows) + + original = stream.as_table() + rebuilt_t = rebuilt.as_table() + assert original.column("x").to_pylist() == rebuilt_t.column("x").to_pylist() + assert original.column("y").to_pylist() == rebuilt_t.column("y").to_pylist() + + +# --------------------------------------------------------------------------- +# 3. PythonPacketFunction.direct_async_call +# --------------------------------------------------------------------------- + + +class TestDirectAsyncCall: + @pytest.mark.asyncio + async def test_direct_async_call_returns_correct_result(self): + def add(x: int, y: int) -> int: + return x + y + + pf = PythonPacketFunction(add, output_keys="result") + packet = Packet({"x": 3, "y": 5}) + result, captured = await pf.direct_async_call(packet) + assert captured.success is True + assert result is not None + assert result.as_dict()["result"] == 8 + + @pytest.mark.asyncio + async def test_async_call_multiple_packets(self): + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + raw_results = await asyncio.gather( + pf.async_call(Packet({"x": 1})), + pf.async_call(Packet({"x": 2})), + pf.async_call(Packet({"x": 3})), + ) + results = [r for r, _captured in raw_results] + assert all(r is not None for r in results) + values = [r.as_dict()["result"] for r in results if r is not None] + assert values == [2, 4, 6] + + @pytest.mark.asyncio + async def test_async_call_runs_in_thread(self): + """Verify the function actually runs (proves run_in_executor works).""" + import threading + + call_threads = [] + + def record_thread(x: int) -> int: + call_threads.append(threading.current_thread().name) + return x + + pf = PythonPacketFunction(record_thread, output_keys="result") + result, captured = await pf.direct_async_call(Packet({"x": 42})) + assert captured.success is True + assert result is not None + assert len(call_threads) == 1 + + +# --------------------------------------------------------------------------- +# 4. UnaryOperator barrier-mode async_execute +# --------------------------------------------------------------------------- + + +class TestUnaryOperatorAsyncExecute: + @pytest.mark.asyncio + async def test_select_tag_columns(self): + stream = make_two_col_stream(3) + op = SelectTagColumns(["id"]) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + for tag, packet in results: + assert "id" in tag.keys() + + @pytest.mark.asyncio + async def test_select_packet_columns(self): + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + for _, packet in results: + pkt_dict = packet.as_dict() + assert "x" in pkt_dict + assert "y" not in pkt_dict + + @pytest.mark.asyncio + async def test_drop_packet_columns(self): + stream = make_two_col_stream(3) + op = DropPacketColumns(["y"]) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + for _, packet in results: + pkt_dict = packet.as_dict() + assert "x" in pkt_dict + assert "y" not in pkt_dict + + @pytest.mark.asyncio + async def test_drop_tag_columns(self): + # Need multi-tag stream + table = pa.table( + { + "a": pa.array([1, 2], type=pa.int64()), + "b": pa.array([10, 20], type=pa.int64()), + "x": pa.array([100, 200], type=pa.int64()), + } + ) + stream = ArrowTableStream(table, tag_columns=["a", "b"]) + op = DropTagColumns(["b"]) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 2 + for tag, _ in results: + tag_keys = tag.keys() + assert "a" in tag_keys + assert "b" not in tag_keys + + @pytest.mark.asyncio + async def test_map_tags(self): + stream = make_stream(3) + op = MapTags({"id": "row_id"}, drop_unmapped=True) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + for tag, _ in results: + assert "row_id" in tag.keys() + assert "id" not in tag.keys() + + @pytest.mark.asyncio + async def test_map_packets(self): + stream = make_stream(3) + op = MapPackets({"x": "value"}, drop_unmapped=True) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + for _, packet in results: + pkt_dict = packet.as_dict() + assert "value" in pkt_dict + assert "x" not in pkt_dict + + @pytest.mark.asyncio + async def test_polars_filter(self): + stream = make_stream(5) + op = PolarsFilter(constraints={"id": 2}) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 1 + tag, packet = results[0] + assert tag.as_dict()["id"] == 2 + assert packet.as_dict()["x"] == 2 + + @pytest.mark.asyncio + async def test_batch_operator(self): + stream = make_stream(6) + op = Batch(batch_size=2) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 # 6 rows / batch_size=2 + + +# --------------------------------------------------------------------------- +# 5. BinaryOperator barrier-mode async_execute +# --------------------------------------------------------------------------- + + +class TestBinaryOperatorAsyncExecute: + @pytest.mark.asyncio + async def test_semi_join(self): + left = make_stream(5) + right_table = pa.table( + { + "id": pa.array([1, 3], type=pa.int64()), + "z": pa.array([100, 300], type=pa.int64()), + } + ) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + op = SemiJoin() + + left_ch = Channel(buffer_size=16) + right_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(left, left_ch) + await feed_stream_to_channel(right, right_ch) + + await op.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [1, 3] + + @pytest.mark.asyncio + async def test_merge_join(self): + left_table = pa.table( + { + "id": pa.array([0, 1], type=pa.int64()), + "val": pa.array([10, 20], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1], type=pa.int64()), + "val": pa.array([100, 200], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + op = MergeJoin() + + left_ch = Channel(buffer_size=16) + right_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(left, left_ch) + await feed_stream_to_channel(right, right_ch) + + await op.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 2 + + +# --------------------------------------------------------------------------- +# 6. NonZeroInputOperator barrier-mode async_execute (Join) +# --------------------------------------------------------------------------- + + +class TestJoinAsyncExecute: + @pytest.mark.asyncio + async def test_two_way_join(self): + left_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "x": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "y": pa.array([100, 200, 300], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + op = Join() + + left_ch = Channel(buffer_size=16) + right_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(left, left_ch) + await feed_stream_to_channel(right, right_ch) + + await op.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + + # Verify all tag values present + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [0, 1, 2] + + # Verify both packet columns present + for _, packet in results: + pkt = packet.as_dict() + assert "x" in pkt + assert "y" in pkt + + +# --------------------------------------------------------------------------- +# 7. FunctionPod streaming async_execute +# --------------------------------------------------------------------------- + + +class TestFunctionPodAsyncExecute: + @pytest.mark.asyncio + async def test_basic_streaming(self): + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + stream = make_stream(5) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 5 + + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 2, 4, 6, 8] + + @pytest.mark.asyncio + async def test_two_input_keys(self): + def add(x: int, y: int) -> int: + return x + y + + pf = PythonPacketFunction(add, output_keys="result") + pod = FunctionPod(pf) + + stream = make_two_col_stream(3) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 11, 22] + + @pytest.mark.asyncio + async def test_tags_pass_through(self): + """FunctionPod should preserve the input tag for each output.""" + + def noop(x: int) -> int: + return x + + pf = PythonPacketFunction(noop, output_keys="result") + pod = FunctionPod(pf) + + stream = make_stream(3) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [0, 1, 2] + + @pytest.mark.asyncio + async def test_empty_input(self): + """No items in → no items out, channel closed cleanly.""" + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + + await input_ch.writer.close() + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert results == [] + + +# --------------------------------------------------------------------------- +# 8. FunctionPod concurrency control +# --------------------------------------------------------------------------- + + +class TestFunctionPodConcurrency: + @pytest.mark.asyncio + async def test_max_concurrency_limits_in_flight(self): + """With max_concurrency=1, packets should be processed sequentially.""" + processing_order = [] + + def record_order(x: int) -> int: + processing_order.append(x) + return x + + pf = PythonPacketFunction(record_order, output_keys="result") + pod = FunctionPod(pf, node_config=NodeConfig(max_concurrency=1)) + + stream = make_stream(5) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 5 + + @pytest.mark.asyncio + async def test_unlimited_concurrency(self): + """With max_concurrency=None, all packets run concurrently.""" + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf, node_config=NodeConfig(max_concurrency=None)) + pipeline_cfg = PipelineConfig(default_max_concurrency=None) + + stream = make_stream(10) + input_ch = Channel(buffer_size=32) + output_ch = Channel(buffer_size=32) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute( + [input_ch.reader], output_ch.writer, pipeline_config=pipeline_cfg + ) + + results = await output_ch.reader.collect() + assert len(results) == 10 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [i * 2 for i in range(10)] + + @pytest.mark.asyncio + async def test_pipeline_config_concurrency_fallback(self): + """NodeConfig inherits from PipelineConfig when not overridden.""" + + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) # NodeConfig default (None) + pipeline_cfg = PipelineConfig(default_max_concurrency=2) + + stream = make_stream(4) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute( + [input_ch.reader], output_ch.writer, pipeline_config=pipeline_cfg + ) + + results = await output_ch.reader.collect() + assert len(results) == 4 + + +# --------------------------------------------------------------------------- +# 9. End-to-end multi-stage async pipeline +# --------------------------------------------------------------------------- + + +class TestEndToEndPipeline: + @pytest.mark.asyncio + async def test_source_filter_function_chain(self): + """Source → Filter → FunctionPod, wired with channels.""" + import polars as pl + + # Setup + stream = make_stream(10) + filter_op = PolarsFilter(predicates=(pl.col("id").is_in([1, 3, 5, 7]),)) + + def triple(x: int) -> int: + return x * 3 + + func_pod = FunctionPod(PythonPacketFunction(triple, output_keys="result")) + + # Channels + ch1 = Channel(buffer_size=16) + ch2 = Channel(buffer_size=16) + ch3 = Channel(buffer_size=16) + + # Wire + async def source(): + for tag, packet in stream.iter_packets(): + await ch1.writer.send((tag, packet)) + await ch1.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source()) + tg.create_task(filter_op.async_execute([ch1.reader], ch2.writer)) + tg.create_task(func_pod.async_execute([ch2.reader], ch3.writer)) + + results = await ch3.reader.collect() + assert len(results) == 4 + + result_map = { + tag.as_dict()["id"]: pkt.as_dict()["result"] for tag, pkt in results + } + assert result_map[1] == 3 + assert result_map[3] == 9 + assert result_map[5] == 15 + assert result_map[7] == 21 + + @pytest.mark.asyncio + async def test_source_join_function_chain(self): + """Two sources → Join → FunctionPod, wired with channels.""" + left_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "x": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "y": pa.array([1, 2, 3], type=pa.int64()), + } + ) + left_stream = ArrowTableStream(left_table, tag_columns=["id"]) + right_stream = ArrowTableStream(right_table, tag_columns=["id"]) + + def add(x: int, y: int) -> int: + return x + y + + join_op = Join() + func_pod = FunctionPod(PythonPacketFunction(add, output_keys="result")) + + ch_left = Channel(buffer_size=16) + ch_right = Channel(buffer_size=16) + ch_joined = Channel(buffer_size=16) + ch_out = Channel(buffer_size=16) + + async def push(stream, ch): + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(push(left_stream, ch_left)) + tg.create_task(push(right_stream, ch_right)) + tg.create_task( + join_op.async_execute( + [ch_left.reader, ch_right.reader], ch_joined.writer + ) + ) + tg.create_task(func_pod.async_execute([ch_joined.reader], ch_out.writer)) + + results = await ch_out.reader.collect() + assert len(results) == 3 + + result_map = { + tag.as_dict()["id"]: pkt.as_dict()["result"] for tag, pkt in results + } + assert result_map[0] == 11 # 10 + 1 + assert result_map[1] == 22 # 20 + 2 + assert result_map[2] == 33 # 30 + 3 + + +# --------------------------------------------------------------------------- +# 10. Error propagation +# --------------------------------------------------------------------------- + + +class TestErrorPropagation: + @pytest.mark.asyncio + async def test_function_exception_returns_none(self): + """An exception in the packet function returns (None, captured) — no raise.""" + + def failing(x: int) -> int: + if x == 2: + raise ValueError("boom") + return x + + pf = PythonPacketFunction(failing, output_keys="result") + pod = FunctionPod(pf) + + stream = make_stream(5) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(stream, input_ch) + await pod.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + # The failing packet (x=2) is silently dropped; 4 of 5 succeed + assert len(results) == 4 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 1, 3, 4] + + @pytest.mark.asyncio + async def test_direct_async_call_captures_failure(self): + """direct_async_call returns (None, captured) with success=False on error.""" + + def failing(x: int) -> int: + raise ValueError("boom") + + pf = PythonPacketFunction(failing, output_keys="result") + result, captured = await pf.direct_async_call(Packet({"x": 1})) + assert result is None + assert captured.success is False + + +# --------------------------------------------------------------------------- +# 11. Sync behavior unchanged +# --------------------------------------------------------------------------- + + +class TestSyncBehaviorUnchanged: + """Verify that adding async_execute doesn't break the existing sync path.""" + + def test_function_pod_sync_process_still_works(self): + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + + stream = make_stream(3) + output = pod.process(stream) + results = list(output.iter_packets()) + assert len(results) == 3 + values = [pkt.as_dict()["result"] for _, pkt in results] + assert values == [0, 2, 4] + + def test_operator_sync_process_still_works(self): + import polars as pl + + stream = make_stream(5) + op = PolarsFilter(predicates=(pl.col("id").is_in([1, 3]),)) + output = op.process(stream) + results = list(output.iter_packets()) + ids = sorted(cast(int, tag.as_dict()["id"]) for tag, _ in results) + assert ids == [1, 3] + + def test_join_sync_process_still_works(self): + left_table = pa.table( + { + "id": pa.array([0, 1], type=pa.int64()), + "x": pa.array([10, 20], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1], type=pa.int64()), + "y": pa.array([100, 200], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + join = Join() + output = join.process(left, right) + results = list(output.iter_packets()) + assert len(results) == 2 + + def test_function_pod_with_node_config_sync_still_works(self): + """NodeConfig should be ignored in sync mode.""" + + def add(x: int, y: int) -> int: + return x + y + + pf = PythonPacketFunction(add, output_keys="result") + pod = FunctionPod(pf, node_config=NodeConfig(max_concurrency=2)) + + stream = make_two_col_stream(3) + output = pod.process(stream) + results = list(output.iter_packets()) + assert len(results) == 3 + values = sorted(cast(int, pkt.as_dict()["result"]) for _, pkt in results) + assert values == [0, 11, 22] diff --git a/tests/test_channels/test_channels.py b/tests/test_channels/test_channels.py new file mode 100644 index 00000000..8c043fe9 --- /dev/null +++ b/tests/test_channels/test_channels.py @@ -0,0 +1,592 @@ +""" +Comprehensive tests for the async channel primitives. + +Covers: +- Channel basic send/receive +- Channel close semantics and ChannelClosed exception +- Backpressure (bounded buffer) +- Async iteration (__aiter__ / __anext__) +- collect() draining +- Multiple readers seeing sentinel +- Writer send-after-close +- BroadcastChannel fan-out to multiple readers +- BroadcastChannel close semantics +- Protocol conformance (ReadableChannel / WritableChannel) +- Empty channel collect +- Concurrent producer/consumer patterns +- Edge cases (zero-buffer, single item, large burst) +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from orcapod.channels import ( + BroadcastChannel, + Channel, + ChannelClosed, + ReadableChannel, + WritableChannel, +) + + +# --------------------------------------------------------------------------- +# 1. Basic send/receive +# --------------------------------------------------------------------------- + + +class TestBasicSendReceive: + @pytest.mark.asyncio + async def test_send_and_receive_single_item(self): + ch = Channel[int](buffer_size=8) + await ch.writer.send(42) + result = await ch.reader.receive() + assert result == 42 + + @pytest.mark.asyncio + async def test_send_and_receive_multiple_items(self): + ch = Channel[str](buffer_size=8) + items = ["a", "b", "c"] + for item in items: + await ch.writer.send(item) + + received = [] + for _ in range(3): + received.append(await ch.reader.receive()) + assert received == items + + @pytest.mark.asyncio + async def test_fifo_ordering(self): + ch = Channel[int](buffer_size=16) + for i in range(10): + await ch.writer.send(i) + await ch.writer.close() + + result = await ch.reader.collect() + assert result == list(range(10)) + + @pytest.mark.asyncio + async def test_send_receive_complex_types(self): + ch = Channel[tuple[str, int]](buffer_size=4) + await ch.writer.send(("hello", 1)) + await ch.writer.send(("world", 2)) + assert await ch.reader.receive() == ("hello", 1) + assert await ch.reader.receive() == ("world", 2) + + +# --------------------------------------------------------------------------- +# 2. Close semantics +# --------------------------------------------------------------------------- + + +class TestCloseSemantics: + @pytest.mark.asyncio + async def test_receive_after_close_raises_channel_closed(self): + ch = Channel[int](buffer_size=4) + await ch.writer.close() + with pytest.raises(ChannelClosed): + await ch.reader.receive() + + @pytest.mark.asyncio + async def test_receive_drains_then_raises(self): + ch = Channel[int](buffer_size=4) + await ch.writer.send(1) + await ch.writer.send(2) + await ch.writer.close() + + assert await ch.reader.receive() == 1 + assert await ch.reader.receive() == 2 + with pytest.raises(ChannelClosed): + await ch.reader.receive() + + @pytest.mark.asyncio + async def test_send_after_close_raises(self): + ch = Channel[int](buffer_size=4) + await ch.writer.close() + with pytest.raises(ChannelClosed, match="Cannot send to a closed channel"): + await ch.writer.send(99) + + @pytest.mark.asyncio + async def test_double_close_is_idempotent(self): + ch = Channel[int](buffer_size=4) + await ch.writer.close() + await ch.writer.close() # Should not raise + + @pytest.mark.asyncio + async def test_reader_sentinel_re_enqueued_for_repeated_receive(self): + """After close, repeated receive() calls all raise ChannelClosed.""" + ch = Channel[int](buffer_size=4) + await ch.writer.close() + for _ in range(3): + with pytest.raises(ChannelClosed): + await ch.reader.receive() + + +# --------------------------------------------------------------------------- +# 3. Backpressure +# --------------------------------------------------------------------------- + + +class TestBackpressure: + @pytest.mark.asyncio + async def test_send_blocks_when_buffer_full(self): + ch = Channel[int](buffer_size=2) + await ch.writer.send(1) + await ch.writer.send(2) + + # Buffer is full — a third send should not complete immediately + send_completed = False + + async def try_send(): + nonlocal send_completed + await ch.writer.send(3) + send_completed = True + + task = asyncio.create_task(try_send()) + await asyncio.sleep(0.05) # Give event loop a tick + assert not send_completed, "Send should block when buffer is full" + + # Drain one item to unblock + await ch.reader.receive() + await asyncio.sleep(0.05) + assert send_completed, "Send should complete after buffer has space" + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_receive_blocks_when_buffer_empty(self): + ch = Channel[int](buffer_size=4) + received = None + + async def try_receive(): + nonlocal received + received = await ch.reader.receive() + + task = asyncio.create_task(try_receive()) + await asyncio.sleep(0.05) + assert received is None, "Receive should block when buffer is empty" + + await ch.writer.send(42) + await asyncio.sleep(0.05) + assert received == 42 + await task + + +# --------------------------------------------------------------------------- +# 4. Async iteration +# --------------------------------------------------------------------------- + + +class TestAsyncIteration: + @pytest.mark.asyncio + async def test_async_for_yields_all_items(self): + ch = Channel[int](buffer_size=8) + expected = [10, 20, 30] + for item in expected: + await ch.writer.send(item) + await ch.writer.close() + + result = [] + async for item in ch.reader: + result.append(item) + assert result == expected + + @pytest.mark.asyncio + async def test_async_for_on_empty_closed_channel(self): + ch = Channel[int](buffer_size=4) + await ch.writer.close() + result = [] + async for item in ch.reader: + result.append(item) + assert result == [] + + @pytest.mark.asyncio + async def test_async_iteration_with_concurrent_producer(self): + ch = Channel[int](buffer_size=4) + + async def producer(): + for i in range(5): + await ch.writer.send(i) + await ch.writer.close() + + async def consumer(): + items = [] + async for item in ch.reader: + items.append(item) + return items + + _, result = await asyncio.gather(producer(), consumer()) + assert result == [0, 1, 2, 3, 4] + + +# --------------------------------------------------------------------------- +# 5. collect() +# --------------------------------------------------------------------------- + + +class TestCollect: + @pytest.mark.asyncio + async def test_collect_returns_all_items(self): + ch = Channel[int](buffer_size=16) + for i in range(5): + await ch.writer.send(i) + await ch.writer.close() + + result = await ch.reader.collect() + assert result == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_collect_on_empty_closed_channel(self): + ch = Channel[int](buffer_size=4) + await ch.writer.close() + result = await ch.reader.collect() + assert result == [] + + @pytest.mark.asyncio + async def test_collect_with_concurrent_producer(self): + ch = Channel[int](buffer_size=2) + + async def producer(): + for i in range(10): + await ch.writer.send(i) + await ch.writer.close() + + task = asyncio.create_task(producer()) + result = await ch.reader.collect() + await task + assert result == list(range(10)) + + +# --------------------------------------------------------------------------- +# 6. BroadcastChannel +# --------------------------------------------------------------------------- + + +class TestBroadcastChannel: + @pytest.mark.asyncio + async def test_broadcast_sends_to_all_readers(self): + bc = BroadcastChannel[int](buffer_size=8) + r1 = bc.add_reader() + r2 = bc.add_reader() + + await bc.writer.send(1) + await bc.writer.send(2) + await bc.writer.close() + + result1 = await r1.collect() + result2 = await r2.collect() + assert result1 == [1, 2] + assert result2 == [1, 2] + + @pytest.mark.asyncio + async def test_broadcast_close_signals_all_readers(self): + bc = BroadcastChannel[str](buffer_size=4) + r1 = bc.add_reader() + r2 = bc.add_reader() + r3 = bc.add_reader() + await bc.writer.close() + + for reader in [r1, r2, r3]: + with pytest.raises(ChannelClosed): + await reader.receive() + + @pytest.mark.asyncio + async def test_broadcast_readers_independent_pace(self): + bc = BroadcastChannel[int](buffer_size=8) + r1 = bc.add_reader() + r2 = bc.add_reader() + + await bc.writer.send(10) + await bc.writer.send(20) + await bc.writer.close() + + # Reader 1 drains all + result1 = await r1.collect() + + # Reader 2 also gets everything + result2 = await r2.collect() + + assert result1 == [10, 20] + assert result2 == [10, 20] + + @pytest.mark.asyncio + async def test_broadcast_send_after_close_raises(self): + bc = BroadcastChannel[int](buffer_size=4) + bc.add_reader() + await bc.writer.close() + with pytest.raises(ChannelClosed): + await bc.writer.send(1) + + @pytest.mark.asyncio + async def test_broadcast_double_close_idempotent(self): + bc = BroadcastChannel[int](buffer_size=4) + bc.add_reader() + await bc.writer.close() + await bc.writer.close() # Should not raise + + @pytest.mark.asyncio + async def test_broadcast_no_readers(self): + """Broadcast with no readers should still work (items are dropped).""" + bc = BroadcastChannel[int](buffer_size=4) + await bc.writer.send(1) + await bc.writer.close() + + @pytest.mark.asyncio + async def test_broadcast_repeated_receive_after_close(self): + bc = BroadcastChannel[int](buffer_size=4) + r = bc.add_reader() + await bc.writer.close() + for _ in range(3): + with pytest.raises(ChannelClosed): + await r.receive() + + +# --------------------------------------------------------------------------- +# 7. Protocol conformance +# --------------------------------------------------------------------------- + + +class TestProtocolConformance: + def test_channel_reader_is_readable(self): + ch = Channel[int](buffer_size=4) + assert isinstance(ch.reader, ReadableChannel) + + def test_channel_writer_is_writable(self): + ch = Channel[int](buffer_size=4) + assert isinstance(ch.writer, WritableChannel) + + def test_broadcast_reader_is_readable(self): + bc = BroadcastChannel[int](buffer_size=4) + r = bc.add_reader() + assert isinstance(r, ReadableChannel) + + def test_broadcast_writer_is_writable(self): + bc = BroadcastChannel[int](buffer_size=4) + assert isinstance(bc.writer, WritableChannel) + + +# --------------------------------------------------------------------------- +# 8. Concurrent producer/consumer patterns +# --------------------------------------------------------------------------- + + +class TestConcurrentPatterns: + @pytest.mark.asyncio + async def test_multiple_producers_single_consumer(self): + """Multiple tasks sending to the same channel.""" + ch = Channel[int](buffer_size=8) + + async def produce(start: int, count: int): + for i in range(start, start + count): + await ch.writer.send(i) + + async def run(): + async with asyncio.TaskGroup() as tg: + tg.create_task(produce(0, 5)) + tg.create_task(produce(100, 5)) + + await ch.writer.close() + + task = asyncio.create_task(run()) + result = await ch.reader.collect() + await task + + assert sorted(result) == [0, 1, 2, 3, 4, 100, 101, 102, 103, 104] + + @pytest.mark.asyncio + async def test_pipeline_two_stages(self): + """Simple two-stage pipeline: producer -> transformer -> consumer.""" + ch1 = Channel[int](buffer_size=4) + ch2 = Channel[int](buffer_size=4) + + async def producer(): + for i in range(5): + await ch1.writer.send(i) + await ch1.writer.close() + + async def transformer(): + async for item in ch1.reader: + await ch2.writer.send(item * 2) + await ch2.writer.close() + + async def consumer(): + return await ch2.reader.collect() + + _, _, result = await asyncio.gather(producer(), transformer(), consumer()) + assert result == [0, 2, 4, 6, 8] + + @pytest.mark.asyncio + async def test_pipeline_three_stages(self): + """Three-stage pipeline: source -> add1 -> double -> sink.""" + ch1 = Channel[int](buffer_size=4) + ch2 = Channel[int](buffer_size=4) + ch3 = Channel[int](buffer_size=4) + + async def source(): + for i in range(3): + await ch1.writer.send(i) + await ch1.writer.close() + + async def add_one(): + async for item in ch1.reader: + await ch2.writer.send(item + 1) + await ch2.writer.close() + + async def double(): + async for item in ch2.reader: + await ch3.writer.send(item * 2) + await ch3.writer.close() + + _, _, _, result = await asyncio.gather( + source(), add_one(), double(), ch3.reader.collect() + ) + assert result == [2, 4, 6] + + @pytest.mark.asyncio + async def test_fan_out_fan_in(self): + """Broadcast to two consumers, each processing independently.""" + bc = BroadcastChannel[int](buffer_size=8) + r1 = bc.add_reader() + r2 = bc.add_reader() + + out = Channel[int](buffer_size=16) + + async def producer(): + for i in range(3): + await bc.writer.send(i) + await bc.writer.close() + + async def worker(reader, multiplier): + async for item in reader: + await out.writer.send(item * multiplier) + + async def run(): + async with asyncio.TaskGroup() as tg: + tg.create_task(producer()) + tg.create_task(worker(r1, 10)) + tg.create_task(worker(r2, 100)) + await out.writer.close() + + task = asyncio.create_task(run()) + result = await out.reader.collect() + await task + + assert sorted(result) == [0, 0, 10, 20, 100, 200] + + +# --------------------------------------------------------------------------- +# 9. Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + @pytest.mark.asyncio + async def test_buffer_size_one(self): + ch = Channel[int](buffer_size=1) + + async def producer(): + for i in range(5): + await ch.writer.send(i) + await ch.writer.close() + + task = asyncio.create_task(producer()) + result = await ch.reader.collect() + await task + assert result == [0, 1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_large_burst(self): + ch = Channel[int](buffer_size=4) + n = 100 + + async def producer(): + for i in range(n): + await ch.writer.send(i) + await ch.writer.close() + + task = asyncio.create_task(producer()) + result = await ch.reader.collect() + await task + assert result == list(range(n)) + + @pytest.mark.asyncio + async def test_none_as_item(self): + """None is a valid item — it should not be confused with sentinel.""" + ch = Channel[int | None](buffer_size=4) + await ch.writer.send(None) + await ch.writer.send(1) + await ch.writer.send(None) + await ch.writer.close() + + result = await ch.reader.collect() + assert result == [None, 1, None] + + @pytest.mark.asyncio + async def test_channel_default_buffer_size(self): + ch = Channel[int]() + assert ch.buffer_size == 64 + + +# --------------------------------------------------------------------------- +# 10. Config types +# --------------------------------------------------------------------------- + + +class TestConfigTypes: + def test_executor_type_enum(self): + from orcapod.types import ExecutorType + + assert ExecutorType.SYNCHRONOUS.value == "synchronous" + assert ExecutorType.ASYNC_CHANNELS.value == "async_channels" + + def test_pipeline_config_defaults(self): + from orcapod.types import ExecutorType, PipelineConfig + + cfg = PipelineConfig() + assert cfg.executor == ExecutorType.SYNCHRONOUS + assert cfg.channel_buffer_size == 64 + assert cfg.default_max_concurrency is None + + def test_pipeline_config_custom(self): + from orcapod.types import ExecutorType, PipelineConfig + + cfg = PipelineConfig( + executor=ExecutorType.ASYNC_CHANNELS, + channel_buffer_size=128, + default_max_concurrency=4, + ) + assert cfg.executor == ExecutorType.ASYNC_CHANNELS + assert cfg.channel_buffer_size == 128 + assert cfg.default_max_concurrency == 4 + + def test_node_config_defaults(self): + from orcapod.types import NodeConfig + + cfg = NodeConfig() + assert cfg.max_concurrency is None + + def test_resolve_concurrency_node_overrides_pipeline(self): + from orcapod.types import NodeConfig, PipelineConfig, resolve_concurrency + + node = NodeConfig(max_concurrency=2) + pipeline = PipelineConfig(default_max_concurrency=8) + assert resolve_concurrency(node, pipeline) == 2 + + def test_resolve_concurrency_falls_back_to_pipeline(self): + from orcapod.types import NodeConfig, PipelineConfig, resolve_concurrency + + node = NodeConfig() + pipeline = PipelineConfig(default_max_concurrency=8) + assert resolve_concurrency(node, pipeline) == 8 + + def test_resolve_concurrency_both_none(self): + from orcapod.types import NodeConfig, PipelineConfig, resolve_concurrency + + node = NodeConfig() + pipeline = PipelineConfig() + assert resolve_concurrency(node, pipeline) is None diff --git a/tests/test_channels/test_copilot_review_issues.py b/tests/test_channels/test_copilot_review_issues.py new file mode 100644 index 00000000..f9891945 --- /dev/null +++ b/tests/test_channels/test_copilot_review_issues.py @@ -0,0 +1,326 @@ +""" +Tests that expose the issues identified by GitHub Copilot in PR #72. + +Each test class targets a specific review comment and is expected to FAIL +before the corresponding fix is applied. + +Issues: +1. Timing-based test flakiness — deterministic concurrency tracking +2. Inaccurate GIL documentation — (not testable, documentation-only) +3. ThreadPoolExecutor created per-call — resource waste +4. Unawaited coroutine risk — coroutine created in wrong thread +5. Semaphore(0) deadlock — resolve_concurrency can return 0 +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import patch + +import pyarrow as pa +import pytest + +from orcapod.channels import Channel +from orcapod.core.datagrams import Packet +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.streams import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import NodeConfig, PipelineConfig, resolve_concurrency + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_stream(n: int = 5) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +async def feed_stream_to_channel(stream: ArrowTableStream, ch: Channel) -> None: + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + +# --------------------------------------------------------------------------- +# Issue 1: Deterministic concurrency tracking instead of timing assertions +# --------------------------------------------------------------------------- + + +class TestDeterministicConcurrencyTracking: + """Copilot comment: timing assertions (elapsed < 0.6s) are unreliable. + + This test uses a deterministic concurrency counter to prove that tasks + actually ran concurrently, without relying on wall-clock time. + """ + + @pytest.mark.asyncio + async def test_peak_concurrency_matches_max_concurrency(self): + """Track peak concurrent tasks to verify concurrent execution.""" + peak = 0 + current = 0 + lock = asyncio.Lock() + + async def tracked_double(x: int) -> int: + nonlocal peak, current + async with lock: + current += 1 + peak = max(peak, current) + await asyncio.sleep(0.05) # Small sleep to allow overlap + async with lock: + current -= 1 + return x * 2 + + pf = PythonPacketFunction(tracked_double, output_keys="result") + pod = FunctionPod(pf, node_config=NodeConfig(max_concurrency=5)) + db = InMemoryArrowDatabase() + stream = make_stream(5) + node = FunctionNode(pod, stream, pipeline_database=db) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(5), input_ch) + await node.async_execute(input_ch.reader, output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 5 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 2, 4, 6, 8] + + # Concurrency limiting was removed in PLT-922 (deferred to PLT-930). + # Packets are now processed sequentially, so peak should be 1. + assert peak == 1, f"Expected sequential execution (peak=1) but peak was {peak}" + + +# --------------------------------------------------------------------------- +# Issue 3: ThreadPoolExecutor created per-call (resource waste) +# --------------------------------------------------------------------------- + + +class TestThreadPoolExecutorReuse: + """Copilot comment: creating new ThreadPoolExecutor on every call wastes + thread resources. + + This test verifies that multiple sync calls to an async function reuse + a shared executor rather than creating a new one each time. + """ + + def test_executor_not_created_per_call(self): + """Calling an async function sync multiple times should not create + a new ThreadPoolExecutor each time.""" + + async def simple_add(x: int, y: int) -> int: + return x + y + + pf = PythonPacketFunction(simple_add, output_keys="result") + + creation_count = 0 + original_init = None + + # We need to patch ThreadPoolExecutor.__init__ to count instantiations + from concurrent.futures import ThreadPoolExecutor + + original_init = ThreadPoolExecutor.__init__ + + def counting_init(self, *args, **kwargs): + nonlocal creation_count + creation_count += 1 + return original_init(self, *args, **kwargs) + + # Run inside an event loop context to trigger the ThreadPoolExecutor path + async def run_in_loop(): + nonlocal creation_count + with patch.object(ThreadPoolExecutor, "__init__", counting_init): + for _ in range(3): + _result, _captured = pf.direct_call(Packet({"x": 1, "y": 2})) + + asyncio.run(run_in_loop()) + # Current code creates a new executor per call, so creation_count == 3. + # After fix, it should be <= 1 (reused executor). + assert creation_count <= 1, ( + f"ThreadPoolExecutor created {creation_count} times; " + f"expected at most 1 (should reuse)" + ) + + +# --------------------------------------------------------------------------- +# Issue 4: Coroutine constructed in wrong thread +# --------------------------------------------------------------------------- + + +class TestCoroutineConstructedInExecutorThread: + """Copilot comment: coroutine is created in the caller thread but passed + to another thread — risks unawaited coroutine warnings if submission fails. + + The current code does ``coro = self._function(**packet.as_dict())`` in the + caller thread, then passes the already-created coroutine to a + ThreadPoolExecutor. If submission fails, the coroutine is never awaited, + triggering a RuntimeWarning. + + The fix should construct the coroutine inside the executor thread by + passing a lambda that both creates and runs the coroutine. + """ + + def test_coroutine_created_inside_executor_thread(self): + """Verify the coroutine is NOT created before being submitted to + the executor. + + We instrument the async function to record the thread where it is + *called* (i.e. where the coroutine starts executing). In the fixed + code, this should happen inside the executor thread. We also check + that no coroutine object is created in the caller thread by + inspecting the implementation pattern. + """ + import threading + + execution_threads: list[str] = [] + + async def tracking_func(x: int) -> int: + execution_threads.append(threading.current_thread().name) + return x * 2 + + pf = PythonPacketFunction(tracking_func, output_keys="result") + + # Intercept _call_async_function_sync to verify it does NOT call + # self._function before submitting to the executor. + # The buggy pattern: coro = self._function(...); pool.submit(asyncio.run, coro) + # The fixed pattern: pool.submit(lambda: asyncio.run(self._function(...))) + coroutine_created_in_caller = False + original_method = pf._call_async_function_sync + + def instrumented_call(packet): + nonlocal coroutine_created_in_caller + # Read the source to check if coro is created before submit + import inspect as _inspect + + source = _inspect.getsource(original_method) + # The buggy pattern assigns coro before the try block + if "coro = self._function(" in source: + coroutine_created_in_caller = True + return original_method(packet) + + pf._call_async_function_sync = instrumented_call + + # Run inside an event loop to trigger the ThreadPoolExecutor path + async def run_in_loop(): + _result, _captured = pf.direct_call(Packet({"x": 5})) + + asyncio.run(run_in_loop()) + + assert not coroutine_created_in_caller, ( + "Coroutine is created in the caller thread before submission " + "to the executor. It should be created inside the executor thread " + "to avoid unawaited coroutine warnings on submission failure." + ) + + +# --------------------------------------------------------------------------- +# Issue 5: Semaphore(0) deadlock +# --------------------------------------------------------------------------- + + +class TestSemaphoreZeroDeadlock: + """Copilot comment: resolve_concurrency can return 0, causing + asyncio.Semaphore(0) to deadlock on first acquire. + + This test verifies that max_concurrency=0 is rejected with a clear error. + """ + + def test_resolve_concurrency_rejects_zero(self): + """resolve_concurrency should raise ValueError when result is 0.""" + node_config = NodeConfig(max_concurrency=0) + pipeline_config = PipelineConfig() + + with pytest.raises(ValueError, match="max_concurrency"): + resolve_concurrency(node_config, pipeline_config) + + def test_resolve_concurrency_rejects_negative(self): + """resolve_concurrency should raise ValueError when result is negative.""" + node_config = NodeConfig(max_concurrency=-1) + pipeline_config = PipelineConfig() + + with pytest.raises(ValueError, match="max_concurrency"): + resolve_concurrency(node_config, pipeline_config) + + def test_resolve_concurrency_accepts_one(self): + """max_concurrency=1 is valid (sequential execution).""" + node_config = NodeConfig(max_concurrency=1) + pipeline_config = PipelineConfig() + assert resolve_concurrency(node_config, pipeline_config) == 1 + + def test_resolve_concurrency_accepts_none(self): + """max_concurrency=None means unlimited (no semaphore).""" + node_config = NodeConfig() + pipeline_config = PipelineConfig() + assert resolve_concurrency(node_config, pipeline_config) is None + + @pytest.mark.asyncio + async def test_max_concurrency_zero_no_deadlock(self): + """max_concurrency=0 no longer causes deadlock after PLT-922. + + Semaphore/concurrency limiting was removed from async_execute + (deferred to PLT-930). Packets are processed sequentially regardless + of max_concurrency settings. + """ + + async def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf, node_config=NodeConfig(max_concurrency=0)) + stream = make_stream(1) + node = FunctionNode(pod, stream) + + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + + await feed_stream_to_channel(make_stream(1), input_ch) + + # With concurrency removed, this completes without deadlock. + await asyncio.wait_for( + node.async_execute(input_ch.reader, output_ch.writer), + timeout=2.0, + ) + results = await output_ch.reader.collect() + assert len(results) == 1 + + +# --------------------------------------------------------------------------- +# Issue 2: GIL comment accuracy (not directly testable, but we verify the +# example file contains the inaccurate text) +# --------------------------------------------------------------------------- + + +class TestGILCommentAccuracy: + """Copilot comment: the example claims native coroutines 'bypass the GIL + entirely', which is misleading. + + This test checks that the inaccurate phrase does NOT exist in the example. + After the fix, this test should pass. + """ + + def test_example_does_not_claim_gil_bypass(self): + """The async example should not claim coroutines bypass the GIL.""" + import pathlib + + example_path = ( + pathlib.Path(__file__).resolve().parents[2] + / "examples" + / "async_vs_sync_pipeline.py" + ) + content = example_path.read_text() + + assert "bypass the GIL entirely" not in content, ( + "Example still contains misleading claim that coroutines " + "'bypass the GIL entirely'" + ) diff --git a/tests/test_channels/test_native_async_operators.py b/tests/test_channels/test_native_async_operators.py new file mode 100644 index 00000000..de38bf71 --- /dev/null +++ b/tests/test_channels/test_native_async_operators.py @@ -0,0 +1,1468 @@ +""" +Comprehensive tests for native streaming async_execute overrides. + +Each operator's new streaming async_execute is tested to produce the same +results as the synchronous static_process path. Tests mirror the sync +operator tests in ``tests/test_core/operators/test_operators.py``. + +Covers: +- SelectTagColumns streaming: per-row tag column selection +- SelectPacketColumns streaming: per-row packet column selection +- DropTagColumns streaming: per-row tag column dropping +- DropPacketColumns streaming: per-row packet column dropping +- MapTags streaming: per-row tag column renaming +- MapPackets streaming: per-row packet column renaming +- Batch streaming: accumulate-and-emit full batches, partial batch handling +- SemiJoin build-probe: collect right, stream left through hash lookup +- Join: single-input passthrough, concurrent binary/N-ary collection +- Sync / async equivalence for every operator +- Empty input handling +- Multi-stage pipeline integration +""" + +from __future__ import annotations + +import asyncio + +import pyarrow as pa +import pytest + +from orcapod.channels import Channel +from orcapod.core.operators import ( + Batch, + DropPacketColumns, + DropTagColumns, + Join, + MapPackets, + MapTags, + SelectPacketColumns, + SelectTagColumns, + SemiJoin, +) +from orcapod.core.streams.arrow_table_stream import ArrowTableStream +from orcapod.system_constants import constants + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_simple_stream() -> ArrowTableStream: + """Stream with 1 tag (animal) and 2 packet columns (weight, legs).""" + table = pa.table( + { + "animal": ["cat", "dog", "bird"], + "weight": [4.0, 12.0, 0.5], + "legs": [4, 4, 2], + } + ) + return ArrowTableStream(table, tag_columns=["animal"]) + + +def make_two_tag_stream() -> ArrowTableStream: + """Stream with 2 tags (region, animal) and 1 packet column (count).""" + table = pa.table( + { + "region": ["east", "east", "west"], + "animal": ["cat", "dog", "cat"], + "count": [10, 5, 8], + } + ) + return ArrowTableStream(table, tag_columns=["region", "animal"]) + + +def make_int_stream(n: int = 3) -> ArrowTableStream: + """Stream with tag=id, packet=x (ints).""" + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_two_col_stream(n: int = 3) -> ArrowTableStream: + """Stream with tag=id, packet={x, y}.""" + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + "y": pa.array([i * 10 for i in range(n)], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_left_stream() -> ArrowTableStream: + table = pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "value_a": pa.array([10, 20, 30], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_right_stream() -> ArrowTableStream: + table = pa.table( + { + "id": pa.array([2, 3, 4], type=pa.int64()), + "value_b": pa.array([200, 300, 400], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_disjoint_stream() -> ArrowTableStream: + """Stream with same tags as simple_stream but different packet columns.""" + table = pa.table( + { + "animal": ["cat", "dog", "bird"], + "speed": [30.0, 45.0, 80.0], + } + ) + return ArrowTableStream(table, tag_columns=["animal"]) + + +async def feed(stream: ArrowTableStream, ch: Channel) -> None: + """Push all (tag, packet) from a stream into a channel, then close.""" + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + +async def run_unary(op, stream: ArrowTableStream) -> list[tuple]: + """Run a unary operator async and collect results.""" + input_ch = Channel(buffer_size=1024) + output_ch = Channel(buffer_size=1024) + await feed(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + return await output_ch.reader.collect() + + +async def run_binary( + op, left: ArrowTableStream, right: ArrowTableStream +) -> list[tuple]: + """Run a binary operator async and collect results.""" + left_ch = Channel(buffer_size=1024) + right_ch = Channel(buffer_size=1024) + output_ch = Channel(buffer_size=1024) + await feed(left, left_ch) + await feed(right, right_ch) + await op.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + return await output_ch.reader.collect() + + +def sync_process_to_rows(op, *streams): + """Run sync static_process and return list of (tag, packet) pairs.""" + result = op.static_process(*streams) + return list(result.iter_packets()) + + +# =================================================================== +# SelectTagColumns — streaming per-row +# =================================================================== + + +class TestSelectTagColumnsStreaming: + @pytest.mark.asyncio + async def test_keeps_only_selected_tags(self): + stream = make_two_tag_stream() + op = SelectTagColumns(columns=["region"]) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, packet in results: + tag_keys = tag.keys() + assert "region" in tag_keys + assert "animal" not in tag_keys + # packet columns unchanged + assert "count" in packet.keys() + + @pytest.mark.asyncio + async def test_all_columns_selected_passthrough(self): + """When all tag columns are already selected, rows pass through unaltered.""" + stream = make_two_tag_stream() + op = SelectTagColumns(columns=["region", "animal"]) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, packet in results: + assert set(tag.keys()) == {"region", "animal"} + assert "count" in packet.keys() + + @pytest.mark.asyncio + async def test_data_values_preserved(self): + stream = make_two_tag_stream() + op = SelectTagColumns(columns=["region"]) + results = await run_unary(op, stream) + + regions = sorted(tag.as_dict()["region"] for tag, _ in results) + assert regions == ["east", "east", "west"] + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = SelectTagColumns(columns=["region"]) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_two_tag_stream() + op = SelectTagColumns(columns=["region"]) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + async_tags = sorted(t.as_dict()["region"] for t, _ in async_results) + sync_tags = sorted(t.as_dict()["region"] for t, _ in sync_results) + assert async_tags == sync_tags + + @pytest.mark.asyncio + async def test_system_tags_preserved(self): + """System tags on Tag objects should survive per-row selection.""" + from orcapod.core.sources.arrow_table_source import ArrowTableSource + + src = ArrowTableSource( + pa.table( + { + "region": ["east", "west"], + "animal": ["cat", "dog"], + "count": pa.array([10, 5], type=pa.int64()), + } + ), + tag_columns=["region", "animal"], + ) + op = SelectTagColumns(columns=["region"]) + results = await run_unary(op, src) + + assert len(results) == 2 + for tag, _ in results: + sys_tags = tag.system_tags() + # Source-backed streams have system tags + assert len(sys_tags) > 0 + + +# =================================================================== +# SelectPacketColumns — streaming per-row +# =================================================================== + + +class TestSelectPacketColumnsStreaming: + @pytest.mark.asyncio + async def test_keeps_only_selected_packets(self): + stream = make_simple_stream() + op = SelectPacketColumns(columns=["weight"]) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + pkt_keys = packet.keys() + assert "weight" in pkt_keys + assert "legs" not in pkt_keys + # tag columns unchanged + for tag, _ in results: + assert "animal" in tag.keys() + + @pytest.mark.asyncio + async def test_all_columns_selected_passthrough(self): + stream = make_simple_stream() + op = SelectPacketColumns(columns=["weight", "legs"]) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + assert set(packet.keys()) == {"weight", "legs"} + + @pytest.mark.asyncio + async def test_data_values_preserved(self): + stream = make_simple_stream() + op = SelectPacketColumns(columns=["weight"]) + results = await run_unary(op, stream) + + weights = sorted(pkt.as_dict()["weight"] for _, pkt in results) + assert weights == [0.5, 4.0, 12.0] + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = SelectPacketColumns(columns=["weight"]) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_simple_stream() + op = SelectPacketColumns(columns=["weight"]) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + async_vals = sorted(p.as_dict()["weight"] for _, p in async_results) + sync_vals = sorted(p.as_dict()["weight"] for _, p in sync_results) + assert async_vals == sync_vals + + @pytest.mark.asyncio + async def test_source_info_for_dropped_columns_not_surfaced(self): + """Source info for dropped packet columns should not appear in output.""" + from orcapod.core.sources.arrow_table_source import ArrowTableSource + + src = ArrowTableSource( + pa.table( + { + "animal": ["cat", "dog"], + "weight": [4.0, 12.0], + "legs": pa.array([4, 4], type=pa.int64()), + } + ), + tag_columns=["animal"], + ) + op = SelectPacketColumns(columns=["weight"]) + results = await run_unary(op, src) + + for _, packet in results: + si = packet.source_info() + assert "legs" not in si + assert "weight" in si + + +# =================================================================== +# DropTagColumns — streaming per-row +# =================================================================== + + +class TestDropTagColumnsStreaming: + @pytest.mark.asyncio + async def test_drops_specified_tags(self): + stream = make_two_tag_stream() + op = DropTagColumns(columns=["region"]) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, packet in results: + assert "region" not in tag.keys() + assert "animal" in tag.keys() + assert "count" in packet.keys() + + @pytest.mark.asyncio + async def test_no_columns_to_drop_passthrough(self): + stream = make_two_tag_stream() + op = DropTagColumns(columns=["nonexistent"], strict=False) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, _ in results: + assert set(tag.keys()) == {"region", "animal"} + + @pytest.mark.asyncio + async def test_data_values_preserved(self): + stream = make_two_tag_stream() + op = DropTagColumns(columns=["region"]) + results = await run_unary(op, stream) + + animals = sorted(tag.as_dict()["animal"] for tag, _ in results) + assert animals == ["cat", "cat", "dog"] + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = DropTagColumns(columns=["region"]) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_two_tag_stream() + op = DropTagColumns(columns=["region"]) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + async_animals = sorted(t.as_dict()["animal"] for t, _ in async_results) + sync_animals = sorted(t.as_dict()["animal"] for t, _ in sync_results) + assert async_animals == sync_animals + + +# =================================================================== +# DropPacketColumns — streaming per-row +# =================================================================== + + +class TestDropPacketColumnsStreaming: + @pytest.mark.asyncio + async def test_drops_specified_packets(self): + stream = make_simple_stream() + op = DropPacketColumns(columns=["legs"]) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + assert "legs" not in packet.keys() + assert "weight" in packet.keys() + for tag, _ in results: + assert "animal" in tag.keys() + + @pytest.mark.asyncio + async def test_no_columns_to_drop_passthrough(self): + stream = make_simple_stream() + op = DropPacketColumns(columns=["nonexistent"], strict=False) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + assert set(packet.keys()) == {"weight", "legs"} + + @pytest.mark.asyncio + async def test_data_values_preserved(self): + stream = make_simple_stream() + op = DropPacketColumns(columns=["legs"]) + results = await run_unary(op, stream) + + weights = sorted(pkt.as_dict()["weight"] for _, pkt in results) + assert weights == [0.5, 4.0, 12.0] + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = DropPacketColumns(columns=["legs"]) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_simple_stream() + op = DropPacketColumns(columns=["legs"]) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + async_vals = sorted(p.as_dict()["weight"] for _, p in async_results) + sync_vals = sorted(p.as_dict()["weight"] for _, p in sync_results) + assert async_vals == sync_vals + + @pytest.mark.asyncio + async def test_source_info_for_dropped_columns_not_surfaced(self): + from orcapod.core.sources.arrow_table_source import ArrowTableSource + + src = ArrowTableSource( + pa.table( + { + "animal": ["cat", "dog"], + "weight": [4.0, 12.0], + "legs": pa.array([4, 4], type=pa.int64()), + } + ), + tag_columns=["animal"], + ) + op = DropPacketColumns(columns=["legs"]) + results = await run_unary(op, src) + + for _, packet in results: + si = packet.source_info() + assert "legs" not in si + assert "weight" in si + + +# =================================================================== +# MapTags — streaming per-row +# =================================================================== + + +class TestMapTagsStreaming: + @pytest.mark.asyncio + async def test_renames_tag_column(self): + stream = make_two_tag_stream() + op = MapTags(name_map={"region": "area"}) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, _ in results: + tag_keys = tag.keys() + assert "area" in tag_keys + assert "region" not in tag_keys + + @pytest.mark.asyncio + async def test_data_values_preserved(self): + stream = make_two_tag_stream() + op = MapTags(name_map={"region": "area"}) + results = await run_unary(op, stream) + + areas = sorted(tag.as_dict()["area"] for tag, _ in results) + assert areas == ["east", "east", "west"] + + @pytest.mark.asyncio + async def test_drop_unmapped(self): + stream = make_two_tag_stream() + op = MapTags(name_map={"region": "area"}, drop_unmapped=True) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, _ in results: + tag_keys = tag.keys() + assert "area" in tag_keys + assert "animal" not in tag_keys # dropped because unmapped + + @pytest.mark.asyncio + async def test_no_matching_rename_passthrough(self): + stream = make_two_tag_stream() + op = MapTags(name_map={"nonexistent": "nope"}) + results = await run_unary(op, stream) + + assert len(results) == 3 + for tag, _ in results: + assert set(tag.keys()) == {"region", "animal"} + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = MapTags(name_map={"region": "area"}) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_two_tag_stream() + op = MapTags(name_map={"region": "area"}) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + async_areas = sorted(t.as_dict()["area"] for t, _ in async_results) + sync_areas = sorted(t.as_dict()["area"] for t, _ in sync_results) + assert async_areas == sync_areas + + @pytest.mark.asyncio + async def test_matches_sync_output_with_drop_unmapped(self): + stream = make_two_tag_stream() + op = MapTags(name_map={"region": "area"}, drop_unmapped=True) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + for (at, ap), (st, sp) in zip( + sorted(async_results, key=lambda x: x[0].as_dict()["area"]), + sorted(sync_results, key=lambda x: x[0].as_dict()["area"]), + ): + assert at.as_dict() == st.as_dict() + assert ap.as_dict() == sp.as_dict() + + +# =================================================================== +# MapPackets — streaming per-row +# =================================================================== + + +class TestMapPacketsStreaming: + @pytest.mark.asyncio + async def test_renames_packet_column(self): + stream = make_simple_stream() + op = MapPackets(name_map={"weight": "mass"}) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + pkt_keys = packet.keys() + assert "mass" in pkt_keys + assert "weight" not in pkt_keys + + @pytest.mark.asyncio + async def test_data_values_preserved(self): + stream = make_simple_stream() + op = MapPackets(name_map={"weight": "mass"}) + results = await run_unary(op, stream) + + masses = sorted(pkt.as_dict()["mass"] for _, pkt in results) + assert masses == [0.5, 4.0, 12.0] + + @pytest.mark.asyncio + async def test_drop_unmapped(self): + stream = make_simple_stream() + op = MapPackets(name_map={"weight": "mass"}, drop_unmapped=True) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + pkt_keys = packet.keys() + assert "mass" in pkt_keys + assert "legs" not in pkt_keys # dropped because unmapped + + @pytest.mark.asyncio + async def test_source_info_renamed(self): + """Packet.rename() should update source_info keys.""" + from orcapod.core.sources.arrow_table_source import ArrowTableSource + + src = ArrowTableSource( + pa.table( + { + "animal": ["cat", "dog"], + "weight": [4.0, 12.0], + "legs": pa.array([4, 4], type=pa.int64()), + } + ), + tag_columns=["animal"], + ) + op = MapPackets(name_map={"weight": "mass"}) + results = await run_unary(op, src) + + for _, packet in results: + si = packet.source_info() + assert "mass" in si + assert "weight" not in si + + @pytest.mark.asyncio + async def test_no_matching_rename_passthrough(self): + stream = make_simple_stream() + op = MapPackets(name_map={"nonexistent": "nope"}) + results = await run_unary(op, stream) + + assert len(results) == 3 + for _, packet in results: + assert set(packet.keys()) == {"weight", "legs"} + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = MapPackets(name_map={"weight": "mass"}) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_simple_stream() + op = MapPackets(name_map={"weight": "mass"}) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + async_masses = sorted(p.as_dict()["mass"] for _, p in async_results) + sync_masses = sorted(p.as_dict()["mass"] for _, p in sync_results) + assert async_masses == sync_masses + + +# =================================================================== +# Batch — streaming accumulate-and-emit +# =================================================================== + + +class TestBatchStreaming: + @pytest.mark.asyncio + async def test_batch_groups_rows(self): + stream = make_simple_stream() # 3 rows + op = Batch(batch_size=2) + results = await run_unary(op, stream) + + # 3 rows / batch_size=2 → 2 batches (full + partial) + assert len(results) == 2 + + @pytest.mark.asyncio + async def test_batch_drop_partial(self): + stream = make_simple_stream() # 3 rows + op = Batch(batch_size=2, drop_partial_batch=True) + results = await run_unary(op, stream) + + # 3 rows / batch_size=2 with drop → 1 batch + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_batch_size_zero_single_batch(self): + stream = make_simple_stream() # 3 rows + op = Batch(batch_size=0) + results = await run_unary(op, stream) + + # batch_size=0 → all in one batch + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_batch_values_are_lists(self): + stream = make_int_stream(4) + op = Batch(batch_size=2) + results = await run_unary(op, stream) + + assert len(results) == 2 + for tag, packet in results: + # Each value should be a list + tag_d = tag.as_dict() + pkt_d = packet.as_dict() + assert isinstance(tag_d["id"], list) + assert isinstance(pkt_d["x"], list) + assert len(tag_d["id"]) == 2 + assert len(pkt_d["x"]) == 2 + + @pytest.mark.asyncio + async def test_batch_exact_multiple(self): + stream = make_int_stream(6) + op = Batch(batch_size=2) + results = await run_unary(op, stream) + + # 6 / 2 = 3 full batches, no partial + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_batch_exact_multiple_drop_partial(self): + stream = make_int_stream(6) + op = Batch(batch_size=2, drop_partial_batch=True) + results = await run_unary(op, stream) + + # Same as without drop since there's no partial batch + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_empty_input(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = Batch(batch_size=2) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + stream = make_int_stream(7) + op = Batch(batch_size=3) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + # Each batch should have the same data + for (at, ap), (st, sp) in zip(async_results, sync_results): + assert at.as_dict() == st.as_dict() + assert ap.as_dict() == sp.as_dict() + + @pytest.mark.asyncio + async def test_matches_sync_output_batch_zero(self): + stream = make_int_stream(5) + op = Batch(batch_size=0) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) == 1 + assert async_results[0][0].as_dict() == sync_results[0][0].as_dict() + assert async_results[0][1].as_dict() == sync_results[0][1].as_dict() + + @pytest.mark.asyncio + async def test_matches_sync_output_drop_partial(self): + stream = make_int_stream(5) + op = Batch(batch_size=3, drop_partial_batch=True) + + async_results = await run_unary(op, stream) + sync_results = sync_process_to_rows(op, stream) + + assert len(async_results) == len(sync_results) + for (at, ap), (st, sp) in zip(async_results, sync_results): + assert at.as_dict() == st.as_dict() + assert ap.as_dict() == sp.as_dict() + + +# =================================================================== +# SemiJoin — build-probe +# =================================================================== + + +class TestSemiJoinBuildProbe: + @pytest.mark.asyncio + async def test_filters_left_by_right(self): + left = make_left_stream() # id=[1,2,3] + right = make_right_stream() # id=[2,3,4] + op = SemiJoin() + results = await run_binary(op, left, right) + + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [2, 3] + + @pytest.mark.asyncio + async def test_preserves_left_schema(self): + left = make_left_stream() + right = make_right_stream() + op = SemiJoin() + results = await run_binary(op, left, right) + + for tag, packet in results: + assert "id" in tag.keys() + assert "value_a" in packet.keys() + assert "value_b" not in packet.keys() + + @pytest.mark.asyncio + async def test_preserves_left_data(self): + left = make_left_stream() + right = make_right_stream() + op = SemiJoin() + results = await run_binary(op, left, right) + + result_map = { + tag.as_dict()["id"]: pkt.as_dict()["value_a"] for tag, pkt in results + } + assert result_map[2] == 20 + assert result_map[3] == 30 + + @pytest.mark.asyncio + async def test_no_common_keys_returns_all_left(self): + left_table = pa.table( + { + "a": pa.array([1, 2, 3], type=pa.int64()), + "x": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right_table = pa.table( + { + "b": pa.array([1, 2], type=pa.int64()), + "y": pa.array([100, 200], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["a"]) + right = ArrowTableStream(right_table, tag_columns=["b"]) + op = SemiJoin() + results = await run_binary(op, left, right) + + assert len(results) == 3 # all left rows pass through + + @pytest.mark.asyncio + async def test_no_matching_rows_empty_result(self): + left_table = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "x": pa.array([10, 20], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([3, 4], type=pa.int64()), + "y": pa.array([30, 40], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + op = SemiJoin() + results = await run_binary(op, left, right) + + assert len(results) == 0 + + @pytest.mark.asyncio + async def test_empty_left_returns_empty(self): + """Empty left input produces empty output regardless of right.""" + right_table = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "y": pa.array([100, 200], type=pa.int64()), + } + ) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + left_ch = Channel(buffer_size=4) + right_ch = Channel(buffer_size=64) + output_ch = Channel(buffer_size=64) + + await left_ch.writer.close() + await feed(right, right_ch) + + op = SemiJoin() + await op.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_empty_right_returns_empty_or_all(self): + """Empty right: if common keys, result is empty; if no common keys, left passes through. + Since both sides are empty-right, we rely on the barrier fallback.""" + left = make_left_stream() + + left_ch = Channel(buffer_size=64) + right_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=64) + + await feed(left, left_ch) + await right_ch.writer.close() + + op = SemiJoin() + await op.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + # With empty right and no schema information available, + # the implementation falls back to passing left through + assert len(results) == 3 + + @pytest.mark.asyncio + async def test_matches_sync_output(self): + left = make_left_stream() + right = make_right_stream() + op = SemiJoin() + + async_results = await run_binary(op, left, right) + sync_results = sync_process_to_rows(op, left, right) + + assert len(async_results) == len(sync_results) + async_ids = sorted(t.as_dict()["id"] for t, _ in async_results) + sync_ids = sorted(t.as_dict()["id"] for t, _ in sync_results) + assert async_ids == sync_ids + + @pytest.mark.asyncio + async def test_large_input_streaming(self): + """SemiJoin should handle larger inputs correctly with build-probe.""" + left_table = pa.table( + { + "id": pa.array(list(range(100)), type=pa.int64()), + "x": pa.array(list(range(100)), type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array(list(range(0, 100, 3)), type=pa.int64()), # every 3rd + "y": pa.array(list(range(0, 100, 3)), type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + op = SemiJoin() + results = await run_binary(op, left, right) + + expected_ids = list(range(0, 100, 3)) + result_ids = sorted(t.as_dict()["id"] for t, _ in results) + assert result_ids == expected_ids + + +# =================================================================== +# Join — native async +# =================================================================== + + +class TestJoinNativeAsync: + """Tests for Join.async_execute (symmetric hash join + N>2 barrier).""" + + @pytest.mark.asyncio + async def test_single_input_passthrough(self): + stream = make_int_stream(3) + op = Join() + + input_ch = Channel(buffer_size=64) + output_ch = Channel(buffer_size=64) + await feed(stream, input_ch) + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + + assert len(results) == 3 + ids = sorted(t.as_dict()["id"] for t, _ in results) + assert ids == [0, 1, 2] + + @pytest.mark.asyncio + async def test_two_way_join(self): + left = make_simple_stream() + right = make_disjoint_stream() + op = Join() + results = await run_binary(op, left, right) + + assert len(results) == 3 + for tag, packet in results: + assert "animal" in tag.keys() + pkt_d = packet.as_dict() + assert "weight" in pkt_d + assert "speed" in pkt_d + + @pytest.mark.asyncio + async def test_two_way_join_data_correct(self): + left_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "x": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "y": pa.array([100, 200, 300], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + op = Join() + results = await run_binary(op, left, right) + + assert len(results) == 3 + result_map = {tag.as_dict()["id"]: pkt.as_dict() for tag, pkt in results} + assert result_map[0] == {"x": 10, "y": 100} + assert result_map[1] == {"x": 20, "y": 200} + assert result_map[2] == {"x": 30, "y": 300} + + @pytest.mark.asyncio + async def test_three_way_join(self): + t1 = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "a": pa.array([10, 20], type=pa.int64()), + } + ) + t2 = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "b": pa.array([100, 200], type=pa.int64()), + } + ) + t3 = pa.table( + { + "id": pa.array([1, 2], type=pa.int64()), + "c": pa.array([1000, 2000], type=pa.int64()), + } + ) + s1 = ArrowTableStream(t1, tag_columns=["id"]) + s2 = ArrowTableStream(t2, tag_columns=["id"]) + s3 = ArrowTableStream(t3, tag_columns=["id"]) + + op = Join() + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + + await feed(s1, ch1) + await feed(s2, ch2) + await feed(s3, ch3) + await op.async_execute([ch1.reader, ch2.reader, ch3.reader], out.writer) + results = await out.reader.collect() + + assert len(results) == 2 + result_map = {tag.as_dict()["id"]: pkt.as_dict() for tag, pkt in results} + assert result_map[1] == {"a": 10, "b": 100, "c": 1000} + assert result_map[2] == {"a": 20, "b": 200, "c": 2000} + + @pytest.mark.asyncio + async def test_join_no_shared_tags_cartesian(self): + """When no shared tag keys, join produces a cartesian product.""" + left_table = pa.table( + { + "a": pa.array([1, 2], type=pa.int64()), + "x": pa.array([10, 20], type=pa.int64()), + } + ) + right_table = pa.table( + { + "b": pa.array([3, 4], type=pa.int64()), + "y": pa.array([30, 40], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["a"]) + right = ArrowTableStream(right_table, tag_columns=["b"]) + op = Join() + results = await run_binary(op, left, right) + + # 2 × 2 = 4 cartesian product + assert len(results) == 4 + + @pytest.mark.asyncio + async def test_empty_input_single(self): + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + await input_ch.writer.close() + op = Join() + await op.async_execute([input_ch.reader], output_ch.writer) + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_matches_sync_two_way(self): + left = make_simple_stream() + right = make_disjoint_stream() + op = Join() + + async_results = await run_binary(op, left, right) + sync_results = sync_process_to_rows(op, left, right) + + assert len(async_results) == len(sync_results) + async_data = sorted( + (t.as_dict()["animal"], p.as_dict()) for t, p in async_results + ) + sync_data = sorted( + (t.as_dict()["animal"], p.as_dict()) for t, p in sync_results + ) + assert async_data == sync_data + + +# =================================================================== +# Multi-stage pipeline integration +# =================================================================== + + +class TestStreamingPipelineIntegration: + @pytest.mark.asyncio + async def test_select_then_map_chain(self): + """SelectTagColumns → MapTags in a streaming pipeline.""" + stream = make_two_tag_stream() + + select_op = SelectTagColumns(columns=["region"]) + map_op = MapTags(name_map={"region": "area"}) + + ch1 = Channel(buffer_size=16) + ch2 = Channel(buffer_size=16) + ch3 = Channel(buffer_size=16) + + async def source(): + for tag, packet in stream.iter_packets(): + await ch1.writer.send((tag, packet)) + await ch1.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source()) + tg.create_task(select_op.async_execute([ch1.reader], ch2.writer)) + tg.create_task(map_op.async_execute([ch2.reader], ch3.writer)) + + results = await ch3.reader.collect() + assert len(results) == 3 + for tag, _ in results: + assert "area" in tag.keys() + assert "region" not in tag.keys() + assert "animal" not in tag.keys() + + @pytest.mark.asyncio + async def test_join_then_select_chain(self): + """Join → SelectPacketColumns in a streaming pipeline.""" + left_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "x": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "y": pa.array([100, 200, 300], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + join_op = Join() + select_op = SelectPacketColumns(columns=["x"]) + + ch_l = Channel(buffer_size=16) + ch_r = Channel(buffer_size=16) + ch_joined = Channel(buffer_size=16) + ch_out = Channel(buffer_size=16) + + async def push(stream, ch): + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(push(left, ch_l)) + tg.create_task(push(right, ch_r)) + tg.create_task( + join_op.async_execute([ch_l.reader, ch_r.reader], ch_joined.writer) + ) + tg.create_task(select_op.async_execute([ch_joined.reader], ch_out.writer)) + + results = await ch_out.reader.collect() + assert len(results) == 3 + for _, packet in results: + assert "x" in packet.keys() + assert "y" not in packet.keys() + + @pytest.mark.asyncio + async def test_semijoin_then_batch_chain(self): + """SemiJoin → Batch in a streaming pipeline.""" + left = make_left_stream() # id=[1,2,3] + right = make_right_stream() # id=[2,3,4] + + semi_op = SemiJoin() + batch_op = Batch(batch_size=2) + + ch_l = Channel(buffer_size=16) + ch_r = Channel(buffer_size=16) + ch_semi = Channel(buffer_size=16) + ch_out = Channel(buffer_size=16) + + async def push(stream, ch): + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(push(left, ch_l)) + tg.create_task(push(right, ch_r)) + tg.create_task( + semi_op.async_execute([ch_l.reader, ch_r.reader], ch_semi.writer) + ) + tg.create_task(batch_op.async_execute([ch_semi.reader], ch_out.writer)) + + results = await ch_out.reader.collect() + # SemiJoin produces 2 rows (id=[2,3]), Batch(2) → 1 batch + assert len(results) == 1 + tag_d = results[0][0].as_dict() + assert isinstance(tag_d["id"], list) + assert sorted(tag_d["id"]) == [2, 3] + + @pytest.mark.asyncio + async def test_drop_map_select_three_stage(self): + """DropPacketColumns → MapPackets → SelectPacketColumns chain.""" + stream = make_simple_stream() # animal | weight, legs + + drop_op = DropPacketColumns(columns=["legs"]) + map_op = MapPackets(name_map={"weight": "mass"}) + # After map: mass (only packet column) + select_op = SelectPacketColumns(columns=["mass"]) + + ch1 = Channel(buffer_size=16) + ch2 = Channel(buffer_size=16) + ch3 = Channel(buffer_size=16) + ch4 = Channel(buffer_size=16) + + async def source(): + for tag, packet in stream.iter_packets(): + await ch1.writer.send((tag, packet)) + await ch1.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source()) + tg.create_task(drop_op.async_execute([ch1.reader], ch2.writer)) + tg.create_task(map_op.async_execute([ch2.reader], ch3.writer)) + tg.create_task(select_op.async_execute([ch3.reader], ch4.writer)) + + results = await ch4.reader.collect() + assert len(results) == 3 + for _, packet in results: + assert packet.keys() == ("mass",) + masses = sorted(pkt.as_dict()["mass"] for _, pkt in results) + assert masses == [0.5, 4.0, 12.0] + + +# =================================================================== +# Sync vs Async system-tag equivalence +# =================================================================== + + +def _make_source(tag_col: str, packet_col: str, data: dict) -> ArrowTableStream: + """Build an ArrowTableSource (which generates system tags) and return its stream.""" + from orcapod.core.sources.arrow_table_source import ArrowTableSource + + table = pa.table( + { + tag_col: pa.array(data[tag_col], type=pa.large_string()), + packet_col: pa.array(data[packet_col], type=pa.int64()), + } + ) + return ArrowTableSource(table, tag_columns=[tag_col]) + + +async def run_binary_validated( + op, + left: ArrowTableStream, + right: ArrowTableStream, +) -> list[tuple]: + """Run a binary operator async with validation and pipeline hashes. + + Calls ``validate_inputs`` for schema validation, then passes + ``input_pipeline_hashes`` so operators like ``Join`` can compute + canonical system-tag column names. + """ + op.validate_inputs(left, right) + left_ch = Channel(buffer_size=1024) + right_ch = Channel(buffer_size=1024) + output_ch = Channel(buffer_size=1024) + await feed(left, left_ch) + await feed(right, right_ch) + hashes = [left.pipeline_hash(), right.pipeline_hash()] + await op.async_execute( + [left_ch.reader, right_ch.reader], + output_ch.writer, + input_pipeline_hashes=hashes, + ) + return await output_ch.reader.collect() + + +def _extract_system_tags( + rows: list[tuple], +) -> list[dict[str, str]]: + """Extract sorted system-tag dicts from (tag, packet) pairs.""" + return sorted( + [tag.system_tags() for tag, _ in rows], + key=lambda d: sorted(d.items()), + ) + + +def _extract_system_tag_keys(rows: list[tuple]) -> set[str]: + """Collect all unique system-tag keys across rows.""" + keys: set[str] = set() + for tag, _ in rows: + keys.update(tag.system_tags().keys()) + return keys + + +class TestJoinSystemTagEquivalence: + """Verify that Join.async_execute produces the same system-tag column + names and values as the sync static_process path. + + Uses ``ArrowTableSource`` (which adds system-tag columns) rather than + bare ``ArrowTableStream`` to ensure system tags are present. + """ + + @pytest.mark.asyncio + async def test_two_way_system_tag_column_names_match(self): + """System-tag column names must be identical between sync and async.""" + left = _make_source("key", "value", {"key": ["a", "b"], "value": [10, 20]}) + right = _make_source("key", "score", {"key": ["a", "b"], "score": [100, 200]}) + op = Join() + + # Sync + sync_result = op.static_process(left, right) + sync_rows = list(sync_result.iter_packets()) + + # Async + async_rows = await run_binary_validated(op, left, right) + + sync_sys_keys = _extract_system_tag_keys(sync_rows) + async_sys_keys = _extract_system_tag_keys(async_rows) + + assert sync_sys_keys, "Expected system tags to be present" + assert sync_sys_keys == async_sys_keys + + @pytest.mark.asyncio + async def test_two_way_system_tag_values_match(self): + """System-tag values for each row must match between sync and async.""" + left = _make_source( + "key", "value", {"key": ["a", "b", "c"], "value": [1, 2, 3]} + ) + right = _make_source( + "key", "score", {"key": ["a", "b", "c"], "score": [10, 20, 30]} + ) + op = Join() + + sync_result = op.static_process(left, right) + sync_rows = list(sync_result.iter_packets()) + async_rows = await run_binary_validated(op, left, right) + + assert len(sync_rows) == len(async_rows) + + sync_sys = _extract_system_tags(sync_rows) + async_sys = _extract_system_tags(async_rows) + assert sync_sys == async_sys + + @pytest.mark.asyncio + async def test_two_way_system_tag_suffixes_use_pipeline_hash(self): + """System-tag column names should contain the pipeline_hash and + canonical position, matching the name-extending convention.""" + left = _make_source("key", "val", {"key": ["x"], "val": [1]}) + right = _make_source("key", "score", {"key": ["x"], "score": [2]}) + op = Join() + + async_rows = await run_binary_validated(op, left, right) + sys_keys = _extract_system_tag_keys(async_rows) + + # Each system-tag key should end with :{canonical_position} + for key in sys_keys: + assert key.startswith(constants.SYSTEM_TAG_PREFIX) + assert key[-2:] in (":0", ":1"), ( + f"System tag key {key!r} does not end with :0 or :1" + ) + + @pytest.mark.asyncio + async def test_commutativity_system_tags_identical(self): + """Join(A, B) and Join(B, A) should produce identical system tags + (Join is commutative — canonical ordering by pipeline_hash).""" + src_a = _make_source("id", "x", {"id": ["p", "q"], "x": [1, 2]}) + src_b = _make_source("id", "y", {"id": ["p", "q"], "y": [10, 20]}) + op = Join() + + rows_ab = await run_binary_validated(op, src_a, src_b) + rows_ba = await run_binary_validated(op, src_b, src_a) + + assert len(rows_ab) == len(rows_ba) + + sys_ab = _extract_system_tags(rows_ab) + sys_ba = _extract_system_tags(rows_ba) + assert sys_ab == sys_ba + + @pytest.mark.asyncio + async def test_three_way_system_tags_match_sync(self): + """N>2 barrier fallback should produce the same system tags as sync.""" + s1 = _make_source("id", "a", {"id": ["m", "n"], "a": [1, 2]}) + s2 = _make_source("id", "b", {"id": ["m", "n"], "b": [10, 20]}) + s3 = _make_source("id", "c", {"id": ["m", "n"], "c": [100, 200]}) + op = Join() + + # Sync + sync_result = op.static_process(s1, s2, s3) + sync_rows = list(sync_result.iter_packets()) + + # Async (N>2 barrier path) + op.validate_inputs(s1, s2, s3) + ch1 = Channel(buffer_size=64) + ch2 = Channel(buffer_size=64) + ch3 = Channel(buffer_size=64) + out = Channel(buffer_size=64) + await feed(s1, ch1) + await feed(s2, ch2) + await feed(s3, ch3) + hashes = [s1.pipeline_hash(), s2.pipeline_hash(), s3.pipeline_hash()] + await op.async_execute( + [ch1.reader, ch2.reader, ch3.reader], + out.writer, + input_pipeline_hashes=hashes, + ) + async_rows = await out.reader.collect() + + assert len(sync_rows) == len(async_rows) + + sync_sys_keys = _extract_system_tag_keys(sync_rows) + async_sys_keys = _extract_system_tag_keys(async_rows) + assert sync_sys_keys == async_sys_keys + + sync_sys = _extract_system_tags(sync_rows) + async_sys = _extract_system_tags(async_rows) + assert sync_sys == async_sys + + +class TestSemiJoinSystemTagEquivalence: + """Verify SemiJoin system-tag handling matches between sync and async.""" + + @pytest.mark.asyncio + async def test_system_tags_preserved_through_semijoin(self): + """SemiJoin should preserve left-side system tags in both paths.""" + left = _make_source("id", "val", {"id": ["a", "b", "c"], "val": [1, 2, 3]}) + right = _make_source( + "id", "score", {"id": ["b", "c", "d"], "score": [20, 30, 40]} + ) + op = SemiJoin() + + # Sync + sync_result = op.static_process(left, right) + sync_rows = list(sync_result.iter_packets()) + + # Async + async_rows = await run_binary_validated(op, left, right) + + assert len(sync_rows) == len(async_rows) == 2 + + sync_sys_keys = _extract_system_tag_keys(sync_rows) + async_sys_keys = _extract_system_tag_keys(async_rows) + assert sync_sys_keys == async_sys_keys + + sync_sys = _extract_system_tags(sync_rows) + async_sys = _extract_system_tags(async_rows) + assert sync_sys == async_sys diff --git a/tests/test_channels/test_node_async_execute.py b/tests/test_channels/test_node_async_execute.py new file mode 100644 index 00000000..7d29c1a7 --- /dev/null +++ b/tests/test_channels/test_node_async_execute.py @@ -0,0 +1,871 @@ +""" +Tests for async_execute on Node classes. + +Covers: +- CachedPacketFunction.async_call with cache support +- FunctionNode.async_execute basic streaming +- FunctionNode.async_execute two-phase logic +- OperatorNode.async_execute delegation +- OperatorNode.async_execute with cache modes +- execute_packet / async_process_packet routing +""" + +from __future__ import annotations + +import asyncio + +import pyarrow as pa +import pytest + +from orcapod.channels import Channel +from orcapod.core.datagrams import Packet +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import ( + FunctionNode, + OperatorNode, +) +from orcapod.core.operators import SelectPacketColumns +from orcapod.core.operators.join import Join +from orcapod.core.operators.semijoin import SemiJoin +from orcapod.core.packet_function import CachedPacketFunction, PythonPacketFunction +from orcapod.core.streams import ArrowTableStream +from orcapod.databases import InMemoryArrowDatabase +from orcapod.types import CacheMode, NodeConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_stream(n: int = 5) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +def make_two_col_stream(n: int = 3) -> ArrowTableStream: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + "y": pa.array([i * 10 + i for i in range(n)], type=pa.int64()), + } + ) + return ArrowTableStream(table, tag_columns=["id"]) + + +async def feed_stream_to_channel(stream: ArrowTableStream, ch: Channel) -> None: + """Push all (tag, packet) pairs from a stream into a channel, then close.""" + for tag, packet in stream.iter_packets(): + await ch.writer.send((tag, packet)) + await ch.writer.close() + + +def make_double_pod() -> tuple[PythonPacketFunction, FunctionPod]: + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + return pf, pod + + +# --------------------------------------------------------------------------- +# 1. CachedPacketFunction.async_call +# --------------------------------------------------------------------------- + + +class TestCachedPacketFunctionAsync: + @pytest.mark.asyncio + async def test_async_call_cache_miss_computes_and_records(self): + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + db = InMemoryArrowDatabase() + cpf = CachedPacketFunction(pf, result_database=db) + + packet = Packet({"x": 5}) + result, _captured = await cpf.async_call(packet) + + assert result is not None + assert result.as_dict()["result"] == 10 + # Check that result was recorded in DB + cached = cpf.get_cached_output_for_packet(packet) + assert cached is not None + assert cached.as_dict()["result"] == 10 + + @pytest.mark.asyncio + async def test_async_call_cache_hit_returns_cached(self): + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + db = InMemoryArrowDatabase() + cpf = CachedPacketFunction(pf, result_database=db) + + packet = Packet({"x": 5}) + # First call — computes + result1, _captured1 = await cpf.async_call(packet) + assert result1 is not None + # Has RESULT_COMPUTED_FLAG + assert result1.get_meta_value(cpf.RESULT_COMPUTED_FLAG, False) is True + + # Second call — should hit cache (no RESULT_COMPUTED_FLAG set to True) + result2, _captured2 = await cpf.async_call(packet) + assert result2 is not None + assert result2.as_dict()["result"] == 10 + # Cache hit should NOT have RESULT_COMPUTED_FLAG=True + # (the flag is only set on freshly computed results) + assert result2.get_meta_value(cpf.RESULT_COMPUTED_FLAG, None) is not True + + @pytest.mark.asyncio + async def test_async_call_skip_cache_lookup(self): + call_count = 0 + + def counting_double(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + pf = PythonPacketFunction(counting_double, output_keys="result") + db = InMemoryArrowDatabase() + cpf = CachedPacketFunction(pf, result_database=db) + + packet = Packet({"x": 5}) + _result1, _captured1 = await cpf.async_call(packet) + assert call_count == 1 + + # With skip_cache_lookup, should recompute + _result2, _captured2 = await cpf.async_call(packet, skip_cache_lookup=True) + assert call_count == 2 + + @pytest.mark.asyncio + async def test_async_call_skip_cache_insert(self): + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + db = InMemoryArrowDatabase() + cpf = CachedPacketFunction(pf, result_database=db) + + packet = Packet({"x": 5}) + result, _captured = await cpf.async_call(packet, skip_cache_insert=True) + assert result is not None + assert result.as_dict()["result"] == 10 + + # Should NOT be cached + cached = cpf.get_cached_output_for_packet(packet) + assert cached is None + + +# --------------------------------------------------------------------------- +# 3. FunctionNode.async_execute +# --------------------------------------------------------------------------- + + +class TestFunctionNodeAsyncExecute: + @pytest.mark.asyncio + async def test_basic_streaming_matches_sync(self): + _, pod = make_double_pod() + stream = make_stream(5) + + # Sync results + node_sync = FunctionNode(pod, stream) + sync_results = list(node_sync.iter_packets()) + sync_values = sorted(pkt.as_dict()["result"] for _, pkt in sync_results) + + # Async results + node_async = FunctionNode(pod, make_stream(5)) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(5), input_ch) + await node_async.async_execute(input_ch.reader, output_ch.writer) + + async_results = await output_ch.reader.collect() + async_values = sorted(pkt.as_dict()["result"] for _, pkt in async_results) + assert async_values == sync_values + + @pytest.mark.asyncio + async def test_empty_input_closes_cleanly(self): + _, pod = make_double_pod() + node = FunctionNode(pod, make_stream(1)) + + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=4) + + await input_ch.writer.close() + await node.async_execute(input_ch.reader, output_ch.writer) + + results = await output_ch.reader.collect() + assert results == [] + + @pytest.mark.asyncio + async def test_tags_preserved(self): + """Tags should pass through unchanged.""" + _, pod = make_double_pod() + node = FunctionNode(pod, make_stream(3)) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(3), input_ch) + await node.async_execute(input_ch.reader, output_ch.writer) + + results = await output_ch.reader.collect() + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [0, 1, 2] + + +# --------------------------------------------------------------------------- +# 4. FunctionNode.async_execute +# --------------------------------------------------------------------------- + + +class TestFunctionNodeAsyncExecute: + @pytest.mark.asyncio + async def test_no_cache_processes_all_inputs(self): + """With an empty DB, all inputs should be computed.""" + pf, pod = make_double_pod() + db = InMemoryArrowDatabase() + stream = make_stream(3) + node = FunctionNode(pod, stream, pipeline_database=db) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(3), input_ch) + await node.async_execute(input_ch.reader, output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 2, 4] + + @pytest.mark.asyncio + async def test_sync_run_then_async_emits_from_cache(self): + """After sync run() populates DB, async should emit cached results.""" + pf, pod = make_double_pod() + db = InMemoryArrowDatabase() + stream = make_stream(3) + + # Sync run to populate DB + node1 = FunctionNode(pod, stream, pipeline_database=db) + node1.run() + + # New node with same DB — send same packets, expect cached hits + input_stream = make_stream(3) + node2 = FunctionNode(pod, input_stream, pipeline_database=db) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + # Send the same packets that were already cached + for tag, packet in input_stream.iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + await node2.async_execute(input_ch.reader, output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 2, 4] + + @pytest.mark.asyncio + async def test_two_phase_cached_and_new(self): + """Phase 1 emits cached; Phase 2 computes new.""" + pf, pod = make_double_pod() + db = InMemoryArrowDatabase() + + # Sync run with 3 items to populate DB + stream = make_stream(3) + node1 = FunctionNode(pod, stream, pipeline_database=db) + node1.run() + + # Now run async with 5 items (3 cached + 2 new) + node2 = FunctionNode(pod, make_stream(5), pipeline_database=db) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(5), input_ch) + await node2.async_execute(input_ch.reader, output_ch.writer) + + results = await output_ch.reader.collect() + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + # 3 from cache + 2 new = 5 total + assert values == [0, 2, 4, 6, 8] + + @pytest.mark.asyncio + async def test_concurrent_execution_with_async_function(self): + """Async packets should run concurrently when max_concurrency > 1.""" + import time + + async def slow_double(x: int) -> int: + await asyncio.sleep(0.2) + return x * 2 + + pf = PythonPacketFunction(slow_double, output_keys="result") + pod = FunctionPod(pf, node_config=NodeConfig(max_concurrency=5)) + db = InMemoryArrowDatabase() + stream = make_stream(5) + node = FunctionNode(pod, stream, pipeline_database=db) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(5), input_ch) + + t0 = time.perf_counter() + await node.async_execute(input_ch.reader, output_ch.writer) + elapsed = time.perf_counter() - t0 + + results = await output_ch.reader.collect() + assert len(results) == 5 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 2, 4, 6, 8] + + # Concurrency limiting removed in PLT-922 (deferred to PLT-930). + # Packets are now processed sequentially. + assert elapsed >= 0.9, f"Expected sequential execution but took {elapsed:.2f}s" + + @pytest.mark.asyncio + async def test_db_records_created(self): + """Async execute should create pipeline records in the DB.""" + pf, pod = make_double_pod() + db = InMemoryArrowDatabase() + stream = make_stream(3) + node = FunctionNode(pod, stream, pipeline_database=db) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(3), input_ch) + await node.async_execute(input_ch.reader, output_ch.writer) + await output_ch.reader.collect() + + # Verify records in DB + records = node.get_all_records() + assert records is not None + assert records.num_rows == 3 + + +# --------------------------------------------------------------------------- +# 5. OperatorNode.async_execute +# --------------------------------------------------------------------------- + + +class TestOperatorNodeAsyncExecute: + @pytest.mark.asyncio + async def test_unary_op_delegation(self): + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + node = OperatorNode(op, [stream]) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_two_col_stream(3), input_ch) + await node.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + for _, packet in results: + pkt_dict = packet.as_dict() + assert "x" in pkt_dict + assert "y" not in pkt_dict + + @pytest.mark.asyncio + async def test_binary_op_delegation(self): + left = make_stream(5) + right_table = pa.table( + { + "id": pa.array([1, 3], type=pa.int64()), + "z": pa.array([100, 300], type=pa.int64()), + } + ) + right = ArrowTableStream(right_table, tag_columns=["id"]) + + op = SemiJoin() + node = OperatorNode(op, [left, right]) + + left_ch = Channel(buffer_size=16) + right_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(5), left_ch) + await feed_stream_to_channel( + ArrowTableStream(right_table, tag_columns=["id"]), right_ch + ) + await node.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [1, 3] + + @pytest.mark.asyncio + async def test_nary_op_delegation(self): + left_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "x": pa.array([10, 20, 30], type=pa.int64()), + } + ) + right_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "y": pa.array([100, 200, 300], type=pa.int64()), + } + ) + left = ArrowTableStream(left_table, tag_columns=["id"]) + right = ArrowTableStream(right_table, tag_columns=["id"]) + op = Join() + node = OperatorNode(op, [left, right]) + + left_ch = Channel(buffer_size=16) + right_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel( + ArrowTableStream(left_table, tag_columns=["id"]), left_ch + ) + await feed_stream_to_channel( + ArrowTableStream(right_table, tag_columns=["id"]), right_ch + ) + await node.async_execute([left_ch.reader, right_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + ids = sorted(tag.as_dict()["id"] for tag, _ in results) + assert ids == [0, 1, 2] + + @pytest.mark.asyncio + async def test_results_match_sync(self): + stream = make_two_col_stream(4) + op = SelectPacketColumns(["x"]) + + # Sync + node_sync = OperatorNode(op, [stream]) + node_sync.run() + sync_table = node_sync.as_table() + sync_x = sorted(sync_table.column("x").to_pylist()) + + # Async + node_async = OperatorNode(op, [make_two_col_stream(4)]) + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_two_col_stream(4), input_ch) + await node_async.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + async_x = sorted(pkt.as_dict()["x"] for _, pkt in results) + assert async_x == sync_x + + +# --------------------------------------------------------------------------- +# 6. OperatorNode.async_execute +# --------------------------------------------------------------------------- + + +class TestOperatorNodeAsyncExecute: + @pytest.mark.asyncio + async def test_off_mode_no_db_write(self): + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + db = InMemoryArrowDatabase() + node = OperatorNode( + op, [stream], pipeline_database=db, cache_mode=CacheMode.OFF + ) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_two_col_stream(3), input_ch) + await node.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + + # DB should be empty (OFF mode) + records = node.get_all_records() + assert records is None + + @pytest.mark.asyncio + async def test_log_mode_stores_results(self): + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + db = InMemoryArrowDatabase() + node = OperatorNode( + op, [stream], pipeline_database=db, cache_mode=CacheMode.LOG + ) + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_two_col_stream(3), input_ch) + await node.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + + # DB should have records (LOG mode) + records = node.get_all_records() + assert records is not None + assert records.num_rows == 3 + + @pytest.mark.asyncio + async def test_replay_mode_emits_from_db(self): + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + db = InMemoryArrowDatabase() + + # First: sync LOG to populate DB + node1 = OperatorNode( + op, [stream], pipeline_database=db, cache_mode=CacheMode.LOG + ) + node1.run() + + # Second: async REPLAY from DB + node2 = OperatorNode( + op, + [make_two_col_stream(3)], + pipeline_database=db, + cache_mode=CacheMode.REPLAY, + ) + + # No input needed for REPLAY — close input immediately + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=16) + + await input_ch.writer.close() + await node2.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 3 + values = sorted(pkt.as_dict()["x"] for _, pkt in results) + assert values == [0, 1, 2] + + @pytest.mark.asyncio + async def test_replay_empty_db_returns_empty(self): + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + db = InMemoryArrowDatabase() + + node = OperatorNode( + op, + [stream], + pipeline_database=db, + cache_mode=CacheMode.REPLAY, + ) + + input_ch = Channel(buffer_size=4) + output_ch = Channel(buffer_size=16) + + await input_ch.writer.close() + await node.async_execute([input_ch.reader], output_ch.writer) + + results = await output_ch.reader.collect() + assert len(results) == 0 + + +# --------------------------------------------------------------------------- +# 7. execute_packet routing verification +# --------------------------------------------------------------------------- + + +class TestExecutePacketRouting: + def test_function_node_sequential_uses_execute_packet(self): + """Verify FunctionNode routes through execute_packet (not raw pf.call).""" + call_log = [] + + _, pod = make_double_pod() + stream = make_stream(3) + node = FunctionNode(pod, stream) + + # Monkey-patch to verify routing through internal path + original = node._process_packet_internal + + def patched(tag, packet): + call_log.append("_process_packet_internal") + return original(tag, packet) + + node._process_packet_internal = patched + + results = list(node.iter_packets()) + assert len(results) == 3 + assert len(call_log) == 3 + + @pytest.mark.asyncio + async def test_function_node_async_uses_async_process_packet_internal(self): + """Verify FunctionNode.async_execute routes through _async_process_packet_internal.""" + call_log = [] + + _, pod = make_double_pod() + stream = make_stream(3) + node = FunctionNode(pod, stream) + + original = node._async_process_packet_internal + + async def patched(tag, packet): + call_log.append("_async_process_packet_internal") + return await original(tag, packet) + + node._async_process_packet_internal = patched + + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + await feed_stream_to_channel(make_stream(3), input_ch) + await node.async_execute(input_ch.reader, output_ch.writer) + await output_ch.reader.collect() + + assert len(call_log) == 3 + + +# --------------------------------------------------------------------------- +# 8. End-to-end async pipeline with nodes +# --------------------------------------------------------------------------- + + +class TestEndToEnd: + @pytest.mark.asyncio + async def test_source_to_function_node_pipeline(self): + """Source → FunctionNode async pipeline.""" + + def triple(x: int) -> int: + return x * 3 + + pf = PythonPacketFunction(triple, output_keys="result") + pod = FunctionPod(pf) + stream = make_stream(4) + node = FunctionNode(pod, stream) + + ch1 = Channel(buffer_size=16) + ch2 = Channel(buffer_size=16) + + async def source(): + for tag, packet in make_stream(4).iter_packets(): + await ch1.writer.send((tag, packet)) + await ch1.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source()) + tg.create_task(node.async_execute(ch1.reader, ch2.writer)) + + results = await ch2.reader.collect() + assert len(results) == 4 + values = sorted(pkt.as_dict()["result"] for _, pkt in results) + assert values == [0, 3, 6, 9] + + @pytest.mark.asyncio + async def test_source_to_operator_node_pipeline(self): + """Source → OperatorNode (SelectPacketColumns) async pipeline.""" + stream = make_two_col_stream(3) + op = SelectPacketColumns(["x"]) + node = OperatorNode(op, [stream]) + + ch1 = Channel(buffer_size=16) + ch2 = Channel(buffer_size=16) + + async def source(): + for tag, packet in make_two_col_stream(3).iter_packets(): + await ch1.writer.send((tag, packet)) + await ch1.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source()) + tg.create_task(node.async_execute([ch1.reader], ch2.writer)) + + results = await ch2.reader.collect() + assert len(results) == 3 + for _, packet in results: + pkt_dict = packet.as_dict() + assert "x" in pkt_dict + assert "y" not in pkt_dict + + +# --------------------------------------------------------------------------- +# 9. Async pipeline → synchronous DB retrieval (concrete example) +# --------------------------------------------------------------------------- + + +class TestAsyncPipelineThenSyncRetrieval: + """Demonstrates the full workflow: run an async pipeline, then retrieve + results synchronously from the database. + + This is the primary use-case for persistent nodes: async streaming + execution populates the DB, and later callers can retrieve results + without re-running the pipeline. + """ + + @pytest.mark.asyncio + async def test_persistent_function_node_async_then_sync_db_retrieval(self): + """FunctionNode: async execute → sync get_all_records.""" + + # --- Setup --- + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + db = InMemoryArrowDatabase() + stream = make_stream(5) # ids 0..4, x values 0..4 + + node = FunctionNode(pod, stream, pipeline_database=db) + + # --- Async pipeline execution --- + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + async def source_producer(): + for tag, packet in make_stream(5).iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source_producer()) + tg.create_task(node.async_execute(input_ch.reader, output_ch.writer)) + + async_results = await output_ch.reader.collect() + async_values = sorted(pkt.as_dict()["result"] for _, pkt in async_results) + assert async_values == [0, 2, 4, 6, 8] + + # --- Synchronous DB retrieval (no re-computation) --- + records = node.get_all_records() + assert records is not None + assert records.num_rows == 5 + + # The DB contains the same result values that were streamed async + result_col = records.column("result").to_pylist() + assert sorted(result_col) == [0, 2, 4, 6, 8] + + # A *new* node sharing the same DB can also read these records + node2 = FunctionNode(pod, make_stream(5), pipeline_database=db) + records2 = node2.get_all_records() + assert records2 is not None + assert records2.num_rows == 5 + assert sorted(records2.column("result").to_pylist()) == [0, 2, 4, 6, 8] + + @pytest.mark.asyncio + async def test_persistent_operator_node_log_then_sync_db_retrieval(self): + """OperatorNode (LOG): async execute → sync get_all_records.""" + # --- Setup --- + stream = make_two_col_stream(4) # ids 0..3, x 0..3, y 0,11,22,33 + op = SelectPacketColumns(["x"]) + db = InMemoryArrowDatabase() + + node = OperatorNode( + op, [stream], pipeline_database=db, cache_mode=CacheMode.LOG + ) + + # --- Async pipeline execution --- + input_ch = Channel(buffer_size=16) + output_ch = Channel(buffer_size=16) + + async def source_producer(): + for tag, packet in make_two_col_stream(4).iter_packets(): + await input_ch.writer.send((tag, packet)) + await input_ch.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source_producer()) + tg.create_task(node.async_execute([input_ch.reader], output_ch.writer)) + + async_results = await output_ch.reader.collect() + assert len(async_results) == 4 + async_x = sorted(pkt.as_dict()["x"] for _, pkt in async_results) + assert async_x == [0, 1, 2, 3] + + # --- Synchronous DB retrieval --- + records = node.get_all_records() + assert records is not None + assert records.num_rows == 4 + assert sorted(records.column("x").to_pylist()) == [0, 1, 2, 3] + # 'y' column should NOT be present (was dropped by SelectPacketColumns) + assert "y" not in records.column_names + + # --- REPLAY from DB via a new node (no computation) --- + replay_node = OperatorNode( + op, + [make_two_col_stream(4)], + pipeline_database=db, + cache_mode=CacheMode.REPLAY, + ) + replay_node.run() + replay_table = replay_node.as_table() + assert replay_table.num_rows == 4 + assert sorted(replay_table.column("x").to_pylist()) == [0, 1, 2, 3] + + @pytest.mark.asyncio + async def test_multi_stage_async_pipeline_with_db_retrieval(self): + """Two-stage async pipeline: Source → FunctionNode → OperatorNode. + + Both nodes are persistent. After async execution, results from each + stage can be retrieved synchronously from the database. + """ + + # --- Setup stage 1: double(x) --- + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + fn_db = InMemoryArrowDatabase() + stream = make_stream(3) # ids 0..2, x 0..2 + + fn_node = FunctionNode(pod, stream, pipeline_database=fn_db) + + # --- Setup stage 2: select only "result" column --- + # Build a placeholder stream for schema purposes (OperatorNode needs + # to validate inputs at construction time) + stage1_table = pa.table( + { + "id": pa.array([0, 1, 2], type=pa.int64()), + "result": pa.array([0, 2, 4], type=pa.int64()), + } + ) + stage1_stream = ArrowTableStream(stage1_table, tag_columns=["id"]) + op = SelectPacketColumns(["result"]) + op_db = InMemoryArrowDatabase() + op_node = OperatorNode( + op, [stage1_stream], pipeline_database=op_db, cache_mode=CacheMode.LOG + ) + + # --- Async pipeline execution --- + ch_source = Channel(buffer_size=16) + ch_mid = Channel(buffer_size=16) + ch_out = Channel(buffer_size=16) + + async def source_producer(): + for tag, packet in make_stream(3).iter_packets(): + await ch_source.writer.send((tag, packet)) + await ch_source.writer.close() + + async with asyncio.TaskGroup() as tg: + tg.create_task(source_producer()) + tg.create_task(fn_node.async_execute(ch_source.reader, ch_mid.writer)) + tg.create_task(op_node.async_execute([ch_mid.reader], ch_out.writer)) + + final_results = await ch_out.reader.collect() + assert len(final_results) == 3 + final_values = sorted(pkt.as_dict()["result"] for _, pkt in final_results) + assert final_values == [0, 2, 4] + + # --- Sync retrieval from stage 1 DB --- + fn_records = fn_node.get_all_records() + assert fn_records is not None + assert fn_records.num_rows == 3 + assert sorted(fn_records.column("result").to_pylist()) == [0, 2, 4] + + # --- Sync retrieval from stage 2 DB --- + op_records = op_node.get_all_records() + assert op_records is not None + assert op_records.num_rows == 3 + assert sorted(op_records.column("result").to_pylist()) == [0, 2, 4] diff --git a/tests/test_channels/test_pipeline_async_integration.py b/tests/test_channels/test_pipeline_async_integration.py new file mode 100644 index 00000000..87f31ab5 --- /dev/null +++ b/tests/test_channels/test_pipeline_async_integration.py @@ -0,0 +1,288 @@ +""" +Integration test — end-to-end async pipeline. + +Shows the recommended workflow in a single, linear example: + +1. Define domain functions with ``@function_pod``. +2. Build a pipeline with the ``Pipeline`` context manager. +3. Run the pipeline asynchronously via ``AsyncPipelineOrchestrator``. +4. Retrieve persisted results synchronously from the pipeline's + persistent nodes (``pipeline.