diff --git a/README.md b/README.md index eb4f904..5cec1b8 100644 --- a/README.md +++ b/README.md @@ -69,9 +69,11 @@ class TaxonomyNode: ``` Note that tax_id in parameters passed in functions described below are string but for example in the case of NCBI need -to be essentially quoting integers: `562 -> "562"`. +to be essentially quoting integers: `562 -> "562"`. If you loaded a taxonomy via JSON and you had additional data in your file, you can access it via indexing, `node["readcount"]` for example. +`TaxonomyNode` is a **snapshot** — it reflects the tree state at the time it was fetched. Mutations via `set_data` do not update existing node references; re-fetch the node to see updated data. + #### `tax.clone() -> Taxonomy` Return a new taxonomy, equivalent to a deep copy. @@ -144,10 +146,77 @@ Remove the node from the tree, re-attaching parents as needed: only a single nod Add a new node to the tree at the parent provided. -#### `edit_node(tax_id: str, /, name: str, rank: str, parent_id: str, parent_dist: float)` +#### `tax.edit_node(tax_id: str, /, name: str, rank: str, parent_id: str, parent_dist: float)` Edit properties on a taxonomy node. +### Storing and aggregating data on nodes + +#### `tax.set_data(node_id: str, key: str, value) -> None` + +Store an arbitrary value on a node. Mutates the taxonomy in-place. + +```python +tax.set_data("562", "readcount", 42) +tax["562"]["readcount"] # 42 +``` + +#### `node.get(key: str, default=None)` + +Safe read from a node's data with an optional fallback (mirrors `dict.get`). + +```python +node.get("readcount") # None if absent +node.get("readcount", 0) # 0 if absent +``` + +#### `node.data` + +Returns all data stored on a node as a Python `dict`. + +```python +node.data # e.g. {"readcount": 42} +``` + +#### `tax.reduce_up(node_id: str, output_key: str, fn) -> Taxonomy` + +Post-order (leaves → root) aggregation over the subtree rooted at `node_id`. The function `fn(node, child_results) -> result` is called once per node; results are stored under `output_key` and a **new Taxonomy** is returned (original unchanged). + +```python +# Compute inclusive clade read counts +annotated = tax.reduce_up("1", "clade_reads", + lambda node, child_results: node.get("readcount", 0) + sum(child_results)) +annotated["1224"]["clade_reads"] # all reads in Proteobacteria + +# Count detected species per clade +annotated = tax.reduce_up("1", "detected_species", + lambda node, child_results: sum(child_results) + + (1 if node.rank == "species" and node.get("readcount", 0) > 0 else 0)) +``` + +#### `tax.map_down(node_id: str, output_key: str, initial, fn) -> Taxonomy` + +Pre-order (root → leaves) propagation over the subtree rooted at `node_id`. The function `fn(parent_result, node) -> result` is called once per node; the root receives `initial` as its parent result. Results are stored under `output_key` and a **new Taxonomy** is returned. + +```python +# Build full lineage string for every node +annotated = tax.map_down("1", "lineage", "", + lambda parent, node: f"{parent};{node.id}" if parent else node.id) + +# Compute depth of every node +annotated = tax.map_down("1", "depth", 0, + lambda parent_depth, node: parent_depth + 1) +``` + +`reduce_up` and `map_down` are chainable — results stored by one call are visible to the next: + +```python +annotated = tax.reduce_up("1", "clade_reads", + lambda node, child_results: node.get("readcount", 0) + sum(child_results)) +annotated = annotated.map_down("1", "relative_abundance", 1.0, + lambda _, node: node["clade_reads"] / annotated["1"]["clade_reads"]) +``` + #### `internal_index(tax_id: str)` Return internal integer index used by some applications. For the JSON node-link diff --git a/docs/aggregation-api.md b/docs/aggregation-api.md new file mode 100644 index 0000000..6ab81cf --- /dev/null +++ b/docs/aggregation-api.md @@ -0,0 +1,226 @@ +# Taxonomy: Functional Tree Operations — API Design + +## Context + +Users need to store arbitrary data alongside taxonomy nodes and perform aggregation and transformation operations across the tree. This is motivated by use cases like computing subtree read counts in metagenomic analysis. + +All operations are implemented in Rust and exposed to Python via PyO3. Operations on the tree return new `Taxonomy` objects (immutable/functional style), consistent with the existing `prune` method. `TaxonomyNode` objects remain independent value objects with no back-reference to the tree. + +______________________________________________________________________ + +## Summary + +| Operation | Traversal | Lambda | Complexity | +|---|---|---|---| +| `reduce_up` | post-order (leaves → root) | `f(node, [child_results]) -> result` | O(n) | +| `map_down` | pre-order (root → leaves) | `f(parent_result, node) -> result` | O(n) | + +n = number of nodes in the subtree rooted at `node_id`. + +______________________________________________________________________ + +## Data Access + +### Reading — existing API + +`TaxonomyNode` already exposes extra data via `__getitem__`. Data is populated from the underlying `data: Vec>` field when a node is constructed. + +```python +node = tax["562"] +node["readcount"] # raises KeyError if key absent +node.get("readcount", 0) # returns default if absent — NEW +node.data # full data dict — NEW +``` + +**New methods needed on `TaxonomyNode`:** + +| Method | Complexity | Notes | +|---|---|---| +| `node.get(key, default=None)` | O(1) | safe read with fallback | +| `node.data` | O(d) | returns copy of data as Python dict, d = number of keys | + +**Note:** `TaxonomyNode` is a snapshot — it reflects the tree state at the time it was constructed. Calling `set_data` after fetching a node does not update existing node references. + +______________________________________________________________________ + +### Writing — new API + +```python +tax.set_data(node_id: str, key: str, value) -> None +``` + +- Mutates the taxonomy in-place (consistent with `add_node`, `edit_node`) +- **O(1)**: hash map lookup by `node_id`, hash map insert for `key` + +```python +tax.set_data("562", "readcount", 5) +tax["562"]["readcount"] # 5 +``` + +______________________________________________________________________ + +## Aggregation + +### `reduce_up` — Aggregate from leaves to root + +```python +tax.reduce_up(node_id: str, output_key: str, fn: Callable[[TaxonomyNode, List], result]) -> Taxonomy +``` + +- **O(n)** — visits every node in the subtree exactly once +- **Post-order** traversal: leaves visited before parents +- `fn(node, child_results) -> result` + - `node`: the current `TaxonomyNode` + - `child_results`: list of already-computed results from direct children (empty list for leaves) +- Stores result at **every node** under `output_key` +- Returns a **new Taxonomy** (original unchanged) +- No `initial` value — leaves handle the base case via `child_results == []` +- Chainable: results stored by one `reduce_up` are visible on nodes in the next + +Mirrors `functools.reduce` conceptually: reduces the tree bottom-up. + +```python +# Compute inclusive (clade) read counts — equivalent to Kraken's "clade_reads" +annotated = tax.reduce_up("1", "clade_reads", + lambda node, child_results: node.get("readcount", 0) + sum(child_results)) +annotated["562"]["clade_reads"] # all reads in the E. coli clade +annotated["1224"]["clade_reads"] # all reads in Proteobacteria + +# Count detected species per clade +tax.reduce_up("1", "detected_species", + lambda node, child_results: sum(child_results) + (1 if node.rank == "species" and node.get("readcount", 0) > 0 else 0)) + +# Compute relative abundance (chained) +annotated = tax.reduce_up("1", "clade_reads", + lambda node, child_results: node.get("readcount", 0) + sum(child_results)) +annotated.reduce_up("1", "relative_abundance", + lambda node, child_results: node.get("readcount", 0) / annotated["1"]["clade_reads"]) +``` + +______________________________________________________________________ + +### `map_down` — Propagate values from root to leaves + +```python +tax.map_down(node_id: str, output_key: str, initial, fn: Callable[[parent_result, TaxonomyNode], result]) -> Taxonomy +``` + +- **O(n)** — visits every node in the subtree exactly once +- **Pre-order** traversal: parents visited before children +- `fn(parent_result, node) -> result` + - `parent_result`: result stored at the parent (or `initial` for the root node) + - `node`: the current `TaxonomyNode` +- Stores result at **every node** under `output_key` +- Returns a **new Taxonomy** +- Chainable with `reduce_up` and `map_down` + +Mirrors Python's `map` conceptually: transforms each node using context flowing from its parent. + +```python +# Build full lineage string for every node (QIIME-style taxonomy strings) +tax.map_down("1", "lineage", "", + lambda parent_lineage, node: f"{parent_lineage};{node.name}" if parent_lineage else node.name) +# tax["562"]["lineage"] +# → "Bacteria;Proteobacteria;Gammaproteobacteria;Enterobacterales;Enterobacteriaceae;Escherichia;Escherichia coli" + +# Compute depth of every node +tax.map_down("1", "depth", 0, + lambda parent_depth, node: parent_depth + 1) + +# Propagate cumulative branch length from root +tax.map_down("1", "distance_from_root", 0.0, + lambda parent_dist, node: parent_dist + node["branch_length"]) +``` + +______________________________________________________________________ + +## Performance Notes + +The lambda receives a full `TaxonomyNode` on every call, which currently requires allocating and populating a Python object per node (string copies for `id`, `name`, `rank`, `parent`, plus all data keys). For large trees (e.g. NCBI ~2M nodes) this has meaningful overhead. Two future optimization paths: + +- **Zero-copy node**: pass a borrowed view backed by a pointer into the Rust tree (safe during traversal since the tree is not mutated), avoiding all allocations +- **Built-in Rust-native ops** (`sum`, `count`, `max`, `min`): bypass the lambda entirely for common cases + +Both are deferred until the API is validated. + +______________________________________________________________________ + +## Comparison to NetworkX and ete3 + +This library was written as a replacement for NetworkX for taxonomy use cases. Neither NetworkX nor ete3 have built-in equivalents of `reduce_up` or `map_down` — both require manual traversal loops. + +### `reduce_up` + +**NetworkX:** + +```python +def reduce_up(G, root, fn): + for node_id in nx.dfs_postorder_nodes(G, root): + child_results = [G.nodes[c]["_result"] for c in G.successors(node_id)] + G.nodes[node_id]["_result"] = fn(G.nodes[node_id], child_results) +``` + +**ete3:** + +```python +for node in tree.traverse("postorder"): + child_results = [c.clade_reads for c in node.children] + node.clade_reads = node.readcount + sum(child_results) +``` + +**taxonomy:** + +```python +annotated = tax.reduce_up("1", "clade_reads", + lambda node, child_results: node.get("readcount", 0) + sum(child_results)) +``` + +______________________________________________________________________ + +### `map_down` + +**NetworkX:** + +```python +def map_down(G, root, initial, fn): + for node_id in nx.dfs_preorder_nodes(G, root): + parents = list(G.predecessors(node_id)) + parent_result = G.nodes[parents[0]]["_result"] if parents else initial + G.nodes[node_id]["_result"] = fn(parent_result, G.nodes[node_id]) +``` + +**ete3:** + +```python +for node in tree.traverse("preorder"): + parent_lineage = node.up.lineage if not node.is_root() else "" + node.lineage = f"{parent_lineage};{node.name}" if parent_lineage else node.name +``` + +**taxonomy:** + +```python +annotated = tax.map_down("1", "lineage", "", + lambda parent_lineage, node: f"{parent_lineage};{node.name}" if parent_lineage else node.name) +``` + +______________________________________________________________________ + +Key differences from NetworkX: + +- NetworkX uses `DiGraph` with dict-style node attributes; this library uses typed `TaxonomyNode` objects with rank, name, and parent built in +- NetworkX has no concept of taxonomic rank, lineage, or LCA — these require manual implementation +- This library is implemented in Rust; NetworkX is pure Python + +Key differences from ete3: + +- ete3 uses attribute access (`node.readcount`); this library uses `node["readcount"]` +- ete3 is pure Python; this library is implemented in Rust +- ete3 has richer phylogenetic features (branch support, evolutionary models); this library is optimized for large taxonomic trees (NCBI ~2M nodes) + +______________________________________________________________________ + +## Deferred + +- Built-in Rust-native `sum`, `count`, `max`, `min` (optimization, post-validation) +- `map(output_key, fn)` — transform data values per node without aggregation diff --git a/pyproject.toml b/pyproject.toml index e326eea..75765cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["maturin>=0.14,<0.15"] +requires = ["maturin>=1.0"] build-backend = "maturin" [project] diff --git a/src/base.rs b/src/base.rs index 77c377a..c172323 100644 --- a/src/base.rs +++ b/src/base.rs @@ -292,6 +292,7 @@ impl GeneralTaxonomy { self.parent_distances.remove(idx); self.ranks.remove(idx); self.names.remove(idx); + self.data.remove(idx); // everything after `tax_id` in parents needs to get decremented by 1 // because we've changed the actual array size diff --git a/src/python.rs b/src/python.rs index ff50dc4..d57e583 100644 --- a/src/python.rs +++ b/src/python.rs @@ -62,6 +62,27 @@ fn json_value_to_pyobject(val: &Value) -> PyObject { }) } +fn pyobject_to_json_value(obj: &PyAny) -> PyResult { + if obj.is_none() { + Ok(Value::Null) + } else if let Ok(b) = obj.extract::() { + Ok(Value::Bool(b)) + } else if let Ok(i) = obj.extract::() { + Ok(Value::Number(i.into())) + } else if let Ok(f) = obj.extract::() { + let n = serde_json::Number::from_f64(f).ok_or_else(|| { + PyErr::new::("Cannot convert non-finite float to JSON") + })?; + Ok(Value::Number(n)) + } else if let Ok(s) = obj.extract::() { + Ok(Value::String(s)) + } else { + Err(PyErr::new::( + "Cannot convert Python object to JSON value: unsupported type", + )) + } +} + /// The data returned when looking up a taxonomy by id or by name #[pyclass] #[derive(Debug, Clone, Eq, PartialEq)] @@ -124,6 +145,26 @@ impl TaxonomyNode { self.id, self.rank, self.name )) } + + #[pyo3(signature = (key, default=None))] + fn get(&self, key: &str, default: Option<&PyAny>, py: Python<'_>) -> PyResult { + if let Some(val) = self.extra.get(key) { + Ok(json_value_to_pyobject(val)) + } else { + Ok(default + .map(|d| d.to_object(py)) + .unwrap_or_else(|| py.None())) + } + } + + #[getter] + fn data(&self, py: Python<'_>) -> PyResult { + let d = PyDict::new(py); + for (k, v) in &self.extra { + d.set_item(k, json_value_to_pyobject(v))?; + } + Ok(d.to_object(py)) + } } /// The Taxonomy object provides the primary interface for exploring a @@ -524,6 +565,89 @@ impl Taxonomy { Ok(()) } + fn set_data(&mut self, node_id: &str, key: &str, value: &PyAny) -> PyResult<()> { + let idx = py_try!(self.tax.to_internal_index(node_id)); + let json_val = pyobject_to_json_value(value)?; + self.tax.data[idx].insert(key.to_string(), json_val); + Ok(()) + } + + fn reduce_up( + &self, + node_id: &str, + output_key: &str, + fn_: &PyAny, + py: Python<'_>, + ) -> PyResult { + let start_idx = py_try!(self.tax.to_internal_index(node_id)); + let mut results: HashMap = HashMap::new(); + let traversal: Vec<(InternalIndex, bool)> = py_try!( + TaxonomyTrait::::traverse(&self.tax, start_idx) + ) + .collect(); + let mut new_tax = self.tax.clone(); + + for (idx, is_pre) in traversal { + if is_pre { + continue; + } + let tax_id = py_try!(self.tax.from_internal_index(idx)); + let node = self.as_node(tax_id)?; + let child_list = PyList::empty(py); + for &child_idx in &self.tax.children_lookup[idx] { + if let Some(child_result) = results.get(&child_idx) { + child_list.append(child_result)?; + } + } + let result: PyObject = fn_.call1((node.into_py(py), child_list))?.to_object(py); + let json_val = pyobject_to_json_value(result.as_ref(py))?; + new_tax.data[idx].insert(output_key.to_string(), json_val); + results.insert(idx, result); + } + + Ok(Taxonomy { tax: new_tax }) + } + + fn map_down( + &self, + node_id: &str, + output_key: &str, + initial: &PyAny, + fn_: &PyAny, + py: Python<'_>, + ) -> PyResult { + let start_idx = py_try!(self.tax.to_internal_index(node_id)); + let mut parent_results: HashMap = HashMap::new(); + let traversal: Vec<(InternalIndex, bool)> = py_try!( + TaxonomyTrait::::traverse(&self.tax, start_idx) + ) + .collect(); + let mut new_tax = self.tax.clone(); + + for (idx, is_pre) in traversal { + if !is_pre { + continue; + } + let tax_id = py_try!(self.tax.from_internal_index(idx)); + let node = self.as_node(tax_id)?; + let parent_result: &PyAny = if idx == start_idx { + initial + } else { + let pidx = self.tax.parent_ids[idx]; + parent_results + .get(&pidx) + .map(|r| r.as_ref(py)) + .unwrap_or(initial) + }; + let result: PyObject = fn_.call1((parent_result, node.into_py(py)))?.to_object(py); + let json_val = pyobject_to_json_value(result.as_ref(py))?; + new_tax.data[idx].insert(output_key.to_string(), json_val); + parent_results.insert(idx, result); + } + + Ok(Taxonomy { tax: new_tax }) + } + #[getter] fn root(&self) -> TaxonomyNode { let key: &str = self.tax.root(); diff --git a/taxonomy.pyi b/taxonomy.pyi index 91eb779..f3fc461 100644 --- a/taxonomy.pyi +++ b/taxonomy.pyi @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Tuple, Iterator +from typing import Any, Callable, List, Optional, Tuple, Iterator class TaxonomyError(Exception): """Raised when an error occurs in the taxonomy library.""" @@ -18,6 +18,9 @@ class TaxonomyNode: def __getitem__(self, key: str) -> Any: ... def __eq__(self, other: object) -> bool: ... def __ne__(self, other: object) -> bool: ... + def get(self, key: str, default: Any = None) -> Any: ... + @property + def data(self) -> dict: ... class Taxonomy: """ @@ -167,6 +170,40 @@ class Taxonomy: """Edit properties on a taxonomy node.""" ... + def set_data(self, node_id: str, key: str, value: Any) -> None: + """Store an arbitrary value on a node. Mutates the taxonomy in-place.""" + ... + + def reduce_up( + self, + node_id: str, + output_key: str, + fn: Callable[[Any, List[Any]], Any], + ) -> "Taxonomy": + """ + Post-order (leaves-to-root) aggregation over the subtree rooted at node_id. + + fn(node, child_results) -> result is called for every node. + Results are stored under output_key and a new Taxonomy is returned. + """ + ... + + def map_down( + self, + node_id: str, + output_key: str, + initial: Any, + fn: Callable[[Any, Any], Any], + ) -> "Taxonomy": + """ + Pre-order (root-to-leaves) propagation over the subtree rooted at node_id. + + fn(parent_result, node) -> result is called for every node. + The root receives initial as parent_result. + Results are stored under output_key and a new Taxonomy is returned. + """ + ... + def __repr__(self) -> str: ... def __len__(self) -> int: ... def __getitem__(self, tax_id: str) -> TaxonomyNode: ... diff --git a/test_python_aggregation.py b/test_python_aggregation.py new file mode 100644 index 0000000..55d7798 --- /dev/null +++ b/test_python_aggregation.py @@ -0,0 +1,247 @@ +"""Tests for taxonomy data storage and tree aggregation functions. + +Taxonomy: Cetacea, as catalogued by Ishmael in Moby Dick. +"Sightings" data represents whale encounters recorded in the novel. + +Tree structure: + Cetacea + ├── Mysticeti (baleen whales) + │ ├── Eubalaena_glacialis (right whale, 5 sightings) + │ └── Balaenopteridae + │ ├── Balaenoptera_musculus (blue whale, 2 sightings) + │ └── Megaptera_novaeangliae (humpback whale, 3 sightings) + └── Odontoceti (toothed whales) + ├── Physeter_macrocephalus (sperm whale, 90 sightings) + └── Kogia_breviceps (pygmy sperm, 1 sighting) +""" + +import pytest +from taxonomy import Taxonomy + +WHALE_NEWICK = ( + "(" + "(Eubalaena_glacialis:1," + "(Balaenoptera_musculus:1,Megaptera_novaeangliae:1)Balaenopteridae:1" + ")Mysticeti:1," + "(Physeter_macrocephalus:1,Kogia_breviceps:1)Odontoceti:1" + ")Cetacea;" +) + +SPECIES_SIGHTINGS = { + "Physeter_macrocephalus": 90, # the white whale — dominates the novel + "Eubalaena_glacialis": 5, + "Balaenoptera_musculus": 2, + "Megaptera_novaeangliae": 3, + "Kogia_breviceps": 1, +} + +TOTAL_SIGHTINGS = sum(SPECIES_SIGHTINGS.values()) # 101 + + +@pytest.fixture +def whale_tax(): + """Whale taxonomy with sighting counts from Moby Dick.""" + tax = Taxonomy.from_newick(WHALE_NEWICK) + for node_id in SPECIES_SIGHTINGS: + tax.edit_node(node_id, rank="species") + for node_id, count in SPECIES_SIGHTINGS.items(): + tax.set_data(node_id, "sightings", count) + return tax + + +# ── set_data ───────────────────────────────────────────────────────────────── + + +def test_set_data(whale_tax): + assert whale_tax["Physeter_macrocephalus"]["sightings"] == 90 + + +def test_set_data_overwrite(whale_tax): + whale_tax.set_data("Physeter_macrocephalus", "sightings", 999) + assert whale_tax["Physeter_macrocephalus"]["sightings"] == 999 + + +def test_set_data_new_key(whale_tax): + whale_tax.set_data("Physeter_macrocephalus", "chapters_mentioned", 135) + assert whale_tax["Physeter_macrocephalus"]["chapters_mentioned"] == 135 + + +# ── node.get ───────────────────────────────────────────────────────────────── + + +def test_node_get_existing_key(whale_tax): + assert whale_tax["Physeter_macrocephalus"].get("sightings") == 90 + + +def test_node_get_missing_key_returns_none(whale_tax): + # Internal nodes have no sightings set + assert whale_tax["Mysticeti"].get("sightings") is None + + +def test_node_get_missing_key_custom_default(whale_tax): + assert whale_tax["Mysticeti"].get("sightings", 0) == 0 + + +# ── node.data ───────────────────────────────────────────────────────────────── + + +def test_node_data_returns_dict(whale_tax): + data = whale_tax["Physeter_macrocephalus"].data + assert isinstance(data, dict) + assert data["sightings"] == 90 + + +def test_node_data_empty_for_internal_node(whale_tax): + assert whale_tax["Mysticeti"].data == {} + + +def test_node_data_is_snapshot(whale_tax): + """TaxonomyNode is a snapshot — set_data does not update existing references.""" + node = whale_tax["Physeter_macrocephalus"] + whale_tax.set_data("Physeter_macrocephalus", "sightings", 999) + assert node["sightings"] == 90 # old snapshot + assert whale_tax["Physeter_macrocephalus"]["sightings"] == 999 # re-fetched + + +# ── reduce_up ───────────────────────────────────────────────────────────────── + + +def clade_sightings(node, child_results): + return node.get("sightings", 0) + sum(child_results) + + +def test_reduce_up_leaf(whale_tax): + """Leaf clade sightings equals its own sightings.""" + annotated = whale_tax.reduce_up( + "Physeter_macrocephalus", "clade_sightings", clade_sightings + ) + assert annotated["Physeter_macrocephalus"]["clade_sightings"] == 90 + + +def test_reduce_up_internal_node(whale_tax): + """Odontoceti clade = sperm whale + pygmy sperm whale.""" + annotated = whale_tax.reduce_up( + "Cetacea", "clade_sightings", clade_sightings + ) + assert annotated["Odontoceti"]["clade_sightings"] == 90 + 1 + + +def test_reduce_up_mysticeti(whale_tax): + """Mysticeti clade = right whale + blue whale + humpback.""" + annotated = whale_tax.reduce_up( + "Cetacea", "clade_sightings", clade_sightings + ) + assert annotated["Mysticeti"]["clade_sightings"] == 5 + 2 + 3 + + +def test_reduce_up_root(whale_tax): + """Root clade sightings equals total across all species.""" + annotated = whale_tax.reduce_up( + "Cetacea", "clade_sightings", clade_sightings + ) + assert annotated["Cetacea"]["clade_sightings"] == TOTAL_SIGHTINGS + + +def test_reduce_up_preserves_original(whale_tax): + """reduce_up returns a new taxonomy; original is unchanged.""" + whale_tax.reduce_up("Cetacea", "clade_sightings", clade_sightings) + assert whale_tax["Odontoceti"].get("clade_sightings") is None + + +def test_reduce_up_count_species(whale_tax): + """Count species with any sightings per clade.""" + annotated = whale_tax.reduce_up( + "Cetacea", + "detected_species", + lambda node, child_results: sum(child_results) + + (1 if node.rank == "species" and node.get("sightings", 0) > 0 else 0), + ) + assert annotated["Mysticeti"]["detected_species"] == 3 + assert annotated["Odontoceti"]["detected_species"] == 2 + assert annotated["Cetacea"]["detected_species"] == 5 + + +def test_reduce_up_max_subclade(whale_tax): + """Max clade sightings among all subclades — uses child_results explicitly.""" + annotated = whale_tax.reduce_up( + "Cetacea", "clade_sightings", clade_sightings + ) + annotated = annotated.reduce_up( + "Cetacea", + "dominant_subclade_sightings", + lambda node, child_results: max(child_results) + if child_results + else node.get("clade_sightings", 0), + ) + # Physeter (90) is the dominant individual species; that propagates up through Odontoceti + assert annotated["Cetacea"]["dominant_subclade_sightings"] == 90 + + +# ── map_down ────────────────────────────────────────────────────────────────── + + +def test_map_down_depth_root(whale_tax): + annotated = whale_tax.map_down( + "Cetacea", "depth", 0, lambda parent_depth, node: parent_depth + 1 + ) + assert annotated["Cetacea"]["depth"] == 1 + + +def test_map_down_depth_internal(whale_tax): + annotated = whale_tax.map_down( + "Cetacea", "depth", 0, lambda parent_depth, node: parent_depth + 1 + ) + assert annotated["Odontoceti"]["depth"] == 2 + assert annotated["Balaenopteridae"]["depth"] == 3 + + +def test_map_down_depth_leaf(whale_tax): + annotated = whale_tax.map_down( + "Cetacea", "depth", 0, lambda parent_depth, node: parent_depth + 1 + ) + assert annotated["Physeter_macrocephalus"]["depth"] == 3 + assert annotated["Balaenoptera_musculus"]["depth"] == 4 + + +def test_map_down_lineage(whale_tax): + annotated = whale_tax.map_down( + "Cetacea", + "lineage", + "", + lambda parent, node: f"{parent};{node.id}" if parent else node.id, + ) + assert annotated["Physeter_macrocephalus"]["lineage"] == ( + "Cetacea;Odontoceti;Physeter_macrocephalus" + ) + assert annotated["Balaenoptera_musculus"]["lineage"] == ( + "Cetacea;Mysticeti;Balaenopteridae;Balaenoptera_musculus" + ) + + +def test_map_down_preserves_original(whale_tax): + """map_down returns a new taxonomy; original is unchanged.""" + whale_tax.map_down("Cetacea", "depth", 0, lambda d, node: d + 1) + assert whale_tax["Cetacea"].get("depth") is None + + +# ── chaining ────────────────────────────────────────────────────────────────── + + +def test_chain_reduce_up_then_map_down(whale_tax): + """ + Compute clade sightings bottom-up, then propagate clade fraction top-down. + Each node stores what fraction of the root's sightings belong to its clade. + """ + annotated = whale_tax.reduce_up( + "Cetacea", "clade_sightings", clade_sightings + ) + annotated = annotated.map_down( + "Cetacea", + "clade_fraction", + 1.0, + lambda _, node: node["clade_sightings"] / TOTAL_SIGHTINGS, + ) + + assert annotated["Cetacea"]["clade_fraction"] == pytest.approx(1.0) + assert annotated["Odontoceti"]["clade_fraction"] == pytest.approx(91 / 101) + assert annotated["Mysticeti"]["clade_fraction"] == pytest.approx(10 / 101)