From fff7ca5962c88fc4a24d5c640cb1699142fba0e1 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 10:43:27 +0100 Subject: [PATCH 01/13] Refactor DAG internals and lazy-load networkx --- docs/source/tutorials/visualizing_the_dag.md | 7 +- pyproject.toml | 6 +- src/_pytask/dag.py | 30 +-- src/_pytask/dag_command.py | 26 ++- src/_pytask/dag_graph.py | 223 +++++++++++++++++++ src/_pytask/dag_utils.py | 41 ++-- src/_pytask/execute.py | 23 +- src/_pytask/mark/__init__.py | 9 +- src/_pytask/session.py | 4 +- src/_pytask/shared.py | 4 +- tests/test_dag_command.py | 72 +++++- tests/test_dag_utils.py | 4 +- uv.lock | 25 ++- 13 files changed, 406 insertions(+), 68 deletions(-) create mode 100644 src/_pytask/dag_graph.py diff --git a/docs/source/tutorials/visualizing_the_dag.md b/docs/source/tutorials/visualizing_the_dag.md index f3554f3f..297db87e 100644 --- a/docs/source/tutorials/visualizing_the_dag.md +++ b/docs/source/tutorials/visualizing_the_dag.md @@ -1,11 +1,12 @@ # Visualizing the DAG To visualize the [DAG](../glossary.md#dag) of the project, first, install -[pygraphviz](https://github.com/pygraphviz/pygraphviz) and -[graphviz](https://graphviz.org/). For example, you can both install with pixi +[networkx](https://networkx.org/), +[pygraphviz](https://github.com/pygraphviz/pygraphviz), and +[graphviz](https://graphviz.org/). For example, you can install them with pixi ```console -$ pixi add pygraphviz graphviz +$ pixi add networkx pygraphviz graphviz ``` After that, pytask offers two interfaces to visualize your project's `DAG`. diff --git a/pyproject.toml b/pyproject.toml index 88c9f812..4906bc23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "click>=8.1.8,!=8.2.0", "click-default-group>=1.2.4", "msgspec>=0.18.6", - "networkx>=2.4.0", "optree>=0.9.0", "packaging>=23.0.0", "pluggy>=1.3.0", @@ -36,6 +35,9 @@ dependencies = [ "universal-pathlib>=0.2.2", ] +[project.optional-dependencies] +dag = ["networkx>=2.4.0"] + [project.readme] file = "README.md" content-type = "text/markdown" @@ -54,6 +56,7 @@ docs = [ "ipywidgets>=8.1.6", "matplotlib>=3.5.0", "mkdocstrings[python]>=0.30.0", + "networkx>=2.4.0", "zensical>=0.0.23", ] docs-live = ["sphinx-autobuild>=2024.10.3"] @@ -71,6 +74,7 @@ test = [ "syrupy>=4.5.0", "aiohttp>=3.11.0", # For HTTPPath tests. "coiled>=1.42.0; python_version < '3.14'", + "networkx>=2.4.0", "pygraphviz>=1.12;platform_system=='Linux'", ] typing = ["ty>=0.0.8"] diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index 538462d8..7d6f19b9 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -6,7 +6,6 @@ import sys from typing import TYPE_CHECKING -import networkx as nx from rich.text import Text from rich.tree import Tree @@ -17,6 +16,9 @@ from _pytask.console import format_node_name from _pytask.console import format_task_name from _pytask.console import render_to_string +from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import NoCycleError +from _pytask.dag_graph import find_cycle from _pytask.exceptions import ResolvingDependenciesError from _pytask.mark import select_by_after_keyword from _pytask.mark import select_tasks_by_marks_and_expressions @@ -37,7 +39,7 @@ __all__ = ["create_dag", "create_dag_from_session"] -def create_dag(session: Session) -> nx.DiGraph: +def create_dag(session: Session) -> DiGraph: """Create a directed acyclic graph (DAG) for the workflow.""" try: dag = create_dag_from_session(session) @@ -50,7 +52,7 @@ def create_dag(session: Session) -> nx.DiGraph: return dag -def create_dag_from_session(session: Session) -> nx.DiGraph: +def create_dag_from_session(session: Session) -> DiGraph: """Create a DAG from a session.""" dag = _create_dag_from_tasks(tasks=session.tasks) _check_if_dag_has_cycles(dag) @@ -60,11 +62,11 @@ def create_dag_from_session(session: Session) -> nx.DiGraph: return dag -def _create_dag_from_tasks(tasks: list[PTask]) -> nx.DiGraph: +def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph: """Create the DAG from tasks, dependencies and products.""" def _add_dependency( - dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode + dag: DiGraph, task: PTask, node: PNode | PProvisionalNode ) -> None: """Add a dependency to the DAG.""" dag.add_node(node.signature, node=node) @@ -76,14 +78,12 @@ def _add_dependency( if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): dag.add_edge(node.value.signature, node.signature) - def _add_product( - dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode - ) -> None: + def _add_product(dag: DiGraph, task: PTask, node: PNode | PProvisionalNode) -> None: """Add a product to the DAG.""" dag.add_node(node.signature, node=node) dag.add_edge(task.signature, node.signature) - dag = nx.DiGraph() + dag = DiGraph() for task in tasks: dag.add_node(task.signature, task=task) @@ -105,7 +105,7 @@ def _add_product( return dag -def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph: +def _modify_dag(session: Session, dag: DiGraph) -> DiGraph: """Create dependencies between tasks when using ``@task(after=...)``.""" temporary_id_to_task = { task.attributes["collection_id"]: task @@ -129,11 +129,11 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph: return dag -def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None: +def _check_if_dag_has_cycles(dag: DiGraph) -> None: """Check if DAG has cycles.""" try: - cycles = nx.algorithms.cycles.find_cycle(dag) - except nx.NetworkXNoCycle: + cycles = find_cycle(dag) + except NoCycleError: pass else: msg = ( @@ -145,7 +145,7 @@ def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None: raise ResolvingDependenciesError(msg) -def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str: +def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str: """Format cycles as a paths connected by arrows.""" chain = [ x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0 @@ -176,7 +176,7 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str: return render_to_string(tree, console=console, strip_styles=True) -def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None: +def _check_if_tasks_have_the_same_products(dag: DiGraph, paths: list[Path]) -> None: nodes_created_by_multiple_tasks = [] for node in dag.nodes: diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index f7ed3c6f..e0497536 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -5,11 +5,11 @@ import enum import sys from pathlib import Path +from typing import TYPE_CHECKING from typing import Any from typing import cast import click -import networkx as nx from rich.text import Text from _pytask.click import ColoredCommand @@ -30,6 +30,11 @@ from _pytask.shared import reduce_names_of_multiple_nodes from _pytask.traceback import Traceback +if TYPE_CHECKING: + import networkx as nx + + from _pytask.dag_graph import DiGraph + class _RankDirection(enum.Enum): TB = "TB" @@ -92,6 +97,7 @@ def dag(**raw_config: Any) -> int: else: try: session.hook.pytask_log_session_header(session=session) + import_optional_dependency("networkx") import_optional_dependency("pygraphviz") check_for_optional_program( session.config["layout"], @@ -100,7 +106,7 @@ def dag(**raw_config: Any) -> int: ) session.hook.pytask_collect(session=session) session.dag = create_dag(session=session) - dag = _refine_dag(session) + dag = _refine_dag(session).to_networkx() _write_graph(dag, session.config["output_path"], session.config["layout"]) except CollectionError: # pragma: no cover @@ -163,6 +169,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: else: session.hook.pytask_log_session_header(session=session) + import_optional_dependency("networkx") import_optional_dependency("pygraphviz") check_for_optional_program( session.config["layout"], @@ -172,10 +179,10 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: session.hook.pytask_collect(session=session) session.dag = create_dag(session=session) session.hook.pytask_unconfigure(session=session) - return _refine_dag(session) + return _refine_dag(session).to_networkx() -def _refine_dag(session: Session) -> nx.DiGraph: +def _refine_dag(session: Session) -> DiGraph: """Refine the dag for plotting.""" dag = _shorten_node_labels(session.dag, session.config["paths"]) dag = _clean_dag(dag) @@ -185,31 +192,32 @@ def _refine_dag(session: Session) -> nx.DiGraph: return dag -def _shorten_node_labels(dag: nx.DiGraph, paths: list[Path]) -> nx.DiGraph: +def _shorten_node_labels(dag: DiGraph, paths: list[Path]) -> DiGraph: """Shorten the node labels in the graph for a better experience.""" node_names = dag.nodes short_names = reduce_names_of_multiple_nodes(node_names, dag, paths) short_names = [i.plain if isinstance(i, Text) else i for i in short_names] old_to_new = dict(zip(node_names, short_names, strict=False)) - return nx.relabel_nodes(dag, old_to_new) + return dag.relabel_nodes(old_to_new) -def _clean_dag(dag: nx.DiGraph) -> nx.DiGraph: +def _clean_dag(dag: DiGraph) -> DiGraph: """Clean the DAG.""" for node in dag.nodes: dag.nodes[node].clear() return dag -def _style_dag(dag: nx.DiGraph) -> nx.DiGraph: +def _style_dag(dag: DiGraph) -> DiGraph: """Style the DAG.""" shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes} - nx.set_node_attributes(dag, shapes, "shape") + dag.set_node_attributes(shapes, "shape") return dag def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None: """Write the graph to disk.""" + nx = cast("Any", import_optional_dependency("networkx")) path.parent.mkdir(exist_ok=True, parents=True) graph = nx.nx_agraph.to_agraph(dag) graph.draw(path, prog=layout) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py new file mode 100644 index 00000000..8fa3d8c6 --- /dev/null +++ b/src/_pytask/dag_graph.py @@ -0,0 +1,223 @@ +"""Internal DAG implementation used by pytask.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING +from typing import Any +from typing import cast + +from _pytask.compat import import_optional_dependency + +if TYPE_CHECKING: + from collections.abc import Iterator + + +class NoCycleError(Exception): + """Raised when no cycle is found in a graph.""" + + +class NodeView: + """A minimal mapping-like view over node attributes.""" + + def __init__(self, node_attributes: dict[str, dict[str, Any]]) -> None: + self._node_attributes = node_attributes + + def __getitem__(self, node: str) -> dict[str, Any]: + return self._node_attributes[node] + + def __iter__(self) -> Iterator[str]: + return iter(self._node_attributes) + + def __len__(self) -> int: + return len(self._node_attributes) + + def __contains__(self, node: object) -> bool: + return node in self._node_attributes + + +class UndirectedGraph: + """A minimal undirected graph used for validation tests.""" + + def __init__( + self, + node_attributes: dict[str, dict[str, Any]], + adjacency: dict[str, dict[str, None]], + graph_attributes: dict[str, Any], + ) -> None: + self._node_attributes = { + node: attributes.copy() for node, attributes in node_attributes.items() + } + self._adjacency = { + node: neighbors.copy() for node, neighbors in adjacency.items() + } + self.graph = graph_attributes.copy() + self.nodes = NodeView(self._node_attributes) + + def is_directed(self) -> bool: + return False + + +class DiGraph: + """A minimal directed graph tailored to pytask's needs.""" + + def __init__(self) -> None: + self._node_attributes: dict[str, dict[str, Any]] = {} + self._successors: dict[str, dict[str, None]] = {} + self._predecessors: dict[str, dict[str, None]] = {} + self.graph: dict[str, Any] = {} + self.nodes = NodeView(self._node_attributes) + + def add_node(self, node_name: str, **attributes: Any) -> None: + if node_name not in self._node_attributes: + self._node_attributes[node_name] = {} + self._successors[node_name] = {} + self._predecessors[node_name] = {} + self._node_attributes[node_name].update(attributes) + + def add_edge(self, source: str, target: str) -> None: + self.add_node(source) + self.add_node(target) + self._successors[source][target] = None + self._predecessors[target][source] = None + + def successors(self, node: str) -> Iterator[str]: + return iter(self._successors[node]) + + def predecessors(self, node: str) -> Iterator[str]: + return iter(self._predecessors[node]) + + def in_degree(self) -> Iterator[tuple[str, int]]: + for node, predecessors_ in self._predecessors.items(): + yield node, len(predecessors_) + + def remove_nodes_from(self, nodes: list[str] | set[str] | tuple[str, ...]) -> None: + for node in nodes: + if node not in self._node_attributes: + continue + for predecessor in tuple(self._predecessors[node]): + self._successors[predecessor].pop(node, None) + for successor in tuple(self._successors[node]): + self._predecessors[successor].pop(node, None) + del self._node_attributes[node] + del self._successors[node] + del self._predecessors[node] + + def is_directed(self) -> bool: + return True + + def reverse(self) -> DiGraph: + graph = DiGraph() + graph.graph = self.graph.copy() + for node, attributes in self._node_attributes.items(): + graph.add_node(node, **attributes.copy()) + for source, successors in self._successors.items(): + for target in successors: + graph.add_edge(target, source) + return graph + + def relabel_nodes(self, mapping: dict[str, str]) -> DiGraph: + graph = DiGraph() + graph.graph = self.graph.copy() + + new_labels = [mapping.get(node, node) for node in self._node_attributes] + if len(new_labels) != len(set(new_labels)): + msg = "Relabeling nodes requires unique target labels." + raise ValueError(msg) + + for node, attributes in self._node_attributes.items(): + graph.add_node(mapping.get(node, node), **attributes.copy()) + for source, successors in self._successors.items(): + new_source = mapping.get(source, source) + for target in successors: + graph.add_edge(new_source, mapping.get(target, target)) + return graph + + def set_node_attributes(self, values: dict[str, Any], name: str) -> None: + for node, value in values.items(): + if node in self._node_attributes: + self._node_attributes[node][name] = value + + def to_undirected(self) -> UndirectedGraph: + adjacency = { + node: { + **self._predecessors[node], + **self._successors[node], + } + for node in self._node_attributes + } + return UndirectedGraph(self._node_attributes, adjacency, self.graph) + + def to_networkx(self) -> Any: + nx = cast("Any", import_optional_dependency("networkx")) + graph = nx.DiGraph() + graph.graph = self.graph.copy() + for node, attributes in self._node_attributes.items(): + graph.add_node(node, **attributes.copy()) + for source, successors in self._successors.items(): + for target in successors: + graph.add_edge(source, target) + return graph + + +def descendants(dag: DiGraph, node: str) -> set[str]: + """Return all descendants of a node.""" + return _traverse(dag, node, dag.successors) + + +def ancestors(dag: DiGraph, node: str) -> set[str]: + """Return all ancestors of a node.""" + return _traverse(dag, node, dag.predecessors) + + +def _traverse( + _dag: DiGraph, + node: str, + adjacency: Any, +) -> set[str]: + visited: set[str] = set() + stack = list(adjacency(node)) + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + stack.extend(adjacency(current)) + + return visited + + +def find_cycle(dag: DiGraph) -> list[tuple[str, str]]: + """Find one cycle in the graph.""" + visited: set[str] = set() + active: set[str] = set() + path: list[str] = [] + + def _visit(node: str) -> list[tuple[str, str]] | None: + visited.add(node) + active.add(node) + path.append(node) + + for successor in dag.successors(node): + if successor not in visited: + cycle = _visit(successor) + if cycle is not None: + return cycle + elif successor in active: + start = path.index(successor) + cycle_nodes = [*path[start:], successor] + return list(itertools.pairwise(cycle_nodes)) + + active.remove(node) + path.pop() + return None + + for node in dag.nodes: + if node in visited: + continue + cycle = _visit(node) + if cycle is not None: + return cycle + + raise NoCycleError diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index fb7bbfe9..47b248f7 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -7,8 +7,11 @@ from dataclasses import field from typing import TYPE_CHECKING -import networkx as nx - +from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import NoCycleError +from _pytask.dag_graph import ancestors +from _pytask.dag_graph import descendants +from _pytask.dag_graph import find_cycle from _pytask.mark_utils import has_mark if TYPE_CHECKING: @@ -18,37 +21,37 @@ from _pytask.node_protocols import PTask -def descending_tasks(task_name: str, dag: nx.DiGraph) -> Generator[str, None, None]: +def descending_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: """Yield only descending tasks.""" - for descendant in nx.descendants(dag, task_name): + for descendant in descendants(dag, task_name): if "task" in dag.nodes[descendant]: yield descendant def task_and_descending_tasks( - task_name: str, dag: nx.DiGraph + task_name: str, dag: DiGraph ) -> Generator[str, None, None]: """Yield task and descending tasks.""" yield task_name yield from descending_tasks(task_name, dag) -def preceding_tasks(task_name: str, dag: nx.DiGraph) -> Generator[str, None, None]: +def preceding_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: """Yield only preceding tasks.""" - for ancestor in nx.ancestors(dag, task_name): + for ancestor in ancestors(dag, task_name): if "task" in dag.nodes[ancestor]: yield ancestor def task_and_preceding_tasks( - task_name: str, dag: nx.DiGraph + task_name: str, dag: DiGraph ) -> Generator[str, None, None]: """Yield task and preceding tasks.""" yield task_name yield from preceding_tasks(task_name, dag) -def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]: +def node_and_neighbors(dag: DiGraph, node: str) -> Iterable[str]: """Yield node and neighbors which are first degree predecessors and successors. We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a @@ -77,13 +80,13 @@ class TopologicalSorter: """ - dag: nx.DiGraph + dag: DiGraph priorities: dict[str, int] = field(default_factory=dict) _nodes_processing: set[str] = field(default_factory=set) _nodes_done: set[str] = field(default_factory=set) @classmethod - def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: + def from_dag(cls, dag: DiGraph) -> TopologicalSorter: """Instantiate from a DAG.""" cls.check_dag(dag) @@ -93,14 +96,18 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: priorities = _extract_priorities_from_tasks(tasks) task_signatures = {task.signature for task in tasks} - task_dict = {s: nx.ancestors(dag, s) & task_signatures for s in task_signatures} - task_dag = nx.DiGraph(task_dict).reverse() + task_dag = DiGraph() + for signature in task_signatures: + task_dag.add_node(signature) + for signature in task_signatures: + for ancestor_ in ancestors(dag, signature) & task_signatures: + task_dag.add_edge(ancestor_, signature) return cls(dag=task_dag, priorities=priorities) @classmethod def from_dag_and_sorter( - cls, dag: nx.DiGraph, sorter: TopologicalSorter + cls, dag: DiGraph, sorter: TopologicalSorter ) -> TopologicalSorter: """Instantiate a sorter from another sorter and a DAG.""" new_sorter = cls.from_dag(dag) @@ -109,14 +116,14 @@ def from_dag_and_sorter( return new_sorter @staticmethod - def check_dag(dag: nx.DiGraph) -> None: + def check_dag(dag: DiGraph) -> None: if not dag.is_directed(): msg = "Only directed graphs have a topological order." raise ValueError(msg) try: - nx.algorithms.cycles.find_cycle(dag) - except nx.NetworkXNoCycle: + find_cycle(dag) + except NoCycleError: pass else: msg = "The DAG contains cycles." diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 9b3fd8c3..8dd8a1e5 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -7,6 +7,7 @@ import time from typing import TYPE_CHECKING from typing import Any +from typing import cast from rich.text import Text @@ -172,9 +173,11 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): - node = dag.nodes[node_signature].get("task") or dag.nodes[ - node_signature - ].get("node") + node = cast( + "PTask | PNode | PProvisionalNode", + dag.nodes[node_signature].get("task") + or dag.nodes[node_signature].get("node"), + ) # Skip provisional nodes that are products since they do not have a state. if node_signature not in predecessors and isinstance( @@ -182,7 +185,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ): continue - node_state = node.state() + node_state = cast("Any", node).state() if node_signature in predecessors and not node_state: msg = f"{task.name!r} requires missing node {node.name!r}." @@ -196,7 +199,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C # Check if node changed and collect detailed info if in explain mode if session.config["explain"]: has_changed, reason, details = get_node_change_info( - session=session, task=task, node=node, state=node_state + session=session, + task=task, + node=cast("PTask | PNode", node), + state=node_state, ) if has_changed: needs_to_be_executed = True @@ -214,7 +220,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C reason_typed: ReasonType = reason # type: ignore[assignment] change_reasons.append( create_change_reason( - node=node, + node=cast("PTask | PNode", node), node_type=node_type, reason=reason_typed, old_hash=details.get("old_hash"), @@ -223,7 +229,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ) else: has_changed = has_node_changed( - session=session, task=task, node=node, state=node_state + session=session, + task=task, + node=cast("PTask | PNode", node), + state=node_state, ) if has_changed: needs_to_be_executed = True diff --git a/src/_pytask/mark/__init__.py b/src/_pytask/mark/__init__.py index 9cee0c2b..2f6f0727 100644 --- a/src/_pytask/mark/__init__.py +++ b/src/_pytask/mark/__init__.py @@ -30,8 +30,7 @@ from collections.abc import Set as AbstractSet from typing import NoReturn - import networkx as nx - + from _pytask.dag_graph import DiGraph from _pytask.node_protocols import PTask @@ -153,7 +152,7 @@ def __call__(self, subname: str) -> bool: return any(subname in name for name in names) -def select_by_keyword(session: Session, dag: nx.DiGraph) -> set[str] | None: +def select_by_keyword(session: Session, dag: DiGraph) -> set[str] | None: """Deselect tests by keywords.""" keywordexpr = session.config["expression"] if not keywordexpr: @@ -208,7 +207,7 @@ def __call__(self, name: str) -> bool: return name in self.own_mark_names -def select_by_mark(session: Session, dag: nx.DiGraph) -> set[str] | None: +def select_by_mark(session: Session, dag: DiGraph) -> set[str] | None: """Deselect tests by marks.""" matchexpr = session.config["marker_expression"] if not matchexpr: @@ -237,7 +236,7 @@ def _deselect_others_with_mark( task.markers.append(mark) -def select_tasks_by_marks_and_expressions(session: Session, dag: nx.DiGraph) -> None: +def select_tasks_by_marks_and_expressions(session: Session, dag: DiGraph) -> None: """Modify the tasks which are executed with expressions and markers.""" remaining = select_by_keyword(session, dag) if remaining is not None: diff --git a/src/_pytask/session.py b/src/_pytask/session.py index 79f7f06c..09008742 100644 --- a/src/_pytask/session.py +++ b/src/_pytask/session.py @@ -7,9 +7,9 @@ from typing import TYPE_CHECKING from typing import Any -import networkx as nx from pluggy import HookRelay +from _pytask.dag_graph import DiGraph from _pytask.outcomes import ExitCode if TYPE_CHECKING: @@ -51,7 +51,7 @@ class Session: config: dict[str, Any] = field(default_factory=dict) collection_reports: list[CollectionReport] = field(default_factory=list) - dag: nx.DiGraph = field(default_factory=nx.DiGraph) + dag: DiGraph = field(default_factory=DiGraph) hook: HookRelay = field(default_factory=HookRelay) tasks: list[PTask] = field(default_factory=list) dag_report: DagReport | None = None diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 770a6b8c..ccd11988 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from enum import Enum - import networkx as nx + from _pytask.dag_graph import DiGraph __all__ = [ @@ -79,7 +79,7 @@ def parse_paths(x: Path | list[Path]) -> list[Path]: def reduce_names_of_multiple_nodes( - names: list[str], dag: nx.DiGraph, paths: Sequence[Path] + names: Iterable[str], dag: DiGraph, paths: Sequence[Path] ) -> list[str]: """Reduce the names of multiple nodes in the DAG.""" short_names = [] diff --git a/tests/test_dag_command.py b/tests/test_dag_command.py index c9573c83..9b40748b 100644 --- a/tests/test_dag_command.py +++ b/tests/test_dag_command.py @@ -120,7 +120,11 @@ def task_example(path=Path("input.txt")): ... monkeypatch.setattr( "_pytask.compat.import_module", - lambda x: _raise_exc(ImportError("pygraphviz not found")), # noqa: ARG005 + lambda x: ( + _raise_exc(ImportError("pygraphviz not found")) + if x == "pygraphviz" + else importlib.import_module(x) + ), ) result = runner.invoke( @@ -136,6 +140,33 @@ def task_example(path=Path("input.txt")): ... assert not tmp_path.joinpath("dag.png").exists() +def test_raise_error_with_graph_via_cli_missing_networkx(monkeypatch, tmp_path, runner): + source = """ + from pathlib import Path + + def task_example(path=Path("input.txt")): ... + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("input.txt").touch() + + monkeypatch.setattr( + "_pytask.compat.import_module", + lambda x: ( + _raise_exc(ImportError("networkx not found")) if x == "networkx" else None + ), + ) + + result = runner.invoke( + cli, + ["dag", tmp_path.as_posix(), "-o", tmp_path.joinpath("dag.png"), "-l", "dot"], + ) + + assert result.exit_code == ExitCode.FAILED + assert "pytask requires the optional dependency 'networkx'." in result.output + assert "Traceback" not in result.output + assert not tmp_path.joinpath("dag.png").exists() + + def test_raise_error_with_graph_via_task_missing_optional_dependency( monkeypatch, tmp_path, runner ): @@ -154,7 +185,11 @@ def task_create_graph(): monkeypatch.setattr( "_pytask.compat.import_module", - lambda x: _raise_exc(ImportError("pygraphviz not found")), # noqa: ARG005 + lambda x: ( + _raise_exc(ImportError("pygraphviz not found")) + if x == "pygraphviz" + else importlib.import_module(x) + ), ) result = runner.invoke(cli, [tmp_path.as_posix()]) @@ -167,6 +202,39 @@ def task_create_graph(): assert not tmp_path.joinpath("dag.png").exists() +def test_raise_error_with_graph_via_task_missing_networkx( + monkeypatch, tmp_path, runner +): + source = """ + import pytask + from pathlib import Path + import networkx as nx + + def task_create_graph(): + dag = pytask.build_dag({"paths": Path(__file__).parent}) + graph = nx.nx_agraph.to_agraph(dag) + path = Path(__file__).parent.joinpath("dag.png") + graph.draw(path, prog="dot") + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + monkeypatch.setattr( + "_pytask.compat.import_module", + lambda x: ( + _raise_exc(ImportError("networkx not found")) + if x == "networkx" + else importlib.import_module(x) + ), + ) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + + assert result.exit_code == ExitCode.FAILED + assert "pytask requires the optional dependency 'networkx'." in result.output + assert "Traceback" in result.output + assert not tmp_path.joinpath("dag.png").exists() + + def test_raise_error_with_graph_via_cli_missing_optional_program( monkeypatch, tmp_path, runner ): diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index d11f2af2..db6be94f 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -3,9 +3,9 @@ from contextlib import ExitStack as does_not_raise # noqa: N813 from pathlib import Path -import networkx as nx import pytest +from _pytask.dag_graph import DiGraph from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import _extract_priorities_from_tasks from _pytask.dag_utils import descending_tasks @@ -19,7 +19,7 @@ @pytest.fixture def dag(): """Create a dag with five nodes in a line.""" - dag = nx.DiGraph() + dag = DiGraph() for i in range(4): task = Task(base_name=str(i), path=Path(), function=noop) next_task = Task(base_name=str(i + 1), path=Path(), function=noop) diff --git a/uv.lock b/uv.lock index 1df51b41..f9dae641 100644 --- a/uv.lock +++ b/uv.lock @@ -1279,6 +1279,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/db/b4c12cff13ebac2786f4f217f06588bccd8b53d260453404ef22b121fc3a/greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be", size = 268977, upload-time = "2025-06-05T16:10:24.001Z" }, { url = "https://files.pythonhosted.org/packages/52/61/75b4abd8147f13f70986df2801bf93735c1bd87ea780d70e3b3ecda8c165/greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac", size = 627351, upload-time = "2025-06-05T16:38:50.685Z" }, { url = "https://files.pythonhosted.org/packages/35/aa/6894ae299d059d26254779a5088632874b80ee8cf89a88bca00b0709d22f/greenlet-3.2.3-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a433dbc54e4a37e4fff90ef34f25a8c00aed99b06856f0119dcf09fbafa16392", size = 638599, upload-time = "2025-06-05T16:41:34.057Z" }, + { url = "https://files.pythonhosted.org/packages/30/64/e01a8261d13c47f3c082519a5e9dbf9e143cc0498ed20c911d04e54d526c/greenlet-3.2.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:72e77ed69312bab0434d7292316d5afd6896192ac4327d44f3d613ecb85b037c", size = 634482, upload-time = "2025-06-05T16:48:16.26Z" }, { url = "https://files.pythonhosted.org/packages/47/48/ff9ca8ba9772d083a4f5221f7b4f0ebe8978131a9ae0909cf202f94cd879/greenlet-3.2.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:68671180e3849b963649254a882cd544a3c75bfcd2c527346ad8bb53494444db", size = 633284, upload-time = "2025-06-05T16:13:01.599Z" }, { url = "https://files.pythonhosted.org/packages/e9/45/626e974948713bc15775b696adb3eb0bd708bec267d6d2d5c47bb47a6119/greenlet-3.2.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49c8cfb18fb419b3d08e011228ef8a25882397f3a859b9fe1436946140b6756b", size = 582206, upload-time = "2025-06-05T16:12:48.51Z" }, { url = "https://files.pythonhosted.org/packages/b1/8e/8b6f42c67d5df7db35b8c55c9a850ea045219741bb14416255616808c690/greenlet-3.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:efc6dc8a792243c31f2f5674b670b3a95d46fa1c6a912b8e310d6f542e7b0712", size = 1111412, upload-time = "2025-06-05T16:36:45.479Z" }, @@ -1287,6 +1288,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/2e/d4fcb2978f826358b673f779f78fa8a32ee37df11920dc2bb5589cbeecef/greenlet-3.2.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:784ae58bba89fa1fa5733d170d42486580cab9decda3484779f4759345b29822", size = 270219, upload-time = "2025-06-05T16:10:10.414Z" }, { url = "https://files.pythonhosted.org/packages/16/24/929f853e0202130e4fe163bc1d05a671ce8dcd604f790e14896adac43a52/greenlet-3.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0921ac4ea42a5315d3446120ad48f90c3a6b9bb93dd9b3cf4e4d84a66e42de83", size = 630383, upload-time = "2025-06-05T16:38:51.785Z" }, { url = "https://files.pythonhosted.org/packages/d1/b2/0320715eb61ae70c25ceca2f1d5ae620477d246692d9cc284c13242ec31c/greenlet-3.2.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d2971d93bb99e05f8c2c0c2f4aa9484a18d98c4c3bd3c62b65b7e6ae33dfcfaf", size = 642422, upload-time = "2025-06-05T16:41:35.259Z" }, + { url = "https://files.pythonhosted.org/packages/bd/49/445fd1a210f4747fedf77615d941444349c6a3a4a1135bba9701337cd966/greenlet-3.2.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c667c0bf9d406b77a15c924ef3285e1e05250948001220368e039b6aa5b5034b", size = 638375, upload-time = "2025-06-05T16:48:18.235Z" }, { url = "https://files.pythonhosted.org/packages/7e/c8/ca19760cf6eae75fa8dc32b487e963d863b3ee04a7637da77b616703bc37/greenlet-3.2.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:592c12fb1165be74592f5de0d70f82bc5ba552ac44800d632214b76089945147", size = 637627, upload-time = "2025-06-05T16:13:02.858Z" }, { url = "https://files.pythonhosted.org/packages/65/89/77acf9e3da38e9bcfca881e43b02ed467c1dedc387021fc4d9bd9928afb8/greenlet-3.2.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29e184536ba333003540790ba29829ac14bb645514fbd7e32af331e8202a62a5", size = 585502, upload-time = "2025-06-05T16:12:49.642Z" }, { url = "https://files.pythonhosted.org/packages/97/c6/ae244d7c95b23b7130136e07a9cc5aadd60d59b5951180dc7dc7e8edaba7/greenlet-3.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93c0bb79844a367782ec4f429d07589417052e621aa39a5ac1fb99c5aa308edc", size = 1114498, upload-time = "2025-06-05T16:36:46.598Z" }, @@ -1295,6 +1297,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f3/94/ad0d435f7c48debe960c53b8f60fb41c2026b1d0fa4a99a1cb17c3461e09/greenlet-3.2.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:25ad29caed5783d4bd7a85c9251c651696164622494c00802a139c00d639242d", size = 271992, upload-time = "2025-06-05T16:11:23.467Z" }, { url = "https://files.pythonhosted.org/packages/93/5d/7c27cf4d003d6e77749d299c7c8f5fd50b4f251647b5c2e97e1f20da0ab5/greenlet-3.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88cd97bf37fe24a6710ec6a3a7799f3f81d9cd33317dcf565ff9950c83f55e0b", size = 638820, upload-time = "2025-06-05T16:38:52.882Z" }, { url = "https://files.pythonhosted.org/packages/c6/7e/807e1e9be07a125bb4c169144937910bf59b9d2f6d931578e57f0bce0ae2/greenlet-3.2.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:baeedccca94880d2f5666b4fa16fc20ef50ba1ee353ee2d7092b383a243b0b0d", size = 653046, upload-time = "2025-06-05T16:41:36.343Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ab/158c1a4ea1068bdbc78dba5a3de57e4c7aeb4e7fa034320ea94c688bfb61/greenlet-3.2.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:be52af4b6292baecfa0f397f3edb3c6092ce071b499dd6fe292c9ac9f2c8f264", size = 647701, upload-time = "2025-06-05T16:48:19.604Z" }, { url = "https://files.pythonhosted.org/packages/cc/0d/93729068259b550d6a0288da4ff72b86ed05626eaf1eb7c0d3466a2571de/greenlet-3.2.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0cc73378150b8b78b0c9fe2ce56e166695e67478550769536a6742dca3651688", size = 649747, upload-time = "2025-06-05T16:13:04.628Z" }, { url = "https://files.pythonhosted.org/packages/f6/f6/c82ac1851c60851302d8581680573245c8fc300253fc1ff741ae74a6c24d/greenlet-3.2.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:706d016a03e78df129f68c4c9b4c4f963f7d73534e48a24f5f5a7101ed13dbbb", size = 605461, upload-time = "2025-06-05T16:12:50.792Z" }, { url = "https://files.pythonhosted.org/packages/98/82/d022cf25ca39cf1200650fc58c52af32c90f80479c25d1cbf57980ec3065/greenlet-3.2.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:419e60f80709510c343c57b4bb5a339d8767bf9aef9b8ce43f4f143240f88b7c", size = 1121190, upload-time = "2025-06-05T16:36:48.59Z" }, @@ -1303,6 +1306,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/cf/f5c0b23309070ae93de75c90d29300751a5aacefc0a3ed1b1d8edb28f08b/greenlet-3.2.3-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:500b8689aa9dd1ab26872a34084503aeddefcb438e2e7317b89b11eaea1901ad", size = 270732, upload-time = "2025-06-05T16:10:08.26Z" }, { url = "https://files.pythonhosted.org/packages/48/ae/91a957ba60482d3fecf9be49bc3948f341d706b52ddb9d83a70d42abd498/greenlet-3.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a07d3472c2a93117af3b0136f246b2833fdc0b542d4a9799ae5f41c28323faef", size = 639033, upload-time = "2025-06-05T16:38:53.983Z" }, { url = "https://files.pythonhosted.org/packages/6f/df/20ffa66dd5a7a7beffa6451bdb7400d66251374ab40b99981478c69a67a8/greenlet-3.2.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:8704b3768d2f51150626962f4b9a9e4a17d2e37c8a8d9867bbd9fa4eb938d3b3", size = 652999, upload-time = "2025-06-05T16:41:37.89Z" }, + { url = "https://files.pythonhosted.org/packages/51/b4/ebb2c8cb41e521f1d72bf0465f2f9a2fd803f674a88db228887e6847077e/greenlet-3.2.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5035d77a27b7c62db6cf41cf786cfe2242644a7a337a0e155c80960598baab95", size = 647368, upload-time = "2025-06-05T16:48:21.467Z" }, { url = "https://files.pythonhosted.org/packages/8e/6a/1e1b5aa10dced4ae876a322155705257748108b7fd2e4fae3f2a091fe81a/greenlet-3.2.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2d8aa5423cd4a396792f6d4580f88bdc6efcb9205891c9d40d20f6e670992efb", size = 650037, upload-time = "2025-06-05T16:13:06.402Z" }, { url = "https://files.pythonhosted.org/packages/26/f2/ad51331a157c7015c675702e2d5230c243695c788f8f75feba1af32b3617/greenlet-3.2.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2c724620a101f8170065d7dded3f962a2aea7a7dae133a009cada42847e04a7b", size = 608402, upload-time = "2025-06-05T16:12:51.91Z" }, { url = "https://files.pythonhosted.org/packages/26/bc/862bd2083e6b3aff23300900a956f4ea9a4059de337f5c8734346b9b34fc/greenlet-3.2.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:873abe55f134c48e1f2a6f53f7d1419192a3d1a4e873bace00499a4e45ea6af0", size = 1119577, upload-time = "2025-06-05T16:36:49.787Z" }, @@ -1311,6 +1315,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/ca/accd7aa5280eb92b70ed9e8f7fd79dc50a2c21d8c73b9a0856f5b564e222/greenlet-3.2.3-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:3d04332dddb10b4a211b68111dabaee2e1a073663d117dc10247b5b1642bac86", size = 271479, upload-time = "2025-06-05T16:10:47.525Z" }, { url = "https://files.pythonhosted.org/packages/55/71/01ed9895d9eb49223280ecc98a557585edfa56b3d0e965b9fa9f7f06b6d9/greenlet-3.2.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8186162dffde068a465deab08fc72c767196895c39db26ab1c17c0b77a6d8b97", size = 683952, upload-time = "2025-06-05T16:38:55.125Z" }, { url = "https://files.pythonhosted.org/packages/ea/61/638c4bdf460c3c678a0a1ef4c200f347dff80719597e53b5edb2fb27ab54/greenlet-3.2.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f4bfbaa6096b1b7a200024784217defedf46a07c2eee1a498e94a1b5f8ec5728", size = 696917, upload-time = "2025-06-05T16:41:38.959Z" }, + { url = "https://files.pythonhosted.org/packages/22/cc/0bd1a7eb759d1f3e3cc2d1bc0f0b487ad3cc9f34d74da4b80f226fde4ec3/greenlet-3.2.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:ed6cfa9200484d234d8394c70f5492f144b20d4533f69262d530a1a082f6ee9a", size = 692443, upload-time = "2025-06-05T16:48:23.113Z" }, { url = "https://files.pythonhosted.org/packages/67/10/b2a4b63d3f08362662e89c103f7fe28894a51ae0bc890fabf37d1d780e52/greenlet-3.2.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:02b0df6f63cd15012bed5401b47829cfd2e97052dc89da3cfaf2c779124eb892", size = 692995, upload-time = "2025-06-05T16:13:07.972Z" }, { url = "https://files.pythonhosted.org/packages/5a/c6/ad82f148a4e3ce9564056453a71529732baf5448ad53fc323e37efe34f66/greenlet-3.2.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86c2d68e87107c1792e2e8d5399acec2487a4e993ab76c792408e59394d52141", size = 655320, upload-time = "2025-06-05T16:12:53.453Z" }, { url = "https://files.pythonhosted.org/packages/5c/4f/aab73ecaa6b3086a4c89863d94cf26fa84cbff63f52ce9bc4342b3087a06/greenlet-3.2.3-cp314-cp314-win_amd64.whl", hash = "sha256:8c47aae8fbbfcf82cc13327ae802ba13c9c36753b67e760023fd116bc124a62a", size = 301236, upload-time = "2025-06-05T16:15:20.111Z" }, @@ -1324,6 +1329,7 @@ dependencies = [ { name = "griffecli" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/04/56/28a0accac339c164b52a92c6cfc45a903acc0c174caa5c1713803467b533/griffe-2.0.0.tar.gz", hash = "sha256:c68979cd8395422083a51ea7cf02f9c119d889646d99b7b656ee43725de1b80f", size = 293906, upload-time = "2026-03-23T21:06:53.402Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214, upload-time = "2026-02-09T19:09:44.105Z" }, ] @@ -1336,6 +1342,7 @@ dependencies = [ { name = "colorama" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/a4/f8/2e129fd4a86e52e58eefe664de05e7d502decf766e7316cc9e70fdec3e18/griffecli-2.0.0.tar.gz", hash = "sha256:312fa5ebb4ce6afc786356e2d0ce85b06c1c20d45abc42d74f0cda65e159f6ef", size = 56213, upload-time = "2026-03-23T21:06:54.8Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345, upload-time = "2026-02-09T19:09:42.554Z" }, ] @@ -1344,6 +1351,7 @@ wheels = [ name = "griffelib" version = "2.0.0" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/06/eccbd311c9e2b3ca45dbc063b93134c57a1ccc7607c5e545264ad092c4a9/griffelib-2.0.0.tar.gz", hash = "sha256:e504d637a089f5cab9b5daf18f7645970509bf4f53eda8d79ed71cce8bd97934", size = 166312, upload-time = "2026-03-23T21:06:55.954Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, ] @@ -2876,8 +2884,6 @@ dependencies = [ { name = "click" }, { name = "click-default-group" }, { name = "msgspec", extra = ["toml"] }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "optree" }, { name = "packaging" }, { name = "pluggy" }, @@ -2888,6 +2894,12 @@ dependencies = [ { name = "universal-pathlib" }, ] +[package.optional-dependencies] +dag = [ + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] + [package.dev-dependencies] docs = [ { name = "furo" }, @@ -2895,6 +2907,8 @@ docs = [ { name = "ipywidgets" }, { name = "matplotlib" }, { name = "mkdocstrings", extra = ["python"] }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "zensical" }, ] docs-live = [ @@ -2912,6 +2926,8 @@ test = [ { name = "coiled", marker = "python_full_version < '3.14'" }, { name = "deepdiff" }, { name = "nbmake", marker = "python_full_version < '3.14' or sys_platform != 'win32'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pexpect" }, { name = "pygments" }, { name = "pygraphviz", marker = "sys_platform == 'linux'" }, @@ -2930,7 +2946,7 @@ requires-dist = [ { name = "click-default-group", specifier = ">=1.2.4" }, { name = "msgspec", specifier = ">=0.18.6" }, { name = "msgspec", extras = ["toml"], specifier = ">=0.18.6" }, - { name = "networkx", specifier = ">=2.4.0" }, + { name = "networkx", marker = "extra == 'dag'", specifier = ">=2.4.0" }, { name = "optree", specifier = ">=0.9.0" }, { name = "packaging", specifier = ">=23.0.0" }, { name = "pluggy", specifier = ">=1.3.0" }, @@ -2940,6 +2956,7 @@ requires-dist = [ { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.8.0" }, { name = "universal-pathlib", specifier = ">=0.2.2" }, ] +provides-extras = ["dag"] [package.metadata.requires-dev] docs = [ @@ -2948,6 +2965,7 @@ docs = [ { name = "ipywidgets", specifier = ">=8.1.6" }, { name = "matplotlib", specifier = ">=3.5.0" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.30.0" }, + { name = "networkx", specifier = ">=2.4.0" }, { name = "zensical", specifier = ">=0.0.23" }, ] docs-live = [{ name = "sphinx-autobuild", specifier = ">=2024.10.3" }] @@ -2962,6 +2980,7 @@ test = [ { name = "coiled", marker = "python_full_version < '3.14'", specifier = ">=1.42.0" }, { name = "deepdiff", specifier = ">=7.0.0" }, { name = "nbmake", marker = "python_full_version < '3.14' or sys_platform != 'win32'", specifier = ">=1.5.5" }, + { name = "networkx", specifier = ">=2.4.0" }, { name = "pexpect", specifier = ">=4.9.0" }, { name = "pygments", specifier = ">=2.18.0" }, { name = "pygraphviz", marker = "sys_platform == 'linux'", specifier = ">=1.12" }, From 4b73493bf72e43a85972b333575aa7926f0cb1d8 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 10:56:37 +0100 Subject: [PATCH 02/13] Clarify DAG scheduling and docs --- docs/source/tutorials/visualizing_the_dag.md | 17 ++++++++++--- src/_pytask/dag_utils.py | 3 +++ src/_pytask/execute.py | 26 +++++++++++++------- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/docs/source/tutorials/visualizing_the_dag.md b/docs/source/tutorials/visualizing_the_dag.md index 297db87e..1089178d 100644 --- a/docs/source/tutorials/visualizing_the_dag.md +++ b/docs/source/tutorials/visualizing_the_dag.md @@ -3,11 +3,20 @@ To visualize the [DAG](../glossary.md#dag) of the project, first, install [networkx](https://networkx.org/), [pygraphviz](https://github.com/pygraphviz/pygraphviz), and -[graphviz](https://graphviz.org/). For example, you can install them with pixi +[graphviz](https://graphviz.org/). -```console -$ pixi add networkx pygraphviz graphviz -``` +=== "uv" + + ```console + $ uv add networkx + $ uv add --optional dag pygraphviz + ``` + +=== "pixi" + + ```console + $ pixi add networkx pygraphviz graphviz + ``` After that, pytask offers two interfaces to visualize your project's `DAG`. diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 47b248f7..130db30e 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -100,6 +100,9 @@ def from_dag(cls, dag: DiGraph) -> TopologicalSorter: for signature in task_signatures: task_dag.add_node(signature) for signature in task_signatures: + # The scheduler graph uses edges from predecessor -> successor so that + # zero in-degree means "ready to run". This is the same orientation the + # previous networkx-based implementation reached after calling reverse(). for ancestor_ in ancestors(dag, signature) & task_signatures: task_dag.add_edge(ancestor_, signature) diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 8dd8a1e5..8f0e94a2 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -173,11 +173,9 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): - node = cast( - "PTask | PNode | PProvisionalNode", - dag.nodes[node_signature].get("task") - or dag.nodes[node_signature].get("node"), - ) + node = dag.nodes[node_signature].get("task") or dag.nodes[ + node_signature + ].get("node") # Skip provisional nodes that are products since they do not have a state. if node_signature not in predecessors and isinstance( @@ -185,7 +183,17 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ): continue - node_state = cast("Any", node).state() + # Provisional dependencies should have been resolved before task setup. + if isinstance(node, PProvisionalNode): + msg = ( + f"Task {task.name!r} still references provisional node " + f"{node.name!r} during execution setup." + ) + raise ExecutionError(msg) + + node = cast("PTask | PNode", node) + + node_state = node.state() if node_signature in predecessors and not node_state: msg = f"{task.name!r} requires missing node {node.name!r}." @@ -201,7 +209,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C has_changed, reason, details = get_node_change_info( session=session, task=task, - node=cast("PTask | PNode", node), + node=node, state=node_state, ) if has_changed: @@ -220,7 +228,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C reason_typed: ReasonType = reason # type: ignore[assignment] change_reasons.append( create_change_reason( - node=cast("PTask | PNode", node), + node=node, node_type=node_type, reason=reason_typed, old_hash=details.get("old_hash"), @@ -231,7 +239,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C has_changed = has_node_changed( session=session, task=task, - node=cast("PTask | PNode", node), + node=node, state=node_state, ) if has_changed: From 5099f820b7c0d4a16930ea833c0a93afed41685e Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 12:01:31 +0100 Subject: [PATCH 03/13] Simplify internal DAG types --- src/_pytask/dag_graph.py | 68 +++++++--------------------------------- tests/test_dag_utils.py | 12 +++++-- 2 files changed, 20 insertions(+), 60 deletions(-) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 8fa3d8c6..95f7cb6e 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -3,6 +3,8 @@ from __future__ import annotations import itertools +from dataclasses import dataclass +from dataclasses import field from typing import TYPE_CHECKING from typing import Any from typing import cast @@ -17,56 +19,18 @@ class NoCycleError(Exception): """Raised when no cycle is found in a graph.""" -class NodeView: - """A minimal mapping-like view over node attributes.""" - - def __init__(self, node_attributes: dict[str, dict[str, Any]]) -> None: - self._node_attributes = node_attributes - - def __getitem__(self, node: str) -> dict[str, Any]: - return self._node_attributes[node] - - def __iter__(self) -> Iterator[str]: - return iter(self._node_attributes) - - def __len__(self) -> int: - return len(self._node_attributes) - - def __contains__(self, node: object) -> bool: - return node in self._node_attributes - - -class UndirectedGraph: - """A minimal undirected graph used for validation tests.""" - - def __init__( - self, - node_attributes: dict[str, dict[str, Any]], - adjacency: dict[str, dict[str, None]], - graph_attributes: dict[str, Any], - ) -> None: - self._node_attributes = { - node: attributes.copy() for node, attributes in node_attributes.items() - } - self._adjacency = { - node: neighbors.copy() for node, neighbors in adjacency.items() - } - self.graph = graph_attributes.copy() - self.nodes = NodeView(self._node_attributes) - - def is_directed(self) -> bool: - return False - - +@dataclass class DiGraph: """A minimal directed graph tailored to pytask's needs.""" - def __init__(self) -> None: - self._node_attributes: dict[str, dict[str, Any]] = {} - self._successors: dict[str, dict[str, None]] = {} - self._predecessors: dict[str, dict[str, None]] = {} - self.graph: dict[str, Any] = {} - self.nodes = NodeView(self._node_attributes) + _node_attributes: dict[str, dict[str, Any]] = field(default_factory=dict) + _successors: dict[str, dict[str, None]] = field(default_factory=dict) + _predecessors: dict[str, dict[str, None]] = field(default_factory=dict) + graph: dict[str, Any] = field(default_factory=dict) + + @property + def nodes(self) -> dict[str, dict[str, Any]]: + return self._node_attributes def add_node(self, node_name: str, **attributes: Any) -> None: if node_name not in self._node_attributes: @@ -138,16 +102,6 @@ def set_node_attributes(self, values: dict[str, Any], name: str) -> None: if node in self._node_attributes: self._node_attributes[node][name] = value - def to_undirected(self) -> UndirectedGraph: - adjacency = { - node: { - **self._predecessors[node], - **self._successors[node], - } - for node in self._node_attributes - } - return UndirectedGraph(self._node_attributes, adjacency, self.graph) - def to_networkx(self) -> Any: nx = cast("Any", import_optional_dependency("networkx")) graph = nx.DiGraph() diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index db6be94f..5f02aeda 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise # noqa: N813 +from dataclasses import dataclass from pathlib import Path import pytest @@ -143,10 +144,15 @@ def test_extract_priorities_from_tasks(tasks, expectation, expected): assert result == expected -def test_raise_error_for_undirected_graphs(dag): - undirected_graph = dag.to_undirected() +@dataclass +class _UndirectedGraphStub: + def is_directed(self): + return False + + +def test_raise_error_for_undirected_graphs(): with pytest.raises(ValueError, match="Only directed graphs have a"): - TopologicalSorter.from_dag(undirected_graph) + TopologicalSorter.from_dag(_UndirectedGraphStub()) # type: ignore[arg-type] def test_raise_error_for_cycle_in_graph(dag): From 13a3466b0c26ca7e7e0f61e8c2d70e1e3a7be17a Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 15:43:35 +0100 Subject: [PATCH 04/13] Add changelog note for internal DAG refactor --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4de47bb9..32f0f762 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and ## Unreleased +- [#830](https://github.com/pytask-dev/pytask/pull/830) replaces the internal + `networkx` dependency with a pytask-owned DAG implementation, lazy-loads + `networkx` only for DAG export and visualization, and makes the `networkx` + dependency optional for core builds. - [#822](https://github.com/pytask-dev/pytask/pull/822) fixes unstable signatures for remote `UPath`-backed `PathNode`s and `PickleNode`s so unchanged remote inputs are no longer reported as missing from the state database on subsequent runs. From 6028419f9bb00c211fd4ad6b718fb172dec49c72 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 09:10:43 +0100 Subject: [PATCH 05/13] Move DAG styling to networkx export --- src/_pytask/dag_command.py | 23 +++++++++++------------ src/_pytask/dag_graph.py | 13 +++---------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index e0497536..d4369d2c 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -106,7 +106,7 @@ def dag(**raw_config: Any) -> int: ) session.hook.pytask_collect(session=session) session.dag = create_dag(session=session) - dag = _refine_dag(session).to_networkx() + dag = _to_visualization_graph(session) _write_graph(dag, session.config["output_path"], session.config["layout"]) except CollectionError: # pragma: no cover @@ -179,16 +179,22 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: session.hook.pytask_collect(session=session) session.dag = create_dag(session=session) session.hook.pytask_unconfigure(session=session) - return _refine_dag(session).to_networkx() + return _to_visualization_graph(session) def _refine_dag(session: Session) -> DiGraph: """Refine the dag for plotting.""" dag = _shorten_node_labels(session.dag, session.config["paths"]) - dag = _clean_dag(dag) - dag = _style_dag(dag) - dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name} + return _clean_dag(dag) + +def _to_visualization_graph(session: Session) -> nx.DiGraph: + """Convert the internal DAG to a styled networkx graph for visualization.""" + nx = cast("Any", import_optional_dependency("networkx")) + dag = _refine_dag(session).to_networkx() + dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name} + shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes} + nx.set_node_attributes(dag, shapes, "shape") return dag @@ -208,13 +214,6 @@ def _clean_dag(dag: DiGraph) -> DiGraph: return dag -def _style_dag(dag: DiGraph) -> DiGraph: - """Style the DAG.""" - shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes} - dag.set_node_attributes(shapes, "shape") - return dag - - def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None: """Write the graph to disk.""" nx = cast("Any", import_optional_dependency("networkx")) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 95f7cb6e..56bb7e0c 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -12,6 +12,8 @@ from _pytask.compat import import_optional_dependency if TYPE_CHECKING: + from collections.abc import Callable + from collections.abc import Iterable from collections.abc import Iterator @@ -26,7 +28,6 @@ class DiGraph: _node_attributes: dict[str, dict[str, Any]] = field(default_factory=dict) _successors: dict[str, dict[str, None]] = field(default_factory=dict) _predecessors: dict[str, dict[str, None]] = field(default_factory=dict) - graph: dict[str, Any] = field(default_factory=dict) @property def nodes(self) -> dict[str, dict[str, Any]]: @@ -72,7 +73,6 @@ def is_directed(self) -> bool: def reverse(self) -> DiGraph: graph = DiGraph() - graph.graph = self.graph.copy() for node, attributes in self._node_attributes.items(): graph.add_node(node, **attributes.copy()) for source, successors in self._successors.items(): @@ -82,7 +82,6 @@ def reverse(self) -> DiGraph: def relabel_nodes(self, mapping: dict[str, str]) -> DiGraph: graph = DiGraph() - graph.graph = self.graph.copy() new_labels = [mapping.get(node, node) for node in self._node_attributes] if len(new_labels) != len(set(new_labels)): @@ -97,15 +96,9 @@ def relabel_nodes(self, mapping: dict[str, str]) -> DiGraph: graph.add_edge(new_source, mapping.get(target, target)) return graph - def set_node_attributes(self, values: dict[str, Any], name: str) -> None: - for node, value in values.items(): - if node in self._node_attributes: - self._node_attributes[node][name] = value - def to_networkx(self) -> Any: nx = cast("Any", import_optional_dependency("networkx")) graph = nx.DiGraph() - graph.graph = self.graph.copy() for node, attributes in self._node_attributes.items(): graph.add_node(node, **attributes.copy()) for source, successors in self._successors.items(): @@ -127,7 +120,7 @@ def ancestors(dag: DiGraph, node: str) -> set[str]: def _traverse( _dag: DiGraph, node: str, - adjacency: Any, + adjacency: Callable[[str], Iterable[str]], ) -> set[str]: visited: set[str] = set() stack = list(adjacency(node)) From 7c7fef8f0306295a7c1b8fd623446316e006671b Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 22:42:17 +0100 Subject: [PATCH 06/13] Make internal DAG generic and typed --- src/_pytask/dag.py | 62 ++++++------- src/_pytask/dag_command.py | 11 +-- src/_pytask/dag_graph.py | 162 +++++++++++++++++++++++----------- src/_pytask/dag_utils.py | 36 ++++---- src/_pytask/database_utils.py | 7 +- src/_pytask/execute.py | 15 ++-- src/_pytask/lockfile.py | 8 +- src/_pytask/mark/__init__.py | 9 +- src/_pytask/persist.py | 25 +++--- src/_pytask/profile.py | 2 +- src/_pytask/session.py | 3 +- src/_pytask/shared.py | 5 +- src/_pytask/skipping.py | 4 +- tests/test_dag_utils.py | 35 ++++---- 14 files changed, 229 insertions(+), 155 deletions(-) diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index 7d6f19b9..d9e281fb 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -16,6 +16,7 @@ from _pytask.console import format_node_name from _pytask.console import format_task_name from _pytask.console import render_to_string +from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph from _pytask.dag_graph import NoCycleError from _pytask.dag_graph import find_cycle @@ -39,7 +40,7 @@ __all__ = ["create_dag", "create_dag_from_session"] -def create_dag(session: Session) -> DiGraph: +def create_dag(session: Session) -> DiGraph[str, DagNode]: """Create a directed acyclic graph (DAG) for the workflow.""" try: dag = create_dag_from_session(session) @@ -52,7 +53,7 @@ def create_dag(session: Session) -> DiGraph: return dag -def create_dag_from_session(session: Session) -> DiGraph: +def create_dag_from_session(session: Session) -> DiGraph[str, DagNode]: """Create a DAG from a session.""" dag = _create_dag_from_tasks(tasks=session.tasks) _check_if_dag_has_cycles(dag) @@ -62,14 +63,20 @@ def create_dag_from_session(session: Session) -> DiGraph: return dag -def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph: +def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph[str, DagNode]: """Create the DAG from tasks, dependencies and products.""" + def _add_node_data( + dag: DiGraph[str, DagNode], node: PNode | PProvisionalNode + ) -> None: + dag.add_node(node.signature, DagNode.from_node(node)) + if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): + _add_node_data(dag, node.value) + def _add_dependency( - dag: DiGraph, task: PTask, node: PNode | PProvisionalNode + dag: DiGraph[str, DagNode], task: PTask, node: PNode | PProvisionalNode ) -> None: """Add a dependency to the DAG.""" - dag.add_node(node.signature, node=node) dag.add_edge(node.signature, task.signature) # If a node is a PythonNode wrapped in another PythonNode, it is a product from @@ -78,34 +85,26 @@ def _add_dependency( if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): dag.add_edge(node.value.signature, node.signature) - def _add_product(dag: DiGraph, task: PTask, node: PNode | PProvisionalNode) -> None: + def _add_product( + dag: DiGraph[str, DagNode], task: PTask, node: PNode | PProvisionalNode + ) -> None: """Add a product to the DAG.""" - dag.add_node(node.signature, node=node) dag.add_edge(task.signature, node.signature) - dag = DiGraph() + dag = DiGraph[str, DagNode]() for task in tasks: - dag.add_node(task.signature, task=task) + dag.add_node(task.signature, DagNode.from_task(task)) + tree_map(lambda x: _add_node_data(dag, x), task.depends_on) + tree_map(lambda x: _add_node_data(dag, x), task.produces) + for task in tasks: tree_map(lambda x: _add_dependency(dag, task, x), task.depends_on) tree_map(lambda x: _add_product(dag, task, x), task.produces) - - # If a node is a PythonNode wrapped in another PythonNode, it is a product from - # another task that is a dependency in the current task. Thus, draw an edge - # connecting the two nodes. - tree_map( - lambda x: ( - dag.add_edge(x.value.signature, x.signature) - if isinstance(x, PythonNode) and isinstance(x.value, PythonNode) - else None - ), - task.depends_on, - ) return dag -def _modify_dag(session: Session, dag: DiGraph) -> DiGraph: +def _modify_dag(session: Session, dag: DiGraph[str, DagNode]) -> DiGraph[str, DagNode]: """Create dependencies between tasks when using ``@task(after=...)``.""" temporary_id_to_task = { task.attributes["collection_id"]: task @@ -129,7 +128,7 @@ def _modify_dag(session: Session, dag: DiGraph) -> DiGraph: return dag -def _check_if_dag_has_cycles(dag: DiGraph) -> None: +def _check_if_dag_has_cycles(dag: DiGraph[str, DagNode]) -> None: """Check if DAG has cycles.""" try: cycles = find_cycle(dag) @@ -145,7 +144,7 @@ def _check_if_dag_has_cycles(dag: DiGraph) -> None: raise ResolvingDependenciesError(msg) -def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str: +def _format_cycles(dag: DiGraph[str, DagNode], cycles: list[tuple[str, str]]) -> str: """Format cycles as a paths connected by arrows.""" chain = [ x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0 @@ -154,7 +153,7 @@ def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str: lines: list[str] = [] for x in chain: - node = dag.nodes[x].get("task") or dag.nodes[x].get("node") + node = dag.nodes[x].value if isinstance(node, PTask): short_name = format_task_name(node, editor_url_scheme="no_link").plain elif isinstance(node, (PNode, PProvisionalNode)): @@ -176,12 +175,13 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str: return render_to_string(tree, console=console, strip_styles=True) -def _check_if_tasks_have_the_same_products(dag: DiGraph, paths: list[Path]) -> None: +def _check_if_tasks_have_the_same_products( + dag: DiGraph[str, DagNode], paths: list[Path] +) -> None: nodes_created_by_multiple_tasks = [] for node in dag.nodes: - is_node = "node" in dag.nodes[node] - if is_node: + if dag.nodes[node].node is not None: parents = list(dag.predecessors(node)) if len(parents) > 1: nodes_created_by_multiple_tasks.append(node) @@ -189,11 +189,13 @@ def _check_if_tasks_have_the_same_products(dag: DiGraph, paths: list[Path]) -> N if nodes_created_by_multiple_tasks: dictionary = {} for node in nodes_created_by_multiple_tasks: - short_node_name = format_node_name(dag.nodes[node]["node"], paths).plain + short_node_name = format_node_name( + dag.nodes[node].node_or_raise(), paths + ).plain short_predecessors = reduce_names_of_multiple_nodes( dag.predecessors(node), dag, paths ) - dictionary[short_node_name] = short_predecessors + dictionary[short_node_name] = sorted(short_predecessors) text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:") msg = ( f"There are some tasks which produce the same output. See the following " diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index d4369d2c..f1f152c3 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: import networkx as nx + from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph @@ -182,7 +183,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: return _to_visualization_graph(session) -def _refine_dag(session: Session) -> DiGraph: +def _refine_dag(session: Session) -> DiGraph[str, DagNode]: """Refine the dag for plotting.""" dag = _shorten_node_labels(session.dag, session.config["paths"]) return _clean_dag(dag) @@ -198,7 +199,9 @@ def _to_visualization_graph(session: Session) -> nx.DiGraph: return dag -def _shorten_node_labels(dag: DiGraph, paths: list[Path]) -> DiGraph: +def _shorten_node_labels( + dag: DiGraph[str, DagNode], paths: list[Path] +) -> DiGraph[str, DagNode]: """Shorten the node labels in the graph for a better experience.""" node_names = dag.nodes short_names = reduce_names_of_multiple_nodes(node_names, dag, paths) @@ -207,10 +210,8 @@ def _shorten_node_labels(dag: DiGraph, paths: list[Path]) -> DiGraph: return dag.relabel_nodes(old_to_new) -def _clean_dag(dag: DiGraph) -> DiGraph: +def _clean_dag(dag: DiGraph[str, DagNode]) -> DiGraph[str, DagNode]: """Clean the DAG.""" - for node in dag.nodes: - dag.nodes[node].clear() return dag diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 56bb7e0c..7b2d4eab 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -3,10 +3,13 @@ from __future__ import annotations import itertools +from collections.abc import Hashable from dataclasses import dataclass from dataclasses import field from typing import TYPE_CHECKING from typing import Any +from typing import Generic +from typing import TypeVar from typing import cast from _pytask.compat import import_optional_dependency @@ -15,81 +18,136 @@ from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator + from collections.abc import Mapping + + from _pytask.node_protocols import PNode + from _pytask.node_protocols import PProvisionalNode + from _pytask.node_protocols import PTask + + +NodeIdT = TypeVar("NodeIdT", bound=Hashable) +PayloadT = TypeVar("PayloadT") class NoCycleError(Exception): """Raised when no cycle is found in a graph.""" +@dataclass(slots=True) +class DagNode: + """Payload stored for nodes in pytask's internal DAG.""" + + task: PTask | None = None + node: PNode | PProvisionalNode | None = None + + def __post_init__(self) -> None: + if (self.task is None) == (self.node is None): + msg = "A DAG node must store exactly one of 'task' or 'node'." + raise ValueError(msg) + + @classmethod + def from_task(cls, task: PTask) -> DagNode: + """Create a DAG node from a task.""" + return cls(task=task) + + @classmethod + def from_node(cls, node: PNode | PProvisionalNode) -> DagNode: + """Create a DAG node from a dependency or product node.""" + return cls(node=node) + + @property + def value(self) -> PTask | PNode | PProvisionalNode: + """Return the wrapped task or node.""" + if self.task is not None: + return self.task + return cast("PNode | PProvisionalNode", self.node) + + def task_or_raise(self) -> PTask: + """Return the wrapped task.""" + if self.task is None: + msg = "Expected DAG payload to contain a task." + raise TypeError(msg) + return self.task + + def node_or_raise(self) -> PNode | PProvisionalNode: + """Return the wrapped dependency or product node.""" + if self.node is None: + msg = "Expected DAG payload to contain a node." + raise TypeError(msg) + return self.node + + @dataclass -class DiGraph: +class DiGraph(Generic[NodeIdT, PayloadT]): """A minimal directed graph tailored to pytask's needs.""" - _node_attributes: dict[str, dict[str, Any]] = field(default_factory=dict) - _successors: dict[str, dict[str, None]] = field(default_factory=dict) - _predecessors: dict[str, dict[str, None]] = field(default_factory=dict) + _node_data: dict[NodeIdT, PayloadT] = field(default_factory=dict) + _successors: dict[NodeIdT, set[NodeIdT]] = field(default_factory=dict) + _predecessors: dict[NodeIdT, set[NodeIdT]] = field(default_factory=dict) @property - def nodes(self) -> dict[str, dict[str, Any]]: - return self._node_attributes - - def add_node(self, node_name: str, **attributes: Any) -> None: - if node_name not in self._node_attributes: - self._node_attributes[node_name] = {} - self._successors[node_name] = {} - self._predecessors[node_name] = {} - self._node_attributes[node_name].update(attributes) - - def add_edge(self, source: str, target: str) -> None: - self.add_node(source) - self.add_node(target) - self._successors[source][target] = None - self._predecessors[target][source] = None - - def successors(self, node: str) -> Iterator[str]: + def nodes(self) -> dict[NodeIdT, PayloadT]: + return self._node_data + + def add_node(self, node_name: NodeIdT, data: PayloadT) -> None: + if node_name not in self._node_data: + self._successors[node_name] = set() + self._predecessors[node_name] = set() + self._node_data[node_name] = data + + def add_edge(self, source: NodeIdT, target: NodeIdT) -> None: + if source not in self._node_data or target not in self._node_data: + msg = "Both nodes must exist before adding an edge." + raise KeyError(msg) + self._successors[source].add(target) + self._predecessors[target].add(source) + + def successors(self, node: NodeIdT) -> Iterator[NodeIdT]: return iter(self._successors[node]) - def predecessors(self, node: str) -> Iterator[str]: + def predecessors(self, node: NodeIdT) -> Iterator[NodeIdT]: return iter(self._predecessors[node]) - def in_degree(self) -> Iterator[tuple[str, int]]: + def in_degree(self) -> Iterator[tuple[NodeIdT, int]]: for node, predecessors_ in self._predecessors.items(): yield node, len(predecessors_) - def remove_nodes_from(self, nodes: list[str] | set[str] | tuple[str, ...]) -> None: + def remove_nodes_from(self, nodes: Iterable[NodeIdT]) -> None: for node in nodes: - if node not in self._node_attributes: + if node not in self._node_data: continue for predecessor in tuple(self._predecessors[node]): - self._successors[predecessor].pop(node, None) + self._successors[predecessor].discard(node) for successor in tuple(self._successors[node]): - self._predecessors[successor].pop(node, None) - del self._node_attributes[node] + self._predecessors[successor].discard(node) + del self._node_data[node] del self._successors[node] del self._predecessors[node] def is_directed(self) -> bool: return True - def reverse(self) -> DiGraph: - graph = DiGraph() - for node, attributes in self._node_attributes.items(): - graph.add_node(node, **attributes.copy()) + def reverse(self) -> DiGraph[NodeIdT, PayloadT]: + graph = DiGraph[NodeIdT, PayloadT]() + for node, data in self._node_data.items(): + graph.add_node(node, data) for source, successors in self._successors.items(): for target in successors: graph.add_edge(target, source) return graph - def relabel_nodes(self, mapping: dict[str, str]) -> DiGraph: - graph = DiGraph() + def relabel_nodes( + self, mapping: Mapping[NodeIdT, NodeIdT] + ) -> DiGraph[NodeIdT, PayloadT]: + graph = DiGraph[NodeIdT, PayloadT]() - new_labels = [mapping.get(node, node) for node in self._node_attributes] + new_labels = [mapping.get(node, node) for node in self._node_data] if len(new_labels) != len(set(new_labels)): msg = "Relabeling nodes requires unique target labels." raise ValueError(msg) - for node, attributes in self._node_attributes.items(): - graph.add_node(mapping.get(node, node), **attributes.copy()) + for node, data in self._node_data.items(): + graph.add_node(mapping.get(node, node), data) for source, successors in self._successors.items(): new_source = mapping.get(source, source) for target in successors: @@ -99,30 +157,30 @@ def relabel_nodes(self, mapping: dict[str, str]) -> DiGraph: def to_networkx(self) -> Any: nx = cast("Any", import_optional_dependency("networkx")) graph = nx.DiGraph() - for node, attributes in self._node_attributes.items(): - graph.add_node(node, **attributes.copy()) + for node in self._node_data: + graph.add_node(node) for source, successors in self._successors.items(): for target in successors: graph.add_edge(source, target) return graph -def descendants(dag: DiGraph, node: str) -> set[str]: +def descendants(dag: DiGraph[NodeIdT, PayloadT], node: NodeIdT) -> set[NodeIdT]: """Return all descendants of a node.""" return _traverse(dag, node, dag.successors) -def ancestors(dag: DiGraph, node: str) -> set[str]: +def ancestors(dag: DiGraph[NodeIdT, PayloadT], node: NodeIdT) -> set[NodeIdT]: """Return all ancestors of a node.""" return _traverse(dag, node, dag.predecessors) def _traverse( - _dag: DiGraph, - node: str, - adjacency: Callable[[str], Iterable[str]], -) -> set[str]: - visited: set[str] = set() + _dag: DiGraph[NodeIdT, PayloadT], + node: NodeIdT, + adjacency: Callable[[NodeIdT], Iterable[NodeIdT]], +) -> set[NodeIdT]: + visited: set[NodeIdT] = set() stack = list(adjacency(node)) while stack: @@ -135,13 +193,15 @@ def _traverse( return visited -def find_cycle(dag: DiGraph) -> list[tuple[str, str]]: +def find_cycle( + dag: DiGraph[NodeIdT, PayloadT], +) -> list[tuple[NodeIdT, NodeIdT]]: """Find one cycle in the graph.""" - visited: set[str] = set() - active: set[str] = set() - path: list[str] = [] + visited: set[NodeIdT] = set() + active: set[NodeIdT] = set() + path: list[NodeIdT] = [] - def _visit(node: str) -> list[tuple[str, str]] | None: + def _visit(node: NodeIdT) -> list[tuple[NodeIdT, NodeIdT]] | None: visited.add(node) active.add(node) path.append(node) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 130db30e..f5f590d6 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -6,7 +6,9 @@ from dataclasses import dataclass from dataclasses import field from typing import TYPE_CHECKING +from typing import Any +from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph from _pytask.dag_graph import NoCycleError from _pytask.dag_graph import ancestors @@ -21,37 +23,41 @@ from _pytask.node_protocols import PTask -def descending_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: +def descending_tasks( + task_name: str, dag: DiGraph[str, DagNode] +) -> Generator[str, None, None]: """Yield only descending tasks.""" for descendant in descendants(dag, task_name): - if "task" in dag.nodes[descendant]: + if dag.nodes[descendant].task is not None: yield descendant def task_and_descending_tasks( - task_name: str, dag: DiGraph + task_name: str, dag: DiGraph[str, DagNode] ) -> Generator[str, None, None]: """Yield task and descending tasks.""" yield task_name yield from descending_tasks(task_name, dag) -def preceding_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: +def preceding_tasks( + task_name: str, dag: DiGraph[str, DagNode] +) -> Generator[str, None, None]: """Yield only preceding tasks.""" for ancestor in ancestors(dag, task_name): - if "task" in dag.nodes[ancestor]: + if dag.nodes[ancestor].task is not None: yield ancestor def task_and_preceding_tasks( - task_name: str, dag: DiGraph + task_name: str, dag: DiGraph[str, DagNode] ) -> Generator[str, None, None]: """Yield task and preceding tasks.""" yield task_name yield from preceding_tasks(task_name, dag) -def node_and_neighbors(dag: DiGraph, node: str) -> Iterable[str]: +def node_and_neighbors(dag: DiGraph[str, DagNode], node: str) -> Iterable[str]: """Yield node and neighbors which are first degree predecessors and successors. We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a @@ -80,25 +86,23 @@ class TopologicalSorter: """ - dag: DiGraph + dag: DiGraph[str, None] priorities: dict[str, int] = field(default_factory=dict) _nodes_processing: set[str] = field(default_factory=set) _nodes_done: set[str] = field(default_factory=set) @classmethod - def from_dag(cls, dag: DiGraph) -> TopologicalSorter: + def from_dag(cls, dag: DiGraph[str, DagNode]) -> TopologicalSorter: """Instantiate from a DAG.""" cls.check_dag(dag) - tasks = [ - dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] - ] + tasks = [node.task for node in dag.nodes.values() if node.task is not None] priorities = _extract_priorities_from_tasks(tasks) task_signatures = {task.signature for task in tasks} - task_dag = DiGraph() + task_dag = DiGraph[str, None]() for signature in task_signatures: - task_dag.add_node(signature) + task_dag.add_node(signature, None) for signature in task_signatures: # The scheduler graph uses edges from predecessor -> successor so that # zero in-degree means "ready to run". This is the same orientation the @@ -110,7 +114,7 @@ def from_dag(cls, dag: DiGraph) -> TopologicalSorter: @classmethod def from_dag_and_sorter( - cls, dag: DiGraph, sorter: TopologicalSorter + cls, dag: DiGraph[str, DagNode], sorter: TopologicalSorter ) -> TopologicalSorter: """Instantiate a sorter from another sorter and a DAG.""" new_sorter = cls.from_dag(dag) @@ -119,7 +123,7 @@ def from_dag_and_sorter( return new_sorter @staticmethod - def check_dag(dag: DiGraph) -> None: + def check_dag(dag: DiGraph[str, Any]) -> None: if not dag.is_directed(): msg = "Only directed graphs have a topological order." raise ValueError(msg) diff --git a/src/_pytask/database_utils.py b/src/_pytask/database_utils.py index 4db66a59..41a93e5c 100644 --- a/src/_pytask/database_utils.py +++ b/src/_pytask/database_utils.py @@ -15,6 +15,7 @@ from sqlalchemy.orm import sessionmaker from _pytask.dag_utils import node_and_neighbors +from _pytask.node_protocols import PProvisionalNode if TYPE_CHECKING: from sqlalchemy.engine import Engine @@ -115,8 +116,12 @@ def update_states_in_database(session: Session, task_signature: str) -> None: if _ENGINE is None: return for name in node_and_neighbors(session.dag, task_signature): - node = session.dag.nodes[name].get("task") or session.dag.nodes[name]["node"] + node = session.dag.nodes[name].value + if isinstance(node, PProvisionalNode): + continue hash_ = node.state() + if hash_ is None: + continue _create_or_update_state(task_signature, node.signature, hash_) diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 8f0e94a2..0a8b7552 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -7,7 +7,6 @@ import time from typing import TYPE_CHECKING from typing import Any -from typing import cast from rich.text import Text @@ -90,7 +89,7 @@ def pytask_execute_build(session: Session) -> bool | None: if isinstance(session.scheduler, TopologicalSorter): while session.scheduler.is_active(): task_name = session.scheduler.get_ready()[0] - task = session.dag.nodes[task_name]["task"] + task = session.dag.nodes[task_name].task_or_raise() report = session.hook.pytask_execute_task_protocol( session=session, task=task ) @@ -173,9 +172,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): - node = dag.nodes[node_signature].get("task") or dag.nodes[ - node_signature - ].get("node") + node = dag.nodes[node_signature].value # Skip provisional nodes that are products since they do not have a state. if node_signature not in predecessors and isinstance( @@ -191,8 +188,6 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ) raise ExecutionError(msg) - node = cast("PTask | PNode", node) - node_state = node.state() if node_signature in predecessors and not node_state: @@ -258,7 +253,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C # Create directory for product if it does not exist. Maybe this should be a `setup` # method for the node classes. for product in dag.successors(task.signature): - node = dag.nodes[product]["node"] + node = dag.nodes[product].node_or_raise() if isinstance(node, PPathNode): node.path.parent.mkdir(parents=True, exist_ok=True) if isinstance(node, DirectoryNode) and node.root_dir: @@ -352,7 +347,7 @@ def pytask_execute_task_process_report( report.outcome = TaskOutcome.WOULD_BE_EXECUTED for descending_task_name in descending_tasks(task.signature, session.dag): - descending_task = session.dag.nodes[descending_task_name]["task"] + descending_task = session.dag.nodes[descending_task_name].task_or_raise() descending_task.markers.append( Mark( "would_be_executed", @@ -365,7 +360,7 @@ def pytask_execute_task_process_report( ) else: for descending_task_name in descending_tasks(task.signature, session.dag): - descending_task = session.dag.nodes[descending_task_name]["task"] + descending_task = session.dag.nodes[descending_task_name].task_or_raise() descending_task.markers.append( Mark( "skip_ancestor_failed", diff --git a/src/_pytask/lockfile.py b/src/_pytask/lockfile.py index ba18ccbd..0c6bb73c 100644 --- a/src/_pytask/lockfile.py +++ b/src/_pytask/lockfile.py @@ -225,9 +225,7 @@ def _build_task_entry(session: Session, task: PTask, root: Path) -> _TaskEntry | depends_on: dict[str, str] = {} for node_signature in predecessors: - node = ( - dag.nodes[node_signature].get("task") or dag.nodes[node_signature]["node"] - ) + node = dag.nodes[node_signature].value if not isinstance(node, (PNode, PTask)): continue state = node.state() @@ -242,9 +240,7 @@ def _build_task_entry(session: Session, task: PTask, root: Path) -> _TaskEntry | produces: dict[str, str] = {} for node_signature in successors: - node = ( - dag.nodes[node_signature].get("task") or dag.nodes[node_signature]["node"] - ) + node = dag.nodes[node_signature].value if not isinstance(node, (PNode, PTask)): continue state = node.state() diff --git a/src/_pytask/mark/__init__.py b/src/_pytask/mark/__init__.py index 2f6f0727..9a724ce3 100644 --- a/src/_pytask/mark/__init__.py +++ b/src/_pytask/mark/__init__.py @@ -30,6 +30,7 @@ from collections.abc import Set as AbstractSet from typing import NoReturn + from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph from _pytask.node_protocols import PTask @@ -152,7 +153,7 @@ def __call__(self, subname: str) -> bool: return any(subname in name for name in names) -def select_by_keyword(session: Session, dag: DiGraph) -> set[str] | None: +def select_by_keyword(session: Session, dag: DiGraph[str, DagNode]) -> set[str] | None: """Deselect tests by keywords.""" keywordexpr = session.config["expression"] if not keywordexpr: @@ -207,7 +208,7 @@ def __call__(self, name: str) -> bool: return name in self.own_mark_names -def select_by_mark(session: Session, dag: DiGraph) -> set[str] | None: +def select_by_mark(session: Session, dag: DiGraph[str, DagNode]) -> set[str] | None: """Deselect tests by marks.""" matchexpr = session.config["marker_expression"] if not matchexpr: @@ -236,7 +237,9 @@ def _deselect_others_with_mark( task.markers.append(mark) -def select_tasks_by_marks_and_expressions(session: Session, dag: DiGraph) -> None: +def select_tasks_by_marks_and_expressions( + session: Session, dag: DiGraph[str, DagNode] +) -> None: """Modify the tasks which are executed with expressions and markers.""" remaining = select_by_keyword(session, dag) if remaining is not None: diff --git a/src/_pytask/persist.py b/src/_pytask/persist.py index 40d958b3..6b72ffff 100644 --- a/src/_pytask/persist.py +++ b/src/_pytask/persist.py @@ -8,6 +8,7 @@ from _pytask.dag_utils import node_and_neighbors from _pytask.database_utils import update_states_in_database as _db_update_states from _pytask.mark_utils import has_mark +from _pytask.node_protocols import PProvisionalNode from _pytask.outcomes import Persisted from _pytask.outcomes import TaskOutcome from _pytask.pluginmanager import hookimpl @@ -16,6 +17,7 @@ from _pytask.state import update_states if TYPE_CHECKING: + from _pytask.node_protocols import PNode from _pytask.node_protocols import PTask from _pytask.reports import ExecutionReport from _pytask.session import Session @@ -47,12 +49,14 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: """ if has_mark(task, "persist"): - all_states = [ - ( - session.dag.nodes[name].get("task") or session.dag.nodes[name]["node"] - ).state() - for name in node_and_neighbors(session.dag, task.signature) - ] + stateful_nodes: list[tuple[PTask | PNode, str | None]] = [] + for name in node_and_neighbors(session.dag, task.signature): + node = session.dag.nodes[name].value + if isinstance(node, PProvisionalNode): + continue + stateful_nodes.append((node, node.state())) + + all_states = [state for _, state in stateful_nodes] all_nodes_exist = all(all_states) if all_nodes_exist: @@ -60,15 +64,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: has_node_changed( session=session, task=task, - node=session.dag.nodes[name].get("task") - or session.dag.nodes[name]["node"], + node=node, state=state, ) - for name, state in zip( - node_and_neighbors(session.dag, task.signature), - all_states, - strict=False, - ) + for node, state in stateful_nodes ) if any_node_changed: collect_provisional_products(session, task) diff --git a/src/_pytask/profile.py b/src/_pytask/profile.py index 74b50e15..12eba84e 100644 --- a/src/_pytask/profile.py +++ b/src/_pytask/profile.py @@ -210,7 +210,7 @@ def pytask_profile_add_info_on_task( if successors: sum_bytes = 0 for successor in successors: - node = session.dag.nodes[successor]["node"] + node = session.dag.nodes[successor].node_or_raise() if isinstance(node, PPathNode): with suppress(FileNotFoundError): sum_bytes += node.path.stat().st_size diff --git a/src/_pytask/session.py b/src/_pytask/session.py index 09008742..8f1575ca 100644 --- a/src/_pytask/session.py +++ b/src/_pytask/session.py @@ -9,6 +9,7 @@ from pluggy import HookRelay +from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph from _pytask.outcomes import ExitCode @@ -51,7 +52,7 @@ class Session: config: dict[str, Any] = field(default_factory=dict) collection_reports: list[CollectionReport] = field(default_factory=list) - dag: DiGraph = field(default_factory=DiGraph) + dag: DiGraph[str, DagNode] = field(default_factory=DiGraph) hook: HookRelay = field(default_factory=HookRelay) tasks: list[PTask] = field(default_factory=list) dag_report: DagReport | None = None diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index ccd11988..0a3aefee 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from enum import Enum + from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph @@ -79,12 +80,12 @@ def parse_paths(x: Path | list[Path]) -> list[Path]: def reduce_names_of_multiple_nodes( - names: Iterable[str], dag: DiGraph, paths: Sequence[Path] + names: Iterable[str], dag: DiGraph[str, DagNode], paths: Sequence[Path] ) -> list[str]: """Reduce the names of multiple nodes in the DAG.""" short_names = [] for name in names: - node = dag.nodes[name].get("node") or dag.nodes[name].get("task") + node = dag.nodes[name].value if isinstance(node, PTask): short_name = format_task_name(node, editor_url_scheme="no_link").plain diff --git a/src/_pytask/skipping.py b/src/_pytask/skipping.py index a7678154..71c7fcf0 100644 --- a/src/_pytask/skipping.py +++ b/src/_pytask/skipping.py @@ -97,7 +97,9 @@ def pytask_execute_task_process_report( report.outcome = TaskOutcome.SKIP for descending_task_name in descending_tasks(task.signature, session.dag): - descending_task = session.dag.nodes[descending_task_name]["task"] + descending_task = session.dag.nodes[ + descending_task_name + ].task_or_raise() descending_task.markers.append( Mark( "skip", diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index 5f02aeda..ae37d9ee 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -6,6 +6,7 @@ import pytest +from _pytask.dag_graph import DagNode from _pytask.dag_graph import DiGraph from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import _extract_priorities_from_tasks @@ -20,12 +21,12 @@ @pytest.fixture def dag(): """Create a dag with five nodes in a line.""" - dag = DiGraph() + dag = DiGraph[str, DagNode]() for i in range(4): task = Task(base_name=str(i), path=Path(), function=noop) next_task = Task(base_name=str(i + 1), path=Path(), function=noop) - dag.add_node(task.signature, task=task) - dag.add_node(next_task.signature, task=next_task) + dag.add_node(task.signature, DagNode.from_task(task)) + dag.add_node(next_task.signature, DagNode.from_task(next_task)) dag.add_edge(task.signature, next_task.signature) return dag @@ -38,43 +39,47 @@ def test_sort_tasks_topologically(dag): task_name = sorter.get_ready()[0] topo_ordering.append(task_name) sorter.done(task_name) - topo_names = [dag.nodes[sig]["task"].name for sig in topo_ordering] + topo_names = [dag.nodes[sig].task_or_raise().name for sig in topo_ordering] assert topo_names == [f".::{i}" for i in range(5)] def test_descending_tasks(dag): for i in range(5): task = next( - dag.nodes[sig]["task"] + dag.nodes[sig].task_or_raise() for sig in dag.nodes - if dag.nodes[sig]["task"].name == f".::{i}" + if dag.nodes[sig].task_or_raise().name == f".::{i}" ) descendants = descending_tasks(task.signature, dag) - descendant_names = sorted(dag.nodes[sig]["task"].name for sig in descendants) + descendant_names = sorted( + dag.nodes[sig].task_or_raise().name for sig in descendants + ) assert descendant_names == [f".::{i}" for i in range(i + 1, 5)] def test_task_and_descending_tasks(dag): for i in range(5): task = next( - dag.nodes[sig]["task"] + dag.nodes[sig].task_or_raise() for sig in dag.nodes - if dag.nodes[sig]["task"].name == f".::{i}" + if dag.nodes[sig].task_or_raise().name == f".::{i}" ) descendants = task_and_descending_tasks(task.signature, dag) - descendant_names = sorted(dag.nodes[sig]["task"].name for sig in descendants) + descendant_names = sorted( + dag.nodes[sig].task_or_raise().name for sig in descendants + ) assert descendant_names == [f".::{i}" for i in range(i, 5)] def test_node_and_neighbors(dag): for i in range(1, 4): task = next( - dag.nodes[sig]["task"] + dag.nodes[sig].task_or_raise() for sig in dag.nodes - if dag.nodes[sig]["task"].name == f".::{i}" + if dag.nodes[sig].task_or_raise().name == f".::{i}" ) nodes = node_and_neighbors(dag, task.signature) - node_names = [dag.nodes[sig]["task"].name for sig in nodes] + node_names = [dag.nodes[sig].task_or_raise().name for sig in nodes] assert node_names == [f".::{j}" for j in range(i - 1, i + 2)] @@ -171,7 +176,7 @@ def test_ask_for_invalid_number_of_ready_tasks(dag): def test_instantiate_sorter_from_other_sorter(dag): - name_to_sig = {dag.nodes[sig]["task"].name: sig for sig in dag.nodes} + name_to_sig = {dag.nodes[sig].task_or_raise().name: sig for sig in dag.nodes} scheduler = TopologicalSorter.from_dag(dag) for _ in range(2): @@ -180,7 +185,7 @@ def test_instantiate_sorter_from_other_sorter(dag): assert scheduler._nodes_done == {name_to_sig[name] for name in (".::0", ".::1")} task = Task(base_name="5", path=Path(), function=noop) - dag.add_node(task.signature, task=Task(base_name="5", path=Path(), function=noop)) + dag.add_node(task.signature, DagNode.from_task(task)) dag.add_edge(name_to_sig[".::4"], task.signature) new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) From 1695d968d9da90d36356c4ecfdb51508fb1bf309 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 22:57:26 +0100 Subject: [PATCH 07/13] Rename internal graph types to DAG --- src/_pytask/dag.py | 38 ++++++++----------- src/_pytask/dag_command.py | 11 ++---- src/_pytask/dag_graph.py | 71 ++++++++++++++++-------------------- src/_pytask/dag_utils.py | 34 ++++++----------- src/_pytask/mark/__init__.py | 11 ++---- src/_pytask/session.py | 5 +-- src/_pytask/shared.py | 5 +-- tests/test_dag_utils.py | 12 +++--- 8 files changed, 76 insertions(+), 111 deletions(-) diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index d9e281fb..eaf05af5 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -16,8 +16,8 @@ from _pytask.console import format_node_name from _pytask.console import format_task_name from _pytask.console import render_to_string -from _pytask.dag_graph import DagNode -from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import DAG +from _pytask.dag_graph import DAGNode from _pytask.dag_graph import NoCycleError from _pytask.dag_graph import find_cycle from _pytask.exceptions import ResolvingDependenciesError @@ -40,7 +40,7 @@ __all__ = ["create_dag", "create_dag_from_session"] -def create_dag(session: Session) -> DiGraph[str, DagNode]: +def create_dag(session: Session) -> DAG: """Create a directed acyclic graph (DAG) for the workflow.""" try: dag = create_dag_from_session(session) @@ -53,7 +53,7 @@ def create_dag(session: Session) -> DiGraph[str, DagNode]: return dag -def create_dag_from_session(session: Session) -> DiGraph[str, DagNode]: +def create_dag_from_session(session: Session) -> DAG: """Create a DAG from a session.""" dag = _create_dag_from_tasks(tasks=session.tasks) _check_if_dag_has_cycles(dag) @@ -63,19 +63,15 @@ def create_dag_from_session(session: Session) -> DiGraph[str, DagNode]: return dag -def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph[str, DagNode]: +def _create_dag_from_tasks(tasks: list[PTask]) -> DAG: """Create the DAG from tasks, dependencies and products.""" - def _add_node_data( - dag: DiGraph[str, DagNode], node: PNode | PProvisionalNode - ) -> None: - dag.add_node(node.signature, DagNode.from_node(node)) + def _add_node_data(dag: DAG, node: PNode | PProvisionalNode) -> None: + dag.add_node(node.signature, DAGNode.from_node(node)) if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): _add_node_data(dag, node.value) - def _add_dependency( - dag: DiGraph[str, DagNode], task: PTask, node: PNode | PProvisionalNode - ) -> None: + def _add_dependency(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None: """Add a dependency to the DAG.""" dag.add_edge(node.signature, task.signature) @@ -85,16 +81,14 @@ def _add_dependency( if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): dag.add_edge(node.value.signature, node.signature) - def _add_product( - dag: DiGraph[str, DagNode], task: PTask, node: PNode | PProvisionalNode - ) -> None: + def _add_product(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None: """Add a product to the DAG.""" dag.add_edge(task.signature, node.signature) - dag = DiGraph[str, DagNode]() + dag = DAG() for task in tasks: - dag.add_node(task.signature, DagNode.from_task(task)) + dag.add_node(task.signature, DAGNode.from_task(task)) tree_map(lambda x: _add_node_data(dag, x), task.depends_on) tree_map(lambda x: _add_node_data(dag, x), task.produces) @@ -104,7 +98,7 @@ def _add_product( return dag -def _modify_dag(session: Session, dag: DiGraph[str, DagNode]) -> DiGraph[str, DagNode]: +def _modify_dag(session: Session, dag: DAG) -> DAG: """Create dependencies between tasks when using ``@task(after=...)``.""" temporary_id_to_task = { task.attributes["collection_id"]: task @@ -128,7 +122,7 @@ def _modify_dag(session: Session, dag: DiGraph[str, DagNode]) -> DiGraph[str, Da return dag -def _check_if_dag_has_cycles(dag: DiGraph[str, DagNode]) -> None: +def _check_if_dag_has_cycles(dag: DAG) -> None: """Check if DAG has cycles.""" try: cycles = find_cycle(dag) @@ -144,7 +138,7 @@ def _check_if_dag_has_cycles(dag: DiGraph[str, DagNode]) -> None: raise ResolvingDependenciesError(msg) -def _format_cycles(dag: DiGraph[str, DagNode], cycles: list[tuple[str, str]]) -> str: +def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str: """Format cycles as a paths connected by arrows.""" chain = [ x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0 @@ -175,9 +169,7 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str: return render_to_string(tree, console=console, strip_styles=True) -def _check_if_tasks_have_the_same_products( - dag: DiGraph[str, DagNode], paths: list[Path] -) -> None: +def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None: nodes_created_by_multiple_tasks = [] for node in dag.nodes: diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index f1f152c3..7ed12e32 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -33,8 +33,7 @@ if TYPE_CHECKING: import networkx as nx - from _pytask.dag_graph import DagNode - from _pytask.dag_graph import DiGraph + from _pytask.dag_graph import DAG class _RankDirection(enum.Enum): @@ -183,7 +182,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: return _to_visualization_graph(session) -def _refine_dag(session: Session) -> DiGraph[str, DagNode]: +def _refine_dag(session: Session) -> DAG: """Refine the dag for plotting.""" dag = _shorten_node_labels(session.dag, session.config["paths"]) return _clean_dag(dag) @@ -199,9 +198,7 @@ def _to_visualization_graph(session: Session) -> nx.DiGraph: return dag -def _shorten_node_labels( - dag: DiGraph[str, DagNode], paths: list[Path] -) -> DiGraph[str, DagNode]: +def _shorten_node_labels(dag: DAG, paths: list[Path]) -> DAG: """Shorten the node labels in the graph for a better experience.""" node_names = dag.nodes short_names = reduce_names_of_multiple_nodes(node_names, dag, paths) @@ -210,7 +207,7 @@ def _shorten_node_labels( return dag.relabel_nodes(old_to_new) -def _clean_dag(dag: DiGraph[str, DagNode]) -> DiGraph[str, DagNode]: +def _clean_dag(dag: DAG) -> DAG: """Clean the DAG.""" return dag diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 7b2d4eab..d06bdbdd 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -3,13 +3,10 @@ from __future__ import annotations import itertools -from collections.abc import Hashable from dataclasses import dataclass from dataclasses import field from typing import TYPE_CHECKING from typing import Any -from typing import Generic -from typing import TypeVar from typing import cast from _pytask.compat import import_optional_dependency @@ -25,16 +22,12 @@ from _pytask.node_protocols import PTask -NodeIdT = TypeVar("NodeIdT", bound=Hashable) -PayloadT = TypeVar("PayloadT") - - class NoCycleError(Exception): """Raised when no cycle is found in a graph.""" @dataclass(slots=True) -class DagNode: +class DAGNode: """Payload stored for nodes in pytask's internal DAG.""" task: PTask | None = None @@ -46,12 +39,12 @@ def __post_init__(self) -> None: raise ValueError(msg) @classmethod - def from_task(cls, task: PTask) -> DagNode: + def from_task(cls, task: PTask) -> DAGNode: """Create a DAG node from a task.""" return cls(task=task) @classmethod - def from_node(cls, node: PNode | PProvisionalNode) -> DagNode: + def from_node(cls, node: PNode | PProvisionalNode) -> DAGNode: """Create a DAG node from a dependency or product node.""" return cls(node=node) @@ -78,41 +71,41 @@ def node_or_raise(self) -> PNode | PProvisionalNode: @dataclass -class DiGraph(Generic[NodeIdT, PayloadT]): +class DAG: """A minimal directed graph tailored to pytask's needs.""" - _node_data: dict[NodeIdT, PayloadT] = field(default_factory=dict) - _successors: dict[NodeIdT, set[NodeIdT]] = field(default_factory=dict) - _predecessors: dict[NodeIdT, set[NodeIdT]] = field(default_factory=dict) + _node_data: dict[str, DAGNode] = field(default_factory=dict) + _successors: dict[str, set[str]] = field(default_factory=dict) + _predecessors: dict[str, set[str]] = field(default_factory=dict) @property - def nodes(self) -> dict[NodeIdT, PayloadT]: + def nodes(self) -> dict[str, DAGNode]: return self._node_data - def add_node(self, node_name: NodeIdT, data: PayloadT) -> None: + def add_node(self, node_name: str, data: DAGNode) -> None: if node_name not in self._node_data: self._successors[node_name] = set() self._predecessors[node_name] = set() self._node_data[node_name] = data - def add_edge(self, source: NodeIdT, target: NodeIdT) -> None: + def add_edge(self, source: str, target: str) -> None: if source not in self._node_data or target not in self._node_data: msg = "Both nodes must exist before adding an edge." raise KeyError(msg) self._successors[source].add(target) self._predecessors[target].add(source) - def successors(self, node: NodeIdT) -> Iterator[NodeIdT]: + def successors(self, node: str) -> Iterator[str]: return iter(self._successors[node]) - def predecessors(self, node: NodeIdT) -> Iterator[NodeIdT]: + def predecessors(self, node: str) -> Iterator[str]: return iter(self._predecessors[node]) - def in_degree(self) -> Iterator[tuple[NodeIdT, int]]: + def in_degree(self) -> Iterator[tuple[str, int]]: for node, predecessors_ in self._predecessors.items(): yield node, len(predecessors_) - def remove_nodes_from(self, nodes: Iterable[NodeIdT]) -> None: + def remove_nodes_from(self, nodes: Iterable[str]) -> None: for node in nodes: if node not in self._node_data: continue @@ -127,8 +120,8 @@ def remove_nodes_from(self, nodes: Iterable[NodeIdT]) -> None: def is_directed(self) -> bool: return True - def reverse(self) -> DiGraph[NodeIdT, PayloadT]: - graph = DiGraph[NodeIdT, PayloadT]() + def reverse(self) -> DAG: + graph = DAG() for node, data in self._node_data.items(): graph.add_node(node, data) for source, successors in self._successors.items(): @@ -136,10 +129,8 @@ def reverse(self) -> DiGraph[NodeIdT, PayloadT]: graph.add_edge(target, source) return graph - def relabel_nodes( - self, mapping: Mapping[NodeIdT, NodeIdT] - ) -> DiGraph[NodeIdT, PayloadT]: - graph = DiGraph[NodeIdT, PayloadT]() + def relabel_nodes(self, mapping: Mapping[str, str]) -> DAG: + graph = DAG() new_labels = [mapping.get(node, node) for node in self._node_data] if len(new_labels) != len(set(new_labels)): @@ -165,22 +156,22 @@ def to_networkx(self) -> Any: return graph -def descendants(dag: DiGraph[NodeIdT, PayloadT], node: NodeIdT) -> set[NodeIdT]: +def descendants(dag: DAG, node: str) -> set[str]: """Return all descendants of a node.""" return _traverse(dag, node, dag.successors) -def ancestors(dag: DiGraph[NodeIdT, PayloadT], node: NodeIdT) -> set[NodeIdT]: +def ancestors(dag: DAG, node: str) -> set[str]: """Return all ancestors of a node.""" return _traverse(dag, node, dag.predecessors) def _traverse( - _dag: DiGraph[NodeIdT, PayloadT], - node: NodeIdT, - adjacency: Callable[[NodeIdT], Iterable[NodeIdT]], -) -> set[NodeIdT]: - visited: set[NodeIdT] = set() + _dag: DAG, + node: str, + adjacency: Callable[[str], Iterable[str]], +) -> set[str]: + visited: set[str] = set() stack = list(adjacency(node)) while stack: @@ -194,14 +185,14 @@ def _traverse( def find_cycle( - dag: DiGraph[NodeIdT, PayloadT], -) -> list[tuple[NodeIdT, NodeIdT]]: + dag: DAG, +) -> list[tuple[str, str]]: """Find one cycle in the graph.""" - visited: set[NodeIdT] = set() - active: set[NodeIdT] = set() - path: list[NodeIdT] = [] + visited: set[str] = set() + active: set[str] = set() + path: list[str] = [] - def _visit(node: NodeIdT) -> list[tuple[NodeIdT, NodeIdT]] | None: + def _visit(node: str) -> list[tuple[str, str]] | None: visited.add(node) active.add(node) path.append(node) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index f5f590d6..240a43bd 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -6,10 +6,8 @@ from dataclasses import dataclass from dataclasses import field from typing import TYPE_CHECKING -from typing import Any -from _pytask.dag_graph import DagNode -from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import DAG from _pytask.dag_graph import NoCycleError from _pytask.dag_graph import ancestors from _pytask.dag_graph import descendants @@ -23,41 +21,33 @@ from _pytask.node_protocols import PTask -def descending_tasks( - task_name: str, dag: DiGraph[str, DagNode] -) -> Generator[str, None, None]: +def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only descending tasks.""" for descendant in descendants(dag, task_name): if dag.nodes[descendant].task is not None: yield descendant -def task_and_descending_tasks( - task_name: str, dag: DiGraph[str, DagNode] -) -> Generator[str, None, None]: +def task_and_descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield task and descending tasks.""" yield task_name yield from descending_tasks(task_name, dag) -def preceding_tasks( - task_name: str, dag: DiGraph[str, DagNode] -) -> Generator[str, None, None]: +def preceding_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only preceding tasks.""" for ancestor in ancestors(dag, task_name): if dag.nodes[ancestor].task is not None: yield ancestor -def task_and_preceding_tasks( - task_name: str, dag: DiGraph[str, DagNode] -) -> Generator[str, None, None]: +def task_and_preceding_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield task and preceding tasks.""" yield task_name yield from preceding_tasks(task_name, dag) -def node_and_neighbors(dag: DiGraph[str, DagNode], node: str) -> Iterable[str]: +def node_and_neighbors(dag: DAG, node: str) -> Iterable[str]: """Yield node and neighbors which are first degree predecessors and successors. We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a @@ -86,13 +76,13 @@ class TopologicalSorter: """ - dag: DiGraph[str, None] + dag: DAG priorities: dict[str, int] = field(default_factory=dict) _nodes_processing: set[str] = field(default_factory=set) _nodes_done: set[str] = field(default_factory=set) @classmethod - def from_dag(cls, dag: DiGraph[str, DagNode]) -> TopologicalSorter: + def from_dag(cls, dag: DAG) -> TopologicalSorter: """Instantiate from a DAG.""" cls.check_dag(dag) @@ -100,9 +90,9 @@ def from_dag(cls, dag: DiGraph[str, DagNode]) -> TopologicalSorter: priorities = _extract_priorities_from_tasks(tasks) task_signatures = {task.signature for task in tasks} - task_dag = DiGraph[str, None]() + task_dag = DAG() for signature in task_signatures: - task_dag.add_node(signature, None) + task_dag.add_node(signature, dag.nodes[signature]) for signature in task_signatures: # The scheduler graph uses edges from predecessor -> successor so that # zero in-degree means "ready to run". This is the same orientation the @@ -114,7 +104,7 @@ def from_dag(cls, dag: DiGraph[str, DagNode]) -> TopologicalSorter: @classmethod def from_dag_and_sorter( - cls, dag: DiGraph[str, DagNode], sorter: TopologicalSorter + cls, dag: DAG, sorter: TopologicalSorter ) -> TopologicalSorter: """Instantiate a sorter from another sorter and a DAG.""" new_sorter = cls.from_dag(dag) @@ -123,7 +113,7 @@ def from_dag_and_sorter( return new_sorter @staticmethod - def check_dag(dag: DiGraph[str, Any]) -> None: + def check_dag(dag: DAG) -> None: if not dag.is_directed(): msg = "Only directed graphs have a topological order." raise ValueError(msg) diff --git a/src/_pytask/mark/__init__.py b/src/_pytask/mark/__init__.py index 9a724ce3..03e23dd9 100644 --- a/src/_pytask/mark/__init__.py +++ b/src/_pytask/mark/__init__.py @@ -30,8 +30,7 @@ from collections.abc import Set as AbstractSet from typing import NoReturn - from _pytask.dag_graph import DagNode - from _pytask.dag_graph import DiGraph + from _pytask.dag_graph import DAG from _pytask.node_protocols import PTask @@ -153,7 +152,7 @@ def __call__(self, subname: str) -> bool: return any(subname in name for name in names) -def select_by_keyword(session: Session, dag: DiGraph[str, DagNode]) -> set[str] | None: +def select_by_keyword(session: Session, dag: DAG) -> set[str] | None: """Deselect tests by keywords.""" keywordexpr = session.config["expression"] if not keywordexpr: @@ -208,7 +207,7 @@ def __call__(self, name: str) -> bool: return name in self.own_mark_names -def select_by_mark(session: Session, dag: DiGraph[str, DagNode]) -> set[str] | None: +def select_by_mark(session: Session, dag: DAG) -> set[str] | None: """Deselect tests by marks.""" matchexpr = session.config["marker_expression"] if not matchexpr: @@ -237,9 +236,7 @@ def _deselect_others_with_mark( task.markers.append(mark) -def select_tasks_by_marks_and_expressions( - session: Session, dag: DiGraph[str, DagNode] -) -> None: +def select_tasks_by_marks_and_expressions(session: Session, dag: DAG) -> None: """Modify the tasks which are executed with expressions and markers.""" remaining = select_by_keyword(session, dag) if remaining is not None: diff --git a/src/_pytask/session.py b/src/_pytask/session.py index 8f1575ca..5c9bbf8b 100644 --- a/src/_pytask/session.py +++ b/src/_pytask/session.py @@ -9,8 +9,7 @@ from pluggy import HookRelay -from _pytask.dag_graph import DagNode -from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import DAG from _pytask.outcomes import ExitCode if TYPE_CHECKING: @@ -52,7 +51,7 @@ class Session: config: dict[str, Any] = field(default_factory=dict) collection_reports: list[CollectionReport] = field(default_factory=list) - dag: DiGraph[str, DagNode] = field(default_factory=DiGraph) + dag: DAG = field(default_factory=DAG) hook: HookRelay = field(default_factory=HookRelay) tasks: list[PTask] = field(default_factory=list) dag_report: DagReport | None = None diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 0a3aefee..1b3c84e3 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -23,8 +23,7 @@ if TYPE_CHECKING: from enum import Enum - from _pytask.dag_graph import DagNode - from _pytask.dag_graph import DiGraph + from _pytask.dag_graph import DAG __all__ = [ @@ -80,7 +79,7 @@ def parse_paths(x: Path | list[Path]) -> list[Path]: def reduce_names_of_multiple_nodes( - names: Iterable[str], dag: DiGraph[str, DagNode], paths: Sequence[Path] + names: Iterable[str], dag: DAG, paths: Sequence[Path] ) -> list[str]: """Reduce the names of multiple nodes in the DAG.""" short_names = [] diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index ae37d9ee..75a4bcba 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -6,8 +6,8 @@ import pytest -from _pytask.dag_graph import DagNode -from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import DAG +from _pytask.dag_graph import DAGNode from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import _extract_priorities_from_tasks from _pytask.dag_utils import descending_tasks @@ -21,12 +21,12 @@ @pytest.fixture def dag(): """Create a dag with five nodes in a line.""" - dag = DiGraph[str, DagNode]() + dag = DAG() for i in range(4): task = Task(base_name=str(i), path=Path(), function=noop) next_task = Task(base_name=str(i + 1), path=Path(), function=noop) - dag.add_node(task.signature, DagNode.from_task(task)) - dag.add_node(next_task.signature, DagNode.from_task(next_task)) + dag.add_node(task.signature, DAGNode.from_task(task)) + dag.add_node(next_task.signature, DAGNode.from_task(next_task)) dag.add_edge(task.signature, next_task.signature) return dag @@ -185,7 +185,7 @@ def test_instantiate_sorter_from_other_sorter(dag): assert scheduler._nodes_done == {name_to_sig[name] for name in (".::0", ".::1")} task = Task(base_name="5", path=Path(), function=noop) - dag.add_node(task.signature, DagNode.from_task(task)) + dag.add_node(task.signature, DAGNode.from_task(task)) dag.add_edge(name_to_sig[".::4"], task.signature) new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) From fb103b67043acf46098fab2b71715d3be81bbfc3 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 23:09:56 +0100 Subject: [PATCH 08/13] Store DAG payloads directly --- src/_pytask/dag.py | 17 +++++----- src/_pytask/dag_graph.py | 58 +++++------------------------------ src/_pytask/dag_utils.py | 9 +++--- src/_pytask/database_utils.py | 2 +- src/_pytask/execute.py | 26 +++++++++++++--- src/_pytask/lockfile.py | 4 +-- src/_pytask/persist.py | 2 +- src/_pytask/profile.py | 2 +- src/_pytask/shared.py | 2 +- src/_pytask/skipping.py | 12 +++++--- tests/test_dag_utils.py | 33 +++++++------------- 11 files changed, 67 insertions(+), 100 deletions(-) diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index eaf05af5..a9a3f6cf 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -17,7 +17,6 @@ from _pytask.console import format_task_name from _pytask.console import render_to_string from _pytask.dag_graph import DAG -from _pytask.dag_graph import DAGNode from _pytask.dag_graph import NoCycleError from _pytask.dag_graph import find_cycle from _pytask.exceptions import ResolvingDependenciesError @@ -67,7 +66,7 @@ def _create_dag_from_tasks(tasks: list[PTask]) -> DAG: """Create the DAG from tasks, dependencies and products.""" def _add_node_data(dag: DAG, node: PNode | PProvisionalNode) -> None: - dag.add_node(node.signature, DAGNode.from_node(node)) + dag.add_node(node.signature, node) if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): _add_node_data(dag, node.value) @@ -88,7 +87,7 @@ def _add_product(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None: dag = DAG() for task in tasks: - dag.add_node(task.signature, DAGNode.from_task(task)) + dag.add_node(task.signature, task) tree_map(lambda x: _add_node_data(dag, x), task.depends_on) tree_map(lambda x: _add_node_data(dag, x), task.produces) @@ -147,7 +146,7 @@ def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str: lines: list[str] = [] for x in chain: - node = dag.nodes[x].value + node = dag.nodes[x] if isinstance(node, PTask): short_name = format_task_name(node, editor_url_scheme="no_link").plain elif isinstance(node, (PNode, PProvisionalNode)): @@ -173,7 +172,7 @@ def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None: nodes_created_by_multiple_tasks = [] for node in dag.nodes: - if dag.nodes[node].node is not None: + if isinstance(dag.nodes[node], (PNode, PProvisionalNode)): parents = list(dag.predecessors(node)) if len(parents) > 1: nodes_created_by_multiple_tasks.append(node) @@ -181,9 +180,11 @@ def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None: if nodes_created_by_multiple_tasks: dictionary = {} for node in nodes_created_by_multiple_tasks: - short_node_name = format_node_name( - dag.nodes[node].node_or_raise(), paths - ).plain + payload = dag.nodes[node] + if not isinstance(payload, (PNode, PProvisionalNode)): + msg = f"Expected product node for signature {node!r}." + raise TypeError(msg) + short_node_name = format_node_name(payload, paths).plain short_predecessors = reduce_names_of_multiple_nodes( dag.predecessors(node), dag, paths ) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index d06bdbdd..0fd6aa3d 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -10,6 +10,9 @@ from typing import cast from _pytask.compat import import_optional_dependency +from _pytask.node_protocols import PNode +from _pytask.node_protocols import PProvisionalNode +from _pytask.node_protocols import PTask if TYPE_CHECKING: from collections.abc import Callable @@ -17,72 +20,27 @@ from collections.abc import Iterator from collections.abc import Mapping - from _pytask.node_protocols import PNode - from _pytask.node_protocols import PProvisionalNode - from _pytask.node_protocols import PTask + +DAGEntry = PTask | PNode | PProvisionalNode class NoCycleError(Exception): """Raised when no cycle is found in a graph.""" -@dataclass(slots=True) -class DAGNode: - """Payload stored for nodes in pytask's internal DAG.""" - - task: PTask | None = None - node: PNode | PProvisionalNode | None = None - - def __post_init__(self) -> None: - if (self.task is None) == (self.node is None): - msg = "A DAG node must store exactly one of 'task' or 'node'." - raise ValueError(msg) - - @classmethod - def from_task(cls, task: PTask) -> DAGNode: - """Create a DAG node from a task.""" - return cls(task=task) - - @classmethod - def from_node(cls, node: PNode | PProvisionalNode) -> DAGNode: - """Create a DAG node from a dependency or product node.""" - return cls(node=node) - - @property - def value(self) -> PTask | PNode | PProvisionalNode: - """Return the wrapped task or node.""" - if self.task is not None: - return self.task - return cast("PNode | PProvisionalNode", self.node) - - def task_or_raise(self) -> PTask: - """Return the wrapped task.""" - if self.task is None: - msg = "Expected DAG payload to contain a task." - raise TypeError(msg) - return self.task - - def node_or_raise(self) -> PNode | PProvisionalNode: - """Return the wrapped dependency or product node.""" - if self.node is None: - msg = "Expected DAG payload to contain a node." - raise TypeError(msg) - return self.node - - @dataclass class DAG: """A minimal directed graph tailored to pytask's needs.""" - _node_data: dict[str, DAGNode] = field(default_factory=dict) + _node_data: dict[str, DAGEntry] = field(default_factory=dict) _successors: dict[str, set[str]] = field(default_factory=dict) _predecessors: dict[str, set[str]] = field(default_factory=dict) @property - def nodes(self) -> dict[str, DAGNode]: + def nodes(self) -> dict[str, DAGEntry]: return self._node_data - def add_node(self, node_name: str, data: DAGNode) -> None: + def add_node(self, node_name: str, data: DAGEntry) -> None: if node_name not in self._node_data: self._successors[node_name] = set() self._predecessors[node_name] = set() diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 240a43bd..6fda8469 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -13,18 +13,17 @@ from _pytask.dag_graph import descendants from _pytask.dag_graph import find_cycle from _pytask.mark_utils import has_mark +from _pytask.node_protocols import PTask if TYPE_CHECKING: from collections.abc import Generator from collections.abc import Iterable - from _pytask.node_protocols import PTask - def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only descending tasks.""" for descendant in descendants(dag, task_name): - if dag.nodes[descendant].task is not None: + if isinstance(dag.nodes[descendant], PTask): yield descendant @@ -37,7 +36,7 @@ def task_and_descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, def preceding_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only preceding tasks.""" for ancestor in ancestors(dag, task_name): - if dag.nodes[ancestor].task is not None: + if isinstance(dag.nodes[ancestor], PTask): yield ancestor @@ -86,7 +85,7 @@ def from_dag(cls, dag: DAG) -> TopologicalSorter: """Instantiate from a DAG.""" cls.check_dag(dag) - tasks = [node.task for node in dag.nodes.values() if node.task is not None] + tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)] priorities = _extract_priorities_from_tasks(tasks) task_signatures = {task.signature for task in tasks} diff --git a/src/_pytask/database_utils.py b/src/_pytask/database_utils.py index 41a93e5c..ab5bd71c 100644 --- a/src/_pytask/database_utils.py +++ b/src/_pytask/database_utils.py @@ -116,7 +116,7 @@ def update_states_in_database(session: Session, task_signature: str) -> None: if _ENGINE is None: return for name in node_and_neighbors(session.dag, task_signature): - node = session.dag.nodes[name].value + node = session.dag.nodes[name] if isinstance(node, PProvisionalNode): continue hash_ = node.state() diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 0a8b7552..16d2efad 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -89,7 +89,10 @@ def pytask_execute_build(session: Session) -> bool | None: if isinstance(session.scheduler, TopologicalSorter): while session.scheduler.is_active(): task_name = session.scheduler.get_ready()[0] - task = session.dag.nodes[task_name].task_or_raise() + task = session.dag.nodes[task_name] + if not isinstance(task, PTask): + msg = f"Expected task node for signature {task_name!r}." + raise TypeError(msg) report = session.hook.pytask_execute_task_protocol( session=session, task=task ) @@ -172,7 +175,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): - node = dag.nodes[node_signature].value + node = dag.nodes[node_signature] # Skip provisional nodes that are products since they do not have a state. if node_signature not in predecessors and isinstance( @@ -253,7 +256,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C # Create directory for product if it does not exist. Maybe this should be a `setup` # method for the node classes. for product in dag.successors(task.signature): - node = dag.nodes[product].node_or_raise() + node = dag.nodes[product] + if not isinstance(node, (PNode, PProvisionalNode)): + msg = f"Expected product node for signature {product!r}." + raise TypeError(msg) if isinstance(node, PPathNode): node.path.parent.mkdir(parents=True, exist_ok=True) if isinstance(node, DirectoryNode) and node.root_dir: @@ -347,7 +353,12 @@ def pytask_execute_task_process_report( report.outcome = TaskOutcome.WOULD_BE_EXECUTED for descending_task_name in descending_tasks(task.signature, session.dag): - descending_task = session.dag.nodes[descending_task_name].task_or_raise() + descending_task = session.dag.nodes[descending_task_name] + if not isinstance(descending_task, PTask): + msg = ( + f"Expected descending task for signature {descending_task_name!r}." + ) + raise TypeError(msg) descending_task.markers.append( Mark( "would_be_executed", @@ -360,7 +371,12 @@ def pytask_execute_task_process_report( ) else: for descending_task_name in descending_tasks(task.signature, session.dag): - descending_task = session.dag.nodes[descending_task_name].task_or_raise() + descending_task = session.dag.nodes[descending_task_name] + if not isinstance(descending_task, PTask): + msg = ( + f"Expected descending task for signature {descending_task_name!r}." + ) + raise TypeError(msg) descending_task.markers.append( Mark( "skip_ancestor_failed", diff --git a/src/_pytask/lockfile.py b/src/_pytask/lockfile.py index 0c6bb73c..65d83484 100644 --- a/src/_pytask/lockfile.py +++ b/src/_pytask/lockfile.py @@ -225,7 +225,7 @@ def _build_task_entry(session: Session, task: PTask, root: Path) -> _TaskEntry | depends_on: dict[str, str] = {} for node_signature in predecessors: - node = dag.nodes[node_signature].value + node = dag.nodes[node_signature] if not isinstance(node, (PNode, PTask)): continue state = node.state() @@ -240,7 +240,7 @@ def _build_task_entry(session: Session, task: PTask, root: Path) -> _TaskEntry | produces: dict[str, str] = {} for node_signature in successors: - node = dag.nodes[node_signature].value + node = dag.nodes[node_signature] if not isinstance(node, (PNode, PTask)): continue state = node.state() diff --git a/src/_pytask/persist.py b/src/_pytask/persist.py index 6b72ffff..0ab1ee56 100644 --- a/src/_pytask/persist.py +++ b/src/_pytask/persist.py @@ -51,7 +51,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: if has_mark(task, "persist"): stateful_nodes: list[tuple[PTask | PNode, str | None]] = [] for name in node_and_neighbors(session.dag, task.signature): - node = session.dag.nodes[name].value + node = session.dag.nodes[name] if isinstance(node, PProvisionalNode): continue stateful_nodes.append((node, node.state())) diff --git a/src/_pytask/profile.py b/src/_pytask/profile.py index 12eba84e..8e0cb97b 100644 --- a/src/_pytask/profile.py +++ b/src/_pytask/profile.py @@ -210,7 +210,7 @@ def pytask_profile_add_info_on_task( if successors: sum_bytes = 0 for successor in successors: - node = session.dag.nodes[successor].node_or_raise() + node = session.dag.nodes[successor] if isinstance(node, PPathNode): with suppress(FileNotFoundError): sum_bytes += node.path.stat().st_size diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 1b3c84e3..3c8eedd7 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -84,7 +84,7 @@ def reduce_names_of_multiple_nodes( """Reduce the names of multiple nodes in the DAG.""" short_names = [] for name in names: - node = dag.nodes[name].value + node = dag.nodes[name] if isinstance(node, PTask): short_name = format_task_name(node, editor_url_scheme="no_link").plain diff --git a/src/_pytask/skipping.py b/src/_pytask/skipping.py index 71c7fcf0..ac9946d8 100644 --- a/src/_pytask/skipping.py +++ b/src/_pytask/skipping.py @@ -9,6 +9,7 @@ from _pytask.mark import Mark from _pytask.mark_utils import get_marks from _pytask.mark_utils import has_mark +from _pytask.node_protocols import PTask from _pytask.outcomes import Skipped from _pytask.outcomes import SkippedAncestorFailed from _pytask.outcomes import SkippedUnchanged @@ -17,7 +18,6 @@ from _pytask.provisional_utils import collect_provisional_products if TYPE_CHECKING: - from _pytask.node_protocols import PTask from _pytask.reports import ExecutionReport from _pytask.session import Session @@ -97,9 +97,13 @@ def pytask_execute_task_process_report( report.outcome = TaskOutcome.SKIP for descending_task_name in descending_tasks(task.signature, session.dag): - descending_task = session.dag.nodes[ - descending_task_name - ].task_or_raise() + descending_task = session.dag.nodes[descending_task_name] + if not isinstance(descending_task, PTask): + msg = ( + f"Expected descending task for signature " + f"{descending_task_name!r}." + ) + raise TypeError(msg) descending_task.markers.append( Mark( "skip", diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index 75a4bcba..d6d18f66 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -7,7 +7,6 @@ import pytest from _pytask.dag_graph import DAG -from _pytask.dag_graph import DAGNode from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import _extract_priorities_from_tasks from _pytask.dag_utils import descending_tasks @@ -25,8 +24,8 @@ def dag(): for i in range(4): task = Task(base_name=str(i), path=Path(), function=noop) next_task = Task(base_name=str(i + 1), path=Path(), function=noop) - dag.add_node(task.signature, DAGNode.from_task(task)) - dag.add_node(next_task.signature, DAGNode.from_task(next_task)) + dag.add_node(task.signature, task) + dag.add_node(next_task.signature, next_task) dag.add_edge(task.signature, next_task.signature) return dag @@ -39,47 +38,37 @@ def test_sort_tasks_topologically(dag): task_name = sorter.get_ready()[0] topo_ordering.append(task_name) sorter.done(task_name) - topo_names = [dag.nodes[sig].task_or_raise().name for sig in topo_ordering] + topo_names = [dag.nodes[sig].name for sig in topo_ordering] assert topo_names == [f".::{i}" for i in range(5)] def test_descending_tasks(dag): for i in range(5): task = next( - dag.nodes[sig].task_or_raise() - for sig in dag.nodes - if dag.nodes[sig].task_or_raise().name == f".::{i}" + dag.nodes[sig] for sig in dag.nodes if dag.nodes[sig].name == f".::{i}" ) descendants = descending_tasks(task.signature, dag) - descendant_names = sorted( - dag.nodes[sig].task_or_raise().name for sig in descendants - ) + descendant_names = sorted(dag.nodes[sig].name for sig in descendants) assert descendant_names == [f".::{i}" for i in range(i + 1, 5)] def test_task_and_descending_tasks(dag): for i in range(5): task = next( - dag.nodes[sig].task_or_raise() - for sig in dag.nodes - if dag.nodes[sig].task_or_raise().name == f".::{i}" + dag.nodes[sig] for sig in dag.nodes if dag.nodes[sig].name == f".::{i}" ) descendants = task_and_descending_tasks(task.signature, dag) - descendant_names = sorted( - dag.nodes[sig].task_or_raise().name for sig in descendants - ) + descendant_names = sorted(dag.nodes[sig].name for sig in descendants) assert descendant_names == [f".::{i}" for i in range(i, 5)] def test_node_and_neighbors(dag): for i in range(1, 4): task = next( - dag.nodes[sig].task_or_raise() - for sig in dag.nodes - if dag.nodes[sig].task_or_raise().name == f".::{i}" + dag.nodes[sig] for sig in dag.nodes if dag.nodes[sig].name == f".::{i}" ) nodes = node_and_neighbors(dag, task.signature) - node_names = [dag.nodes[sig].task_or_raise().name for sig in nodes] + node_names = [dag.nodes[sig].name for sig in nodes] assert node_names == [f".::{j}" for j in range(i - 1, i + 2)] @@ -176,7 +165,7 @@ def test_ask_for_invalid_number_of_ready_tasks(dag): def test_instantiate_sorter_from_other_sorter(dag): - name_to_sig = {dag.nodes[sig].task_or_raise().name: sig for sig in dag.nodes} + name_to_sig = {dag.nodes[sig].name: sig for sig in dag.nodes} scheduler = TopologicalSorter.from_dag(dag) for _ in range(2): @@ -185,7 +174,7 @@ def test_instantiate_sorter_from_other_sorter(dag): assert scheduler._nodes_done == {name_to_sig[name] for name in (".::0", ".::1")} task = Task(base_name="5", path=Path(), function=noop) - dag.add_node(task.signature, DAGNode.from_task(task)) + dag.add_node(task.signature, task) dag.add_edge(name_to_sig[".::4"], task.signature) new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) From a8ff3f180b579dbe70bf43af9610091ff7687692 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 23:19:15 +0100 Subject: [PATCH 09/13] Move DAG traversal onto graph --- src/_pytask/dag_graph.py | 57 +++++++++++++++------------------------- src/_pytask/dag_utils.py | 12 +++------ tests/test_dag_utils.py | 12 --------- 3 files changed, 24 insertions(+), 57 deletions(-) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 0fd6aa3d..0e3df198 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -75,17 +75,13 @@ def remove_nodes_from(self, nodes: Iterable[str]) -> None: del self._successors[node] del self._predecessors[node] - def is_directed(self) -> bool: - return True + def descendants(self, node: str) -> set[str]: + """Return all descendants of a node.""" + return self._traverse(node, self.successors) - def reverse(self) -> DAG: - graph = DAG() - for node, data in self._node_data.items(): - graph.add_node(node, data) - for source, successors in self._successors.items(): - for target in successors: - graph.add_edge(target, source) - return graph + def ancestors(self, node: str) -> set[str]: + """Return all ancestors of a node.""" + return self._traverse(node, self.predecessors) def relabel_nodes(self, mapping: Mapping[str, str]) -> DAG: graph = DAG() @@ -113,33 +109,22 @@ def to_networkx(self) -> Any: graph.add_edge(source, target) return graph + def _traverse( + self, + node: str, + adjacency: Callable[[str], Iterable[str]], + ) -> set[str]: + visited: set[str] = set() + stack = list(adjacency(node)) + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + stack.extend(adjacency(current)) -def descendants(dag: DAG, node: str) -> set[str]: - """Return all descendants of a node.""" - return _traverse(dag, node, dag.successors) - - -def ancestors(dag: DAG, node: str) -> set[str]: - """Return all ancestors of a node.""" - return _traverse(dag, node, dag.predecessors) - - -def _traverse( - _dag: DAG, - node: str, - adjacency: Callable[[str], Iterable[str]], -) -> set[str]: - visited: set[str] = set() - stack = list(adjacency(node)) - - while stack: - current = stack.pop() - if current in visited: - continue - visited.add(current) - stack.extend(adjacency(current)) - - return visited + return visited def find_cycle( diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 6fda8469..aea91dbf 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -9,8 +9,6 @@ from _pytask.dag_graph import DAG from _pytask.dag_graph import NoCycleError -from _pytask.dag_graph import ancestors -from _pytask.dag_graph import descendants from _pytask.dag_graph import find_cycle from _pytask.mark_utils import has_mark from _pytask.node_protocols import PTask @@ -22,7 +20,7 @@ def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only descending tasks.""" - for descendant in descendants(dag, task_name): + for descendant in dag.descendants(task_name): if isinstance(dag.nodes[descendant], PTask): yield descendant @@ -35,7 +33,7 @@ def task_and_descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, def preceding_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only preceding tasks.""" - for ancestor in ancestors(dag, task_name): + for ancestor in dag.ancestors(task_name): if isinstance(dag.nodes[ancestor], PTask): yield ancestor @@ -96,7 +94,7 @@ def from_dag(cls, dag: DAG) -> TopologicalSorter: # The scheduler graph uses edges from predecessor -> successor so that # zero in-degree means "ready to run". This is the same orientation the # previous networkx-based implementation reached after calling reverse(). - for ancestor_ in ancestors(dag, signature) & task_signatures: + for ancestor_ in dag.ancestors(signature) & task_signatures: task_dag.add_edge(ancestor_, signature) return cls(dag=task_dag, priorities=priorities) @@ -113,10 +111,6 @@ def from_dag_and_sorter( @staticmethod def check_dag(dag: DAG) -> None: - if not dag.is_directed(): - msg = "Only directed graphs have a topological order." - raise ValueError(msg) - try: find_cycle(dag) except NoCycleError: diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index d6d18f66..32effef0 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -1,7 +1,6 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise # noqa: N813 -from dataclasses import dataclass from pathlib import Path import pytest @@ -138,17 +137,6 @@ def test_extract_priorities_from_tasks(tasks, expectation, expected): assert result == expected -@dataclass -class _UndirectedGraphStub: - def is_directed(self): - return False - - -def test_raise_error_for_undirected_graphs(): - with pytest.raises(ValueError, match="Only directed graphs have a"): - TopologicalSorter.from_dag(_UndirectedGraphStub()) # type: ignore[arg-type] - - def test_raise_error_for_cycle_in_graph(dag): dag.add_edge( "115f685b0af2aef0c7317a0b48562f34cfb7a622549562bd3d34d4d948b4fdab", From fcae596cd8dc1823fde465c8e307ff99ddd7d337 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 23:32:40 +0100 Subject: [PATCH 10/13] fix: limit diff --- src/_pytask/execute.py | 10 ++-------- src/_pytask/persist.py | 6 ------ tests/test_persist.py | 2 +- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 16d2efad..7f9c43d3 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -205,10 +205,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C # Check if node changed and collect detailed info if in explain mode if session.config["explain"]: has_changed, reason, details = get_node_change_info( - session=session, - task=task, - node=node, - state=node_state, + session=session, task=task, node=node, state=node_state ) if has_changed: needs_to_be_executed = True @@ -235,10 +232,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ) else: has_changed = has_node_changed( - session=session, - task=task, - node=node, - state=node_state, + session=session, task=task, node=node, state=node_state ) if has_changed: needs_to_be_executed = True diff --git a/src/_pytask/persist.py b/src/_pytask/persist.py index 0ab1ee56..8f1c6b71 100644 --- a/src/_pytask/persist.py +++ b/src/_pytask/persist.py @@ -6,7 +6,6 @@ from typing import Any from _pytask.dag_utils import node_and_neighbors -from _pytask.database_utils import update_states_in_database as _db_update_states from _pytask.mark_utils import has_mark from _pytask.node_protocols import PProvisionalNode from _pytask.outcomes import Persisted @@ -23,11 +22,6 @@ from _pytask.session import Session -def update_states_in_database(session: Session, task_signature: str) -> None: - """Compatibility wrapper for older callers/tests.""" - _db_update_states(session, task_signature) - - @hookimpl def pytask_parse_config(config: dict[str, Any]) -> None: """Add the marker to the configuration.""" diff --git a/tests/test_persist.py b/tests/test_persist.py index a1ebbf52..622d4e48 100644 --- a/tests/test_persist.py +++ b/tests/test_persist.py @@ -123,7 +123,7 @@ def task_dummy(depends_on=Path("in.txt"), produces=Path("out.txt")): ) def test_pytask_execute_task_process_report(monkeypatch, exc_info, expected): monkeypatch.setattr( - "_pytask.persist.update_states_in_database", + "_pytask.persist.update_states", lambda *x: None, # noqa: ARG005 ) From ca37ead9123cbdc44240b6efc0c5b89482a021bc Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 23:42:07 +0100 Subject: [PATCH 11/13] Tighten DAG state update invariants --- src/_pytask/database_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/_pytask/database_utils.py b/src/_pytask/database_utils.py index ab5bd71c..09dcff2b 100644 --- a/src/_pytask/database_utils.py +++ b/src/_pytask/database_utils.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import TYPE_CHECKING from typing import Literal +from typing import cast from sqlalchemy import create_engine from sqlalchemy import inspect @@ -118,7 +119,12 @@ def update_states_in_database(session: Session, task_signature: str) -> None: for name in node_and_neighbors(session.dag, task_signature): node = session.dag.nodes[name] if isinstance(node, PProvisionalNode): - continue + msg = ( + f"Task {task_signature!r} still references provisional node " + f"{node.name!r} when updating database states." + ) + raise TypeError(msg) + node = cast("PTask | PNode", node) hash_ = node.state() if hash_ is None: continue From e6301844adc364143d1ae2b8d7b4a8aa1fd3b5c5 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 23:44:48 +0100 Subject: [PATCH 12/13] Remove redundant database state cast --- src/_pytask/database_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/_pytask/database_utils.py b/src/_pytask/database_utils.py index 09dcff2b..ece04652 100644 --- a/src/_pytask/database_utils.py +++ b/src/_pytask/database_utils.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import TYPE_CHECKING from typing import Literal -from typing import cast from sqlalchemy import create_engine from sqlalchemy import inspect @@ -124,7 +123,6 @@ def update_states_in_database(session: Session, task_signature: str) -> None: f"{node.name!r} when updating database states." ) raise TypeError(msg) - node = cast("PTask | PNode", node) hash_ = node.state() if hash_ is None: continue From ec1f8e4cdbef3167786da83293bf702ccf643c9f Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Fri, 27 Mar 2026 23:48:54 +0100 Subject: [PATCH 13/13] fix: cleanup --- src/_pytask/state.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/_pytask/state.py b/src/_pytask/state.py index 5b05c173..7968cf1f 100644 --- a/src/_pytask/state.py +++ b/src/_pytask/state.py @@ -4,9 +4,9 @@ from typing import TYPE_CHECKING -from _pytask.database_utils import get_node_change_info as _db_get_node_change_info -from _pytask.database_utils import has_node_changed as _db_has_node_changed -from _pytask.database_utils import update_states_in_database as _db_update_states +from _pytask.database_utils import get_node_change_info as db_get_node_change_info +from _pytask.database_utils import has_node_changed as db_has_node_changed +from _pytask.database_utils import update_states_in_database from _pytask.lockfile import LockfileState from _pytask.lockfile import build_portable_node_id from _pytask.lockfile import build_portable_task_id @@ -17,14 +17,10 @@ from _pytask.session import Session -def _get_lockfile_state(session: Session) -> LockfileState | None: - return session.config.get("lockfile_state") - - def has_node_changed( session: Session, task: PTask, node: PTask | PNode, state: str | None ) -> bool: - lockfile_state = _get_lockfile_state(session) + lockfile_state: LockfileState | None = session.config.get("lockfile_state") if lockfile_state and lockfile_state.use_lockfile_for_skip: if state is None: return True @@ -45,15 +41,15 @@ def has_node_changed( if stored_state is None: return True return state != stored_state - return _db_has_node_changed(task=task, node=node, state=state) + return db_has_node_changed(task=task, node=node, state=state) def get_node_change_info( session: Session, task: PTask, node: PTask | PNode, state: str | None ) -> tuple[bool, str, dict[str, str]]: - lockfile_state = _get_lockfile_state(session) + lockfile_state: LockfileState | None = session.config.get("lockfile_state") if not (lockfile_state and lockfile_state.use_lockfile_for_skip): - return _db_get_node_change_info(task=task, node=node, state=state) + return db_get_node_change_info(task=task, node=node, state=state) details: dict[str, str] = {} if state is None: @@ -88,7 +84,7 @@ def get_node_change_info( def update_states(session: Session, task: PTask) -> None: if session.dag is None: return - lockfile_state = _get_lockfile_state(session) + lockfile_state: LockfileState | None = session.config.get("lockfile_state") if lockfile_state is not None: lockfile_state.update_task(session, task) - _db_update_states(session, task.signature) + update_states_in_database(session, task.signature)