From fff7ca5962c88fc4a24d5c640cb1699142fba0e1 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 10:43:27 +0100 Subject: [PATCH 01/10] Refactor DAG internals and lazy-load networkx --- docs/source/tutorials/visualizing_the_dag.md | 7 +- pyproject.toml | 6 +- src/_pytask/dag.py | 30 +-- src/_pytask/dag_command.py | 26 ++- src/_pytask/dag_graph.py | 223 +++++++++++++++++++ src/_pytask/dag_utils.py | 41 ++-- src/_pytask/execute.py | 23 +- src/_pytask/mark/__init__.py | 9 +- src/_pytask/session.py | 4 +- src/_pytask/shared.py | 4 +- tests/test_dag_command.py | 72 +++++- tests/test_dag_utils.py | 4 +- uv.lock | 25 ++- 13 files changed, 406 insertions(+), 68 deletions(-) create mode 100644 src/_pytask/dag_graph.py diff --git a/docs/source/tutorials/visualizing_the_dag.md b/docs/source/tutorials/visualizing_the_dag.md index f3554f3f..297db87e 100644 --- a/docs/source/tutorials/visualizing_the_dag.md +++ b/docs/source/tutorials/visualizing_the_dag.md @@ -1,11 +1,12 @@ # Visualizing the DAG To visualize the [DAG](../glossary.md#dag) of the project, first, install -[pygraphviz](https://github.com/pygraphviz/pygraphviz) and -[graphviz](https://graphviz.org/). For example, you can both install with pixi +[networkx](https://networkx.org/), +[pygraphviz](https://github.com/pygraphviz/pygraphviz), and +[graphviz](https://graphviz.org/). For example, you can install them with pixi ```console -$ pixi add pygraphviz graphviz +$ pixi add networkx pygraphviz graphviz ``` After that, pytask offers two interfaces to visualize your project's `DAG`. diff --git a/pyproject.toml b/pyproject.toml index 88c9f812..4906bc23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ dependencies = [ "click>=8.1.8,!=8.2.0", "click-default-group>=1.2.4", "msgspec>=0.18.6", - "networkx>=2.4.0", "optree>=0.9.0", "packaging>=23.0.0", "pluggy>=1.3.0", @@ -36,6 +35,9 @@ dependencies = [ "universal-pathlib>=0.2.2", ] +[project.optional-dependencies] +dag = ["networkx>=2.4.0"] + [project.readme] file = "README.md" content-type = "text/markdown" @@ -54,6 +56,7 @@ docs = [ "ipywidgets>=8.1.6", "matplotlib>=3.5.0", "mkdocstrings[python]>=0.30.0", + "networkx>=2.4.0", "zensical>=0.0.23", ] docs-live = ["sphinx-autobuild>=2024.10.3"] @@ -71,6 +74,7 @@ test = [ "syrupy>=4.5.0", "aiohttp>=3.11.0", # For HTTPPath tests. "coiled>=1.42.0; python_version < '3.14'", + "networkx>=2.4.0", "pygraphviz>=1.12;platform_system=='Linux'", ] typing = ["ty>=0.0.8"] diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index 538462d8..7d6f19b9 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -6,7 +6,6 @@ import sys from typing import TYPE_CHECKING -import networkx as nx from rich.text import Text from rich.tree import Tree @@ -17,6 +16,9 @@ from _pytask.console import format_node_name from _pytask.console import format_task_name from _pytask.console import render_to_string +from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import NoCycleError +from _pytask.dag_graph import find_cycle from _pytask.exceptions import ResolvingDependenciesError from _pytask.mark import select_by_after_keyword from _pytask.mark import select_tasks_by_marks_and_expressions @@ -37,7 +39,7 @@ __all__ = ["create_dag", "create_dag_from_session"] -def create_dag(session: Session) -> nx.DiGraph: +def create_dag(session: Session) -> DiGraph: """Create a directed acyclic graph (DAG) for the workflow.""" try: dag = create_dag_from_session(session) @@ -50,7 +52,7 @@ def create_dag(session: Session) -> nx.DiGraph: return dag -def create_dag_from_session(session: Session) -> nx.DiGraph: +def create_dag_from_session(session: Session) -> DiGraph: """Create a DAG from a session.""" dag = _create_dag_from_tasks(tasks=session.tasks) _check_if_dag_has_cycles(dag) @@ -60,11 +62,11 @@ def create_dag_from_session(session: Session) -> nx.DiGraph: return dag -def _create_dag_from_tasks(tasks: list[PTask]) -> nx.DiGraph: +def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph: """Create the DAG from tasks, dependencies and products.""" def _add_dependency( - dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode + dag: DiGraph, task: PTask, node: PNode | PProvisionalNode ) -> None: """Add a dependency to the DAG.""" dag.add_node(node.signature, node=node) @@ -76,14 +78,12 @@ def _add_dependency( if isinstance(node, PythonNode) and isinstance(node.value, PythonNode): dag.add_edge(node.value.signature, node.signature) - def _add_product( - dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode - ) -> None: + def _add_product(dag: DiGraph, task: PTask, node: PNode | PProvisionalNode) -> None: """Add a product to the DAG.""" dag.add_node(node.signature, node=node) dag.add_edge(task.signature, node.signature) - dag = nx.DiGraph() + dag = DiGraph() for task in tasks: dag.add_node(task.signature, task=task) @@ -105,7 +105,7 @@ def _add_product( return dag -def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph: +def _modify_dag(session: Session, dag: DiGraph) -> DiGraph: """Create dependencies between tasks when using ``@task(after=...)``.""" temporary_id_to_task = { task.attributes["collection_id"]: task @@ -129,11 +129,11 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph: return dag -def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None: +def _check_if_dag_has_cycles(dag: DiGraph) -> None: """Check if DAG has cycles.""" try: - cycles = nx.algorithms.cycles.find_cycle(dag) - except nx.NetworkXNoCycle: + cycles = find_cycle(dag) + except NoCycleError: pass else: msg = ( @@ -145,7 +145,7 @@ def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None: raise ResolvingDependenciesError(msg) -def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str: +def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str: """Format cycles as a paths connected by arrows.""" chain = [ x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0 @@ -176,7 +176,7 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str: return render_to_string(tree, console=console, strip_styles=True) -def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None: +def _check_if_tasks_have_the_same_products(dag: DiGraph, paths: list[Path]) -> None: nodes_created_by_multiple_tasks = [] for node in dag.nodes: diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index f7ed3c6f..e0497536 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -5,11 +5,11 @@ import enum import sys from pathlib import Path +from typing import TYPE_CHECKING from typing import Any from typing import cast import click -import networkx as nx from rich.text import Text from _pytask.click import ColoredCommand @@ -30,6 +30,11 @@ from _pytask.shared import reduce_names_of_multiple_nodes from _pytask.traceback import Traceback +if TYPE_CHECKING: + import networkx as nx + + from _pytask.dag_graph import DiGraph + class _RankDirection(enum.Enum): TB = "TB" @@ -92,6 +97,7 @@ def dag(**raw_config: Any) -> int: else: try: session.hook.pytask_log_session_header(session=session) + import_optional_dependency("networkx") import_optional_dependency("pygraphviz") check_for_optional_program( session.config["layout"], @@ -100,7 +106,7 @@ def dag(**raw_config: Any) -> int: ) session.hook.pytask_collect(session=session) session.dag = create_dag(session=session) - dag = _refine_dag(session) + dag = _refine_dag(session).to_networkx() _write_graph(dag, session.config["output_path"], session.config["layout"]) except CollectionError: # pragma: no cover @@ -163,6 +169,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: else: session.hook.pytask_log_session_header(session=session) + import_optional_dependency("networkx") import_optional_dependency("pygraphviz") check_for_optional_program( session.config["layout"], @@ -172,10 +179,10 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: session.hook.pytask_collect(session=session) session.dag = create_dag(session=session) session.hook.pytask_unconfigure(session=session) - return _refine_dag(session) + return _refine_dag(session).to_networkx() -def _refine_dag(session: Session) -> nx.DiGraph: +def _refine_dag(session: Session) -> DiGraph: """Refine the dag for plotting.""" dag = _shorten_node_labels(session.dag, session.config["paths"]) dag = _clean_dag(dag) @@ -185,31 +192,32 @@ def _refine_dag(session: Session) -> nx.DiGraph: return dag -def _shorten_node_labels(dag: nx.DiGraph, paths: list[Path]) -> nx.DiGraph: +def _shorten_node_labels(dag: DiGraph, paths: list[Path]) -> DiGraph: """Shorten the node labels in the graph for a better experience.""" node_names = dag.nodes short_names = reduce_names_of_multiple_nodes(node_names, dag, paths) short_names = [i.plain if isinstance(i, Text) else i for i in short_names] old_to_new = dict(zip(node_names, short_names, strict=False)) - return nx.relabel_nodes(dag, old_to_new) + return dag.relabel_nodes(old_to_new) -def _clean_dag(dag: nx.DiGraph) -> nx.DiGraph: +def _clean_dag(dag: DiGraph) -> DiGraph: """Clean the DAG.""" for node in dag.nodes: dag.nodes[node].clear() return dag -def _style_dag(dag: nx.DiGraph) -> nx.DiGraph: +def _style_dag(dag: DiGraph) -> DiGraph: """Style the DAG.""" shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes} - nx.set_node_attributes(dag, shapes, "shape") + dag.set_node_attributes(shapes, "shape") return dag def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None: """Write the graph to disk.""" + nx = cast("Any", import_optional_dependency("networkx")) path.parent.mkdir(exist_ok=True, parents=True) graph = nx.nx_agraph.to_agraph(dag) graph.draw(path, prog=layout) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py new file mode 100644 index 00000000..8fa3d8c6 --- /dev/null +++ b/src/_pytask/dag_graph.py @@ -0,0 +1,223 @@ +"""Internal DAG implementation used by pytask.""" + +from __future__ import annotations + +import itertools +from typing import TYPE_CHECKING +from typing import Any +from typing import cast + +from _pytask.compat import import_optional_dependency + +if TYPE_CHECKING: + from collections.abc import Iterator + + +class NoCycleError(Exception): + """Raised when no cycle is found in a graph.""" + + +class NodeView: + """A minimal mapping-like view over node attributes.""" + + def __init__(self, node_attributes: dict[str, dict[str, Any]]) -> None: + self._node_attributes = node_attributes + + def __getitem__(self, node: str) -> dict[str, Any]: + return self._node_attributes[node] + + def __iter__(self) -> Iterator[str]: + return iter(self._node_attributes) + + def __len__(self) -> int: + return len(self._node_attributes) + + def __contains__(self, node: object) -> bool: + return node in self._node_attributes + + +class UndirectedGraph: + """A minimal undirected graph used for validation tests.""" + + def __init__( + self, + node_attributes: dict[str, dict[str, Any]], + adjacency: dict[str, dict[str, None]], + graph_attributes: dict[str, Any], + ) -> None: + self._node_attributes = { + node: attributes.copy() for node, attributes in node_attributes.items() + } + self._adjacency = { + node: neighbors.copy() for node, neighbors in adjacency.items() + } + self.graph = graph_attributes.copy() + self.nodes = NodeView(self._node_attributes) + + def is_directed(self) -> bool: + return False + + +class DiGraph: + """A minimal directed graph tailored to pytask's needs.""" + + def __init__(self) -> None: + self._node_attributes: dict[str, dict[str, Any]] = {} + self._successors: dict[str, dict[str, None]] = {} + self._predecessors: dict[str, dict[str, None]] = {} + self.graph: dict[str, Any] = {} + self.nodes = NodeView(self._node_attributes) + + def add_node(self, node_name: str, **attributes: Any) -> None: + if node_name not in self._node_attributes: + self._node_attributes[node_name] = {} + self._successors[node_name] = {} + self._predecessors[node_name] = {} + self._node_attributes[node_name].update(attributes) + + def add_edge(self, source: str, target: str) -> None: + self.add_node(source) + self.add_node(target) + self._successors[source][target] = None + self._predecessors[target][source] = None + + def successors(self, node: str) -> Iterator[str]: + return iter(self._successors[node]) + + def predecessors(self, node: str) -> Iterator[str]: + return iter(self._predecessors[node]) + + def in_degree(self) -> Iterator[tuple[str, int]]: + for node, predecessors_ in self._predecessors.items(): + yield node, len(predecessors_) + + def remove_nodes_from(self, nodes: list[str] | set[str] | tuple[str, ...]) -> None: + for node in nodes: + if node not in self._node_attributes: + continue + for predecessor in tuple(self._predecessors[node]): + self._successors[predecessor].pop(node, None) + for successor in tuple(self._successors[node]): + self._predecessors[successor].pop(node, None) + del self._node_attributes[node] + del self._successors[node] + del self._predecessors[node] + + def is_directed(self) -> bool: + return True + + def reverse(self) -> DiGraph: + graph = DiGraph() + graph.graph = self.graph.copy() + for node, attributes in self._node_attributes.items(): + graph.add_node(node, **attributes.copy()) + for source, successors in self._successors.items(): + for target in successors: + graph.add_edge(target, source) + return graph + + def relabel_nodes(self, mapping: dict[str, str]) -> DiGraph: + graph = DiGraph() + graph.graph = self.graph.copy() + + new_labels = [mapping.get(node, node) for node in self._node_attributes] + if len(new_labels) != len(set(new_labels)): + msg = "Relabeling nodes requires unique target labels." + raise ValueError(msg) + + for node, attributes in self._node_attributes.items(): + graph.add_node(mapping.get(node, node), **attributes.copy()) + for source, successors in self._successors.items(): + new_source = mapping.get(source, source) + for target in successors: + graph.add_edge(new_source, mapping.get(target, target)) + return graph + + def set_node_attributes(self, values: dict[str, Any], name: str) -> None: + for node, value in values.items(): + if node in self._node_attributes: + self._node_attributes[node][name] = value + + def to_undirected(self) -> UndirectedGraph: + adjacency = { + node: { + **self._predecessors[node], + **self._successors[node], + } + for node in self._node_attributes + } + return UndirectedGraph(self._node_attributes, adjacency, self.graph) + + def to_networkx(self) -> Any: + nx = cast("Any", import_optional_dependency("networkx")) + graph = nx.DiGraph() + graph.graph = self.graph.copy() + for node, attributes in self._node_attributes.items(): + graph.add_node(node, **attributes.copy()) + for source, successors in self._successors.items(): + for target in successors: + graph.add_edge(source, target) + return graph + + +def descendants(dag: DiGraph, node: str) -> set[str]: + """Return all descendants of a node.""" + return _traverse(dag, node, dag.successors) + + +def ancestors(dag: DiGraph, node: str) -> set[str]: + """Return all ancestors of a node.""" + return _traverse(dag, node, dag.predecessors) + + +def _traverse( + _dag: DiGraph, + node: str, + adjacency: Any, +) -> set[str]: + visited: set[str] = set() + stack = list(adjacency(node)) + + while stack: + current = stack.pop() + if current in visited: + continue + visited.add(current) + stack.extend(adjacency(current)) + + return visited + + +def find_cycle(dag: DiGraph) -> list[tuple[str, str]]: + """Find one cycle in the graph.""" + visited: set[str] = set() + active: set[str] = set() + path: list[str] = [] + + def _visit(node: str) -> list[tuple[str, str]] | None: + visited.add(node) + active.add(node) + path.append(node) + + for successor in dag.successors(node): + if successor not in visited: + cycle = _visit(successor) + if cycle is not None: + return cycle + elif successor in active: + start = path.index(successor) + cycle_nodes = [*path[start:], successor] + return list(itertools.pairwise(cycle_nodes)) + + active.remove(node) + path.pop() + return None + + for node in dag.nodes: + if node in visited: + continue + cycle = _visit(node) + if cycle is not None: + return cycle + + raise NoCycleError diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index fb7bbfe9..47b248f7 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -7,8 +7,11 @@ from dataclasses import field from typing import TYPE_CHECKING -import networkx as nx - +from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import NoCycleError +from _pytask.dag_graph import ancestors +from _pytask.dag_graph import descendants +from _pytask.dag_graph import find_cycle from _pytask.mark_utils import has_mark if TYPE_CHECKING: @@ -18,37 +21,37 @@ from _pytask.node_protocols import PTask -def descending_tasks(task_name: str, dag: nx.DiGraph) -> Generator[str, None, None]: +def descending_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: """Yield only descending tasks.""" - for descendant in nx.descendants(dag, task_name): + for descendant in descendants(dag, task_name): if "task" in dag.nodes[descendant]: yield descendant def task_and_descending_tasks( - task_name: str, dag: nx.DiGraph + task_name: str, dag: DiGraph ) -> Generator[str, None, None]: """Yield task and descending tasks.""" yield task_name yield from descending_tasks(task_name, dag) -def preceding_tasks(task_name: str, dag: nx.DiGraph) -> Generator[str, None, None]: +def preceding_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: """Yield only preceding tasks.""" - for ancestor in nx.ancestors(dag, task_name): + for ancestor in ancestors(dag, task_name): if "task" in dag.nodes[ancestor]: yield ancestor def task_and_preceding_tasks( - task_name: str, dag: nx.DiGraph + task_name: str, dag: DiGraph ) -> Generator[str, None, None]: """Yield task and preceding tasks.""" yield task_name yield from preceding_tasks(task_name, dag) -def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]: +def node_and_neighbors(dag: DiGraph, node: str) -> Iterable[str]: """Yield node and neighbors which are first degree predecessors and successors. We cannot use ``dag.neighbors`` as it only considers successors as neighbors in a @@ -77,13 +80,13 @@ class TopologicalSorter: """ - dag: nx.DiGraph + dag: DiGraph priorities: dict[str, int] = field(default_factory=dict) _nodes_processing: set[str] = field(default_factory=set) _nodes_done: set[str] = field(default_factory=set) @classmethod - def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: + def from_dag(cls, dag: DiGraph) -> TopologicalSorter: """Instantiate from a DAG.""" cls.check_dag(dag) @@ -93,14 +96,18 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: priorities = _extract_priorities_from_tasks(tasks) task_signatures = {task.signature for task in tasks} - task_dict = {s: nx.ancestors(dag, s) & task_signatures for s in task_signatures} - task_dag = nx.DiGraph(task_dict).reverse() + task_dag = DiGraph() + for signature in task_signatures: + task_dag.add_node(signature) + for signature in task_signatures: + for ancestor_ in ancestors(dag, signature) & task_signatures: + task_dag.add_edge(ancestor_, signature) return cls(dag=task_dag, priorities=priorities) @classmethod def from_dag_and_sorter( - cls, dag: nx.DiGraph, sorter: TopologicalSorter + cls, dag: DiGraph, sorter: TopologicalSorter ) -> TopologicalSorter: """Instantiate a sorter from another sorter and a DAG.""" new_sorter = cls.from_dag(dag) @@ -109,14 +116,14 @@ def from_dag_and_sorter( return new_sorter @staticmethod - def check_dag(dag: nx.DiGraph) -> None: + def check_dag(dag: DiGraph) -> None: if not dag.is_directed(): msg = "Only directed graphs have a topological order." raise ValueError(msg) try: - nx.algorithms.cycles.find_cycle(dag) - except nx.NetworkXNoCycle: + find_cycle(dag) + except NoCycleError: pass else: msg = "The DAG contains cycles." diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 9b3fd8c3..8dd8a1e5 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -7,6 +7,7 @@ import time from typing import TYPE_CHECKING from typing import Any +from typing import cast from rich.text import Text @@ -172,9 +173,11 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): - node = dag.nodes[node_signature].get("task") or dag.nodes[ - node_signature - ].get("node") + node = cast( + "PTask | PNode | PProvisionalNode", + dag.nodes[node_signature].get("task") + or dag.nodes[node_signature].get("node"), + ) # Skip provisional nodes that are products since they do not have a state. if node_signature not in predecessors and isinstance( @@ -182,7 +185,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ): continue - node_state = node.state() + node_state = cast("Any", node).state() if node_signature in predecessors and not node_state: msg = f"{task.name!r} requires missing node {node.name!r}." @@ -196,7 +199,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C # Check if node changed and collect detailed info if in explain mode if session.config["explain"]: has_changed, reason, details = get_node_change_info( - session=session, task=task, node=node, state=node_state + session=session, + task=task, + node=cast("PTask | PNode", node), + state=node_state, ) if has_changed: needs_to_be_executed = True @@ -214,7 +220,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C reason_typed: ReasonType = reason # type: ignore[assignment] change_reasons.append( create_change_reason( - node=node, + node=cast("PTask | PNode", node), node_type=node_type, reason=reason_typed, old_hash=details.get("old_hash"), @@ -223,7 +229,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ) else: has_changed = has_node_changed( - session=session, task=task, node=node, state=node_state + session=session, + task=task, + node=cast("PTask | PNode", node), + state=node_state, ) if has_changed: needs_to_be_executed = True diff --git a/src/_pytask/mark/__init__.py b/src/_pytask/mark/__init__.py index 9cee0c2b..2f6f0727 100644 --- a/src/_pytask/mark/__init__.py +++ b/src/_pytask/mark/__init__.py @@ -30,8 +30,7 @@ from collections.abc import Set as AbstractSet from typing import NoReturn - import networkx as nx - + from _pytask.dag_graph import DiGraph from _pytask.node_protocols import PTask @@ -153,7 +152,7 @@ def __call__(self, subname: str) -> bool: return any(subname in name for name in names) -def select_by_keyword(session: Session, dag: nx.DiGraph) -> set[str] | None: +def select_by_keyword(session: Session, dag: DiGraph) -> set[str] | None: """Deselect tests by keywords.""" keywordexpr = session.config["expression"] if not keywordexpr: @@ -208,7 +207,7 @@ def __call__(self, name: str) -> bool: return name in self.own_mark_names -def select_by_mark(session: Session, dag: nx.DiGraph) -> set[str] | None: +def select_by_mark(session: Session, dag: DiGraph) -> set[str] | None: """Deselect tests by marks.""" matchexpr = session.config["marker_expression"] if not matchexpr: @@ -237,7 +236,7 @@ def _deselect_others_with_mark( task.markers.append(mark) -def select_tasks_by_marks_and_expressions(session: Session, dag: nx.DiGraph) -> None: +def select_tasks_by_marks_and_expressions(session: Session, dag: DiGraph) -> None: """Modify the tasks which are executed with expressions and markers.""" remaining = select_by_keyword(session, dag) if remaining is not None: diff --git a/src/_pytask/session.py b/src/_pytask/session.py index 79f7f06c..09008742 100644 --- a/src/_pytask/session.py +++ b/src/_pytask/session.py @@ -7,9 +7,9 @@ from typing import TYPE_CHECKING from typing import Any -import networkx as nx from pluggy import HookRelay +from _pytask.dag_graph import DiGraph from _pytask.outcomes import ExitCode if TYPE_CHECKING: @@ -51,7 +51,7 @@ class Session: config: dict[str, Any] = field(default_factory=dict) collection_reports: list[CollectionReport] = field(default_factory=list) - dag: nx.DiGraph = field(default_factory=nx.DiGraph) + dag: DiGraph = field(default_factory=DiGraph) hook: HookRelay = field(default_factory=HookRelay) tasks: list[PTask] = field(default_factory=list) dag_report: DagReport | None = None diff --git a/src/_pytask/shared.py b/src/_pytask/shared.py index 770a6b8c..ccd11988 100644 --- a/src/_pytask/shared.py +++ b/src/_pytask/shared.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from enum import Enum - import networkx as nx + from _pytask.dag_graph import DiGraph __all__ = [ @@ -79,7 +79,7 @@ def parse_paths(x: Path | list[Path]) -> list[Path]: def reduce_names_of_multiple_nodes( - names: list[str], dag: nx.DiGraph, paths: Sequence[Path] + names: Iterable[str], dag: DiGraph, paths: Sequence[Path] ) -> list[str]: """Reduce the names of multiple nodes in the DAG.""" short_names = [] diff --git a/tests/test_dag_command.py b/tests/test_dag_command.py index c9573c83..9b40748b 100644 --- a/tests/test_dag_command.py +++ b/tests/test_dag_command.py @@ -120,7 +120,11 @@ def task_example(path=Path("input.txt")): ... monkeypatch.setattr( "_pytask.compat.import_module", - lambda x: _raise_exc(ImportError("pygraphviz not found")), # noqa: ARG005 + lambda x: ( + _raise_exc(ImportError("pygraphviz not found")) + if x == "pygraphviz" + else importlib.import_module(x) + ), ) result = runner.invoke( @@ -136,6 +140,33 @@ def task_example(path=Path("input.txt")): ... assert not tmp_path.joinpath("dag.png").exists() +def test_raise_error_with_graph_via_cli_missing_networkx(monkeypatch, tmp_path, runner): + source = """ + from pathlib import Path + + def task_example(path=Path("input.txt")): ... + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + tmp_path.joinpath("input.txt").touch() + + monkeypatch.setattr( + "_pytask.compat.import_module", + lambda x: ( + _raise_exc(ImportError("networkx not found")) if x == "networkx" else None + ), + ) + + result = runner.invoke( + cli, + ["dag", tmp_path.as_posix(), "-o", tmp_path.joinpath("dag.png"), "-l", "dot"], + ) + + assert result.exit_code == ExitCode.FAILED + assert "pytask requires the optional dependency 'networkx'." in result.output + assert "Traceback" not in result.output + assert not tmp_path.joinpath("dag.png").exists() + + def test_raise_error_with_graph_via_task_missing_optional_dependency( monkeypatch, tmp_path, runner ): @@ -154,7 +185,11 @@ def task_create_graph(): monkeypatch.setattr( "_pytask.compat.import_module", - lambda x: _raise_exc(ImportError("pygraphviz not found")), # noqa: ARG005 + lambda x: ( + _raise_exc(ImportError("pygraphviz not found")) + if x == "pygraphviz" + else importlib.import_module(x) + ), ) result = runner.invoke(cli, [tmp_path.as_posix()]) @@ -167,6 +202,39 @@ def task_create_graph(): assert not tmp_path.joinpath("dag.png").exists() +def test_raise_error_with_graph_via_task_missing_networkx( + monkeypatch, tmp_path, runner +): + source = """ + import pytask + from pathlib import Path + import networkx as nx + + def task_create_graph(): + dag = pytask.build_dag({"paths": Path(__file__).parent}) + graph = nx.nx_agraph.to_agraph(dag) + path = Path(__file__).parent.joinpath("dag.png") + graph.draw(path, prog="dot") + """ + tmp_path.joinpath("task_example.py").write_text(textwrap.dedent(source)) + + monkeypatch.setattr( + "_pytask.compat.import_module", + lambda x: ( + _raise_exc(ImportError("networkx not found")) + if x == "networkx" + else importlib.import_module(x) + ), + ) + + result = runner.invoke(cli, [tmp_path.as_posix()]) + + assert result.exit_code == ExitCode.FAILED + assert "pytask requires the optional dependency 'networkx'." in result.output + assert "Traceback" in result.output + assert not tmp_path.joinpath("dag.png").exists() + + def test_raise_error_with_graph_via_cli_missing_optional_program( monkeypatch, tmp_path, runner ): diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index d11f2af2..db6be94f 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -3,9 +3,9 @@ from contextlib import ExitStack as does_not_raise # noqa: N813 from pathlib import Path -import networkx as nx import pytest +from _pytask.dag_graph import DiGraph from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import _extract_priorities_from_tasks from _pytask.dag_utils import descending_tasks @@ -19,7 +19,7 @@ @pytest.fixture def dag(): """Create a dag with five nodes in a line.""" - dag = nx.DiGraph() + dag = DiGraph() for i in range(4): task = Task(base_name=str(i), path=Path(), function=noop) next_task = Task(base_name=str(i + 1), path=Path(), function=noop) diff --git a/uv.lock b/uv.lock index 1df51b41..f9dae641 100644 --- a/uv.lock +++ b/uv.lock @@ -1279,6 +1279,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/db/b4c12cff13ebac2786f4f217f06588bccd8b53d260453404ef22b121fc3a/greenlet-3.2.3-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:1afd685acd5597349ee6d7a88a8bec83ce13c106ac78c196ee9dde7c04fe87be", size = 268977, upload-time = "2025-06-05T16:10:24.001Z" }, { url = "https://files.pythonhosted.org/packages/52/61/75b4abd8147f13f70986df2801bf93735c1bd87ea780d70e3b3ecda8c165/greenlet-3.2.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:761917cac215c61e9dc7324b2606107b3b292a8349bdebb31503ab4de3f559ac", size = 627351, upload-time = "2025-06-05T16:38:50.685Z" }, { url = "https://files.pythonhosted.org/packages/35/aa/6894ae299d059d26254779a5088632874b80ee8cf89a88bca00b0709d22f/greenlet-3.2.3-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a433dbc54e4a37e4fff90ef34f25a8c00aed99b06856f0119dcf09fbafa16392", size = 638599, upload-time = "2025-06-05T16:41:34.057Z" }, + { url = "https://files.pythonhosted.org/packages/30/64/e01a8261d13c47f3c082519a5e9dbf9e143cc0498ed20c911d04e54d526c/greenlet-3.2.3-cp310-cp310-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:72e77ed69312bab0434d7292316d5afd6896192ac4327d44f3d613ecb85b037c", size = 634482, upload-time = "2025-06-05T16:48:16.26Z" }, { url = "https://files.pythonhosted.org/packages/47/48/ff9ca8ba9772d083a4f5221f7b4f0ebe8978131a9ae0909cf202f94cd879/greenlet-3.2.3-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:68671180e3849b963649254a882cd544a3c75bfcd2c527346ad8bb53494444db", size = 633284, upload-time = "2025-06-05T16:13:01.599Z" }, { url = "https://files.pythonhosted.org/packages/e9/45/626e974948713bc15775b696adb3eb0bd708bec267d6d2d5c47bb47a6119/greenlet-3.2.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49c8cfb18fb419b3d08e011228ef8a25882397f3a859b9fe1436946140b6756b", size = 582206, upload-time = "2025-06-05T16:12:48.51Z" }, { url = "https://files.pythonhosted.org/packages/b1/8e/8b6f42c67d5df7db35b8c55c9a850ea045219741bb14416255616808c690/greenlet-3.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:efc6dc8a792243c31f2f5674b670b3a95d46fa1c6a912b8e310d6f542e7b0712", size = 1111412, upload-time = "2025-06-05T16:36:45.479Z" }, @@ -1287,6 +1288,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/2e/d4fcb2978f826358b673f779f78fa8a32ee37df11920dc2bb5589cbeecef/greenlet-3.2.3-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:784ae58bba89fa1fa5733d170d42486580cab9decda3484779f4759345b29822", size = 270219, upload-time = "2025-06-05T16:10:10.414Z" }, { url = "https://files.pythonhosted.org/packages/16/24/929f853e0202130e4fe163bc1d05a671ce8dcd604f790e14896adac43a52/greenlet-3.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0921ac4ea42a5315d3446120ad48f90c3a6b9bb93dd9b3cf4e4d84a66e42de83", size = 630383, upload-time = "2025-06-05T16:38:51.785Z" }, { url = "https://files.pythonhosted.org/packages/d1/b2/0320715eb61ae70c25ceca2f1d5ae620477d246692d9cc284c13242ec31c/greenlet-3.2.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d2971d93bb99e05f8c2c0c2f4aa9484a18d98c4c3bd3c62b65b7e6ae33dfcfaf", size = 642422, upload-time = "2025-06-05T16:41:35.259Z" }, + { url = "https://files.pythonhosted.org/packages/bd/49/445fd1a210f4747fedf77615d941444349c6a3a4a1135bba9701337cd966/greenlet-3.2.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:c667c0bf9d406b77a15c924ef3285e1e05250948001220368e039b6aa5b5034b", size = 638375, upload-time = "2025-06-05T16:48:18.235Z" }, { url = "https://files.pythonhosted.org/packages/7e/c8/ca19760cf6eae75fa8dc32b487e963d863b3ee04a7637da77b616703bc37/greenlet-3.2.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:592c12fb1165be74592f5de0d70f82bc5ba552ac44800d632214b76089945147", size = 637627, upload-time = "2025-06-05T16:13:02.858Z" }, { url = "https://files.pythonhosted.org/packages/65/89/77acf9e3da38e9bcfca881e43b02ed467c1dedc387021fc4d9bd9928afb8/greenlet-3.2.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:29e184536ba333003540790ba29829ac14bb645514fbd7e32af331e8202a62a5", size = 585502, upload-time = "2025-06-05T16:12:49.642Z" }, { url = "https://files.pythonhosted.org/packages/97/c6/ae244d7c95b23b7130136e07a9cc5aadd60d59b5951180dc7dc7e8edaba7/greenlet-3.2.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:93c0bb79844a367782ec4f429d07589417052e621aa39a5ac1fb99c5aa308edc", size = 1114498, upload-time = "2025-06-05T16:36:46.598Z" }, @@ -1295,6 +1297,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f3/94/ad0d435f7c48debe960c53b8f60fb41c2026b1d0fa4a99a1cb17c3461e09/greenlet-3.2.3-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:25ad29caed5783d4bd7a85c9251c651696164622494c00802a139c00d639242d", size = 271992, upload-time = "2025-06-05T16:11:23.467Z" }, { url = "https://files.pythonhosted.org/packages/93/5d/7c27cf4d003d6e77749d299c7c8f5fd50b4f251647b5c2e97e1f20da0ab5/greenlet-3.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88cd97bf37fe24a6710ec6a3a7799f3f81d9cd33317dcf565ff9950c83f55e0b", size = 638820, upload-time = "2025-06-05T16:38:52.882Z" }, { url = "https://files.pythonhosted.org/packages/c6/7e/807e1e9be07a125bb4c169144937910bf59b9d2f6d931578e57f0bce0ae2/greenlet-3.2.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:baeedccca94880d2f5666b4fa16fc20ef50ba1ee353ee2d7092b383a243b0b0d", size = 653046, upload-time = "2025-06-05T16:41:36.343Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ab/158c1a4ea1068bdbc78dba5a3de57e4c7aeb4e7fa034320ea94c688bfb61/greenlet-3.2.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:be52af4b6292baecfa0f397f3edb3c6092ce071b499dd6fe292c9ac9f2c8f264", size = 647701, upload-time = "2025-06-05T16:48:19.604Z" }, { url = "https://files.pythonhosted.org/packages/cc/0d/93729068259b550d6a0288da4ff72b86ed05626eaf1eb7c0d3466a2571de/greenlet-3.2.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0cc73378150b8b78b0c9fe2ce56e166695e67478550769536a6742dca3651688", size = 649747, upload-time = "2025-06-05T16:13:04.628Z" }, { url = "https://files.pythonhosted.org/packages/f6/f6/c82ac1851c60851302d8581680573245c8fc300253fc1ff741ae74a6c24d/greenlet-3.2.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:706d016a03e78df129f68c4c9b4c4f963f7d73534e48a24f5f5a7101ed13dbbb", size = 605461, upload-time = "2025-06-05T16:12:50.792Z" }, { url = "https://files.pythonhosted.org/packages/98/82/d022cf25ca39cf1200650fc58c52af32c90f80479c25d1cbf57980ec3065/greenlet-3.2.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:419e60f80709510c343c57b4bb5a339d8767bf9aef9b8ce43f4f143240f88b7c", size = 1121190, upload-time = "2025-06-05T16:36:48.59Z" }, @@ -1303,6 +1306,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/cf/f5c0b23309070ae93de75c90d29300751a5aacefc0a3ed1b1d8edb28f08b/greenlet-3.2.3-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:500b8689aa9dd1ab26872a34084503aeddefcb438e2e7317b89b11eaea1901ad", size = 270732, upload-time = "2025-06-05T16:10:08.26Z" }, { url = "https://files.pythonhosted.org/packages/48/ae/91a957ba60482d3fecf9be49bc3948f341d706b52ddb9d83a70d42abd498/greenlet-3.2.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a07d3472c2a93117af3b0136f246b2833fdc0b542d4a9799ae5f41c28323faef", size = 639033, upload-time = "2025-06-05T16:38:53.983Z" }, { url = "https://files.pythonhosted.org/packages/6f/df/20ffa66dd5a7a7beffa6451bdb7400d66251374ab40b99981478c69a67a8/greenlet-3.2.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:8704b3768d2f51150626962f4b9a9e4a17d2e37c8a8d9867bbd9fa4eb938d3b3", size = 652999, upload-time = "2025-06-05T16:41:37.89Z" }, + { url = "https://files.pythonhosted.org/packages/51/b4/ebb2c8cb41e521f1d72bf0465f2f9a2fd803f674a88db228887e6847077e/greenlet-3.2.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5035d77a27b7c62db6cf41cf786cfe2242644a7a337a0e155c80960598baab95", size = 647368, upload-time = "2025-06-05T16:48:21.467Z" }, { url = "https://files.pythonhosted.org/packages/8e/6a/1e1b5aa10dced4ae876a322155705257748108b7fd2e4fae3f2a091fe81a/greenlet-3.2.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2d8aa5423cd4a396792f6d4580f88bdc6efcb9205891c9d40d20f6e670992efb", size = 650037, upload-time = "2025-06-05T16:13:06.402Z" }, { url = "https://files.pythonhosted.org/packages/26/f2/ad51331a157c7015c675702e2d5230c243695c788f8f75feba1af32b3617/greenlet-3.2.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2c724620a101f8170065d7dded3f962a2aea7a7dae133a009cada42847e04a7b", size = 608402, upload-time = "2025-06-05T16:12:51.91Z" }, { url = "https://files.pythonhosted.org/packages/26/bc/862bd2083e6b3aff23300900a956f4ea9a4059de337f5c8734346b9b34fc/greenlet-3.2.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:873abe55f134c48e1f2a6f53f7d1419192a3d1a4e873bace00499a4e45ea6af0", size = 1119577, upload-time = "2025-06-05T16:36:49.787Z" }, @@ -1311,6 +1315,7 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/ca/accd7aa5280eb92b70ed9e8f7fd79dc50a2c21d8c73b9a0856f5b564e222/greenlet-3.2.3-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:3d04332dddb10b4a211b68111dabaee2e1a073663d117dc10247b5b1642bac86", size = 271479, upload-time = "2025-06-05T16:10:47.525Z" }, { url = "https://files.pythonhosted.org/packages/55/71/01ed9895d9eb49223280ecc98a557585edfa56b3d0e965b9fa9f7f06b6d9/greenlet-3.2.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8186162dffde068a465deab08fc72c767196895c39db26ab1c17c0b77a6d8b97", size = 683952, upload-time = "2025-06-05T16:38:55.125Z" }, { url = "https://files.pythonhosted.org/packages/ea/61/638c4bdf460c3c678a0a1ef4c200f347dff80719597e53b5edb2fb27ab54/greenlet-3.2.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f4bfbaa6096b1b7a200024784217defedf46a07c2eee1a498e94a1b5f8ec5728", size = 696917, upload-time = "2025-06-05T16:41:38.959Z" }, + { url = "https://files.pythonhosted.org/packages/22/cc/0bd1a7eb759d1f3e3cc2d1bc0f0b487ad3cc9f34d74da4b80f226fde4ec3/greenlet-3.2.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:ed6cfa9200484d234d8394c70f5492f144b20d4533f69262d530a1a082f6ee9a", size = 692443, upload-time = "2025-06-05T16:48:23.113Z" }, { url = "https://files.pythonhosted.org/packages/67/10/b2a4b63d3f08362662e89c103f7fe28894a51ae0bc890fabf37d1d780e52/greenlet-3.2.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:02b0df6f63cd15012bed5401b47829cfd2e97052dc89da3cfaf2c779124eb892", size = 692995, upload-time = "2025-06-05T16:13:07.972Z" }, { url = "https://files.pythonhosted.org/packages/5a/c6/ad82f148a4e3ce9564056453a71529732baf5448ad53fc323e37efe34f66/greenlet-3.2.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86c2d68e87107c1792e2e8d5399acec2487a4e993ab76c792408e59394d52141", size = 655320, upload-time = "2025-06-05T16:12:53.453Z" }, { url = "https://files.pythonhosted.org/packages/5c/4f/aab73ecaa6b3086a4c89863d94cf26fa84cbff63f52ce9bc4342b3087a06/greenlet-3.2.3-cp314-cp314-win_amd64.whl", hash = "sha256:8c47aae8fbbfcf82cc13327ae802ba13c9c36753b67e760023fd116bc124a62a", size = 301236, upload-time = "2025-06-05T16:15:20.111Z" }, @@ -1324,6 +1329,7 @@ dependencies = [ { name = "griffecli" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/04/56/28a0accac339c164b52a92c6cfc45a903acc0c174caa5c1713803467b533/griffe-2.0.0.tar.gz", hash = "sha256:c68979cd8395422083a51ea7cf02f9c119d889646d99b7b656ee43725de1b80f", size = 293906, upload-time = "2026-03-23T21:06:53.402Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8b/94/ee21d41e7eb4f823b94603b9d40f86d3c7fde80eacc2c3c71845476dddaa/griffe-2.0.0-py3-none-any.whl", hash = "sha256:5418081135a391c3e6e757a7f3f156f1a1a746cc7b4023868ff7d5e2f9a980aa", size = 5214, upload-time = "2026-02-09T19:09:44.105Z" }, ] @@ -1336,6 +1342,7 @@ dependencies = [ { name = "colorama" }, { name = "griffelib" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/a4/f8/2e129fd4a86e52e58eefe664de05e7d502decf766e7316cc9e70fdec3e18/griffecli-2.0.0.tar.gz", hash = "sha256:312fa5ebb4ce6afc786356e2d0ce85b06c1c20d45abc42d74f0cda65e159f6ef", size = 56213, upload-time = "2026-03-23T21:06:54.8Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ed/d93f7a447bbf7a935d8868e9617cbe1cadf9ee9ee6bd275d3040fbf93d60/griffecli-2.0.0-py3-none-any.whl", hash = "sha256:9f7cd9ee9b21d55e91689358978d2385ae65c22f307a63fb3269acf3f21e643d", size = 9345, upload-time = "2026-02-09T19:09:42.554Z" }, ] @@ -1344,6 +1351,7 @@ wheels = [ name = "griffelib" version = "2.0.0" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/06/eccbd311c9e2b3ca45dbc063b93134c57a1ccc7607c5e545264ad092c4a9/griffelib-2.0.0.tar.gz", hash = "sha256:e504d637a089f5cab9b5daf18f7645970509bf4f53eda8d79ed71cce8bd97934", size = 166312, upload-time = "2026-03-23T21:06:55.954Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, ] @@ -2876,8 +2884,6 @@ dependencies = [ { name = "click" }, { name = "click-default-group" }, { name = "msgspec", extra = ["toml"] }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "optree" }, { name = "packaging" }, { name = "pluggy" }, @@ -2888,6 +2894,12 @@ dependencies = [ { name = "universal-pathlib" }, ] +[package.optional-dependencies] +dag = [ + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, +] + [package.dev-dependencies] docs = [ { name = "furo" }, @@ -2895,6 +2907,8 @@ docs = [ { name = "ipywidgets" }, { name = "matplotlib" }, { name = "mkdocstrings", extra = ["python"] }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "zensical" }, ] docs-live = [ @@ -2912,6 +2926,8 @@ test = [ { name = "coiled", marker = "python_full_version < '3.14'" }, { name = "deepdiff" }, { name = "nbmake", marker = "python_full_version < '3.14' or sys_platform != 'win32'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "networkx", version = "3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pexpect" }, { name = "pygments" }, { name = "pygraphviz", marker = "sys_platform == 'linux'" }, @@ -2930,7 +2946,7 @@ requires-dist = [ { name = "click-default-group", specifier = ">=1.2.4" }, { name = "msgspec", specifier = ">=0.18.6" }, { name = "msgspec", extras = ["toml"], specifier = ">=0.18.6" }, - { name = "networkx", specifier = ">=2.4.0" }, + { name = "networkx", marker = "extra == 'dag'", specifier = ">=2.4.0" }, { name = "optree", specifier = ">=0.9.0" }, { name = "packaging", specifier = ">=23.0.0" }, { name = "pluggy", specifier = ">=1.3.0" }, @@ -2940,6 +2956,7 @@ requires-dist = [ { name = "typing-extensions", marker = "python_full_version < '3.11'", specifier = ">=4.8.0" }, { name = "universal-pathlib", specifier = ">=0.2.2" }, ] +provides-extras = ["dag"] [package.metadata.requires-dev] docs = [ @@ -2948,6 +2965,7 @@ docs = [ { name = "ipywidgets", specifier = ">=8.1.6" }, { name = "matplotlib", specifier = ">=3.5.0" }, { name = "mkdocstrings", extras = ["python"], specifier = ">=0.30.0" }, + { name = "networkx", specifier = ">=2.4.0" }, { name = "zensical", specifier = ">=0.0.23" }, ] docs-live = [{ name = "sphinx-autobuild", specifier = ">=2024.10.3" }] @@ -2962,6 +2980,7 @@ test = [ { name = "coiled", marker = "python_full_version < '3.14'", specifier = ">=1.42.0" }, { name = "deepdiff", specifier = ">=7.0.0" }, { name = "nbmake", marker = "python_full_version < '3.14' or sys_platform != 'win32'", specifier = ">=1.5.5" }, + { name = "networkx", specifier = ">=2.4.0" }, { name = "pexpect", specifier = ">=4.9.0" }, { name = "pygments", specifier = ">=2.18.0" }, { name = "pygraphviz", marker = "sys_platform == 'linux'", specifier = ">=1.12" }, From 4b73493bf72e43a85972b333575aa7926f0cb1d8 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 10:56:37 +0100 Subject: [PATCH 02/10] Clarify DAG scheduling and docs --- docs/source/tutorials/visualizing_the_dag.md | 17 ++++++++++--- src/_pytask/dag_utils.py | 3 +++ src/_pytask/execute.py | 26 +++++++++++++------- 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/docs/source/tutorials/visualizing_the_dag.md b/docs/source/tutorials/visualizing_the_dag.md index 297db87e..1089178d 100644 --- a/docs/source/tutorials/visualizing_the_dag.md +++ b/docs/source/tutorials/visualizing_the_dag.md @@ -3,11 +3,20 @@ To visualize the [DAG](../glossary.md#dag) of the project, first, install [networkx](https://networkx.org/), [pygraphviz](https://github.com/pygraphviz/pygraphviz), and -[graphviz](https://graphviz.org/). For example, you can install them with pixi +[graphviz](https://graphviz.org/). -```console -$ pixi add networkx pygraphviz graphviz -``` +=== "uv" + + ```console + $ uv add networkx + $ uv add --optional dag pygraphviz + ``` + +=== "pixi" + + ```console + $ pixi add networkx pygraphviz graphviz + ``` After that, pytask offers two interfaces to visualize your project's `DAG`. diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 47b248f7..130db30e 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -100,6 +100,9 @@ def from_dag(cls, dag: DiGraph) -> TopologicalSorter: for signature in task_signatures: task_dag.add_node(signature) for signature in task_signatures: + # The scheduler graph uses edges from predecessor -> successor so that + # zero in-degree means "ready to run". This is the same orientation the + # previous networkx-based implementation reached after calling reverse(). for ancestor_ in ancestors(dag, signature) & task_signatures: task_dag.add_edge(ancestor_, signature) diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 8dd8a1e5..8f0e94a2 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -173,11 +173,9 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C if not needs_to_be_executed: predecessors = set(dag.predecessors(task.signature)) | {task.signature} for node_signature in node_and_neighbors(dag, task.signature): - node = cast( - "PTask | PNode | PProvisionalNode", - dag.nodes[node_signature].get("task") - or dag.nodes[node_signature].get("node"), - ) + node = dag.nodes[node_signature].get("task") or dag.nodes[ + node_signature + ].get("node") # Skip provisional nodes that are products since they do not have a state. if node_signature not in predecessors and isinstance( @@ -185,7 +183,17 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ): continue - node_state = cast("Any", node).state() + # Provisional dependencies should have been resolved before task setup. + if isinstance(node, PProvisionalNode): + msg = ( + f"Task {task.name!r} still references provisional node " + f"{node.name!r} during execution setup." + ) + raise ExecutionError(msg) + + node = cast("PTask | PNode", node) + + node_state = node.state() if node_signature in predecessors and not node_state: msg = f"{task.name!r} requires missing node {node.name!r}." @@ -201,7 +209,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C has_changed, reason, details = get_node_change_info( session=session, task=task, - node=cast("PTask | PNode", node), + node=node, state=node_state, ) if has_changed: @@ -220,7 +228,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C reason_typed: ReasonType = reason # type: ignore[assignment] change_reasons.append( create_change_reason( - node=cast("PTask | PNode", node), + node=node, node_type=node_type, reason=reason_typed, old_hash=details.get("old_hash"), @@ -231,7 +239,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C has_changed = has_node_changed( session=session, task=task, - node=cast("PTask | PNode", node), + node=node, state=node_state, ) if has_changed: From 5099f820b7c0d4a16930ea833c0a93afed41685e Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 12:01:31 +0100 Subject: [PATCH 03/10] Simplify internal DAG types --- src/_pytask/dag_graph.py | 68 +++++++--------------------------------- tests/test_dag_utils.py | 12 +++++-- 2 files changed, 20 insertions(+), 60 deletions(-) diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 8fa3d8c6..95f7cb6e 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -3,6 +3,8 @@ from __future__ import annotations import itertools +from dataclasses import dataclass +from dataclasses import field from typing import TYPE_CHECKING from typing import Any from typing import cast @@ -17,56 +19,18 @@ class NoCycleError(Exception): """Raised when no cycle is found in a graph.""" -class NodeView: - """A minimal mapping-like view over node attributes.""" - - def __init__(self, node_attributes: dict[str, dict[str, Any]]) -> None: - self._node_attributes = node_attributes - - def __getitem__(self, node: str) -> dict[str, Any]: - return self._node_attributes[node] - - def __iter__(self) -> Iterator[str]: - return iter(self._node_attributes) - - def __len__(self) -> int: - return len(self._node_attributes) - - def __contains__(self, node: object) -> bool: - return node in self._node_attributes - - -class UndirectedGraph: - """A minimal undirected graph used for validation tests.""" - - def __init__( - self, - node_attributes: dict[str, dict[str, Any]], - adjacency: dict[str, dict[str, None]], - graph_attributes: dict[str, Any], - ) -> None: - self._node_attributes = { - node: attributes.copy() for node, attributes in node_attributes.items() - } - self._adjacency = { - node: neighbors.copy() for node, neighbors in adjacency.items() - } - self.graph = graph_attributes.copy() - self.nodes = NodeView(self._node_attributes) - - def is_directed(self) -> bool: - return False - - +@dataclass class DiGraph: """A minimal directed graph tailored to pytask's needs.""" - def __init__(self) -> None: - self._node_attributes: dict[str, dict[str, Any]] = {} - self._successors: dict[str, dict[str, None]] = {} - self._predecessors: dict[str, dict[str, None]] = {} - self.graph: dict[str, Any] = {} - self.nodes = NodeView(self._node_attributes) + _node_attributes: dict[str, dict[str, Any]] = field(default_factory=dict) + _successors: dict[str, dict[str, None]] = field(default_factory=dict) + _predecessors: dict[str, dict[str, None]] = field(default_factory=dict) + graph: dict[str, Any] = field(default_factory=dict) + + @property + def nodes(self) -> dict[str, dict[str, Any]]: + return self._node_attributes def add_node(self, node_name: str, **attributes: Any) -> None: if node_name not in self._node_attributes: @@ -138,16 +102,6 @@ def set_node_attributes(self, values: dict[str, Any], name: str) -> None: if node in self._node_attributes: self._node_attributes[node][name] = value - def to_undirected(self) -> UndirectedGraph: - adjacency = { - node: { - **self._predecessors[node], - **self._successors[node], - } - for node in self._node_attributes - } - return UndirectedGraph(self._node_attributes, adjacency, self.graph) - def to_networkx(self) -> Any: nx = cast("Any", import_optional_dependency("networkx")) graph = nx.DiGraph() diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index db6be94f..5f02aeda 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations from contextlib import ExitStack as does_not_raise # noqa: N813 +from dataclasses import dataclass from pathlib import Path import pytest @@ -143,10 +144,15 @@ def test_extract_priorities_from_tasks(tasks, expectation, expected): assert result == expected -def test_raise_error_for_undirected_graphs(dag): - undirected_graph = dag.to_undirected() +@dataclass +class _UndirectedGraphStub: + def is_directed(self): + return False + + +def test_raise_error_for_undirected_graphs(): with pytest.raises(ValueError, match="Only directed graphs have a"): - TopologicalSorter.from_dag(undirected_graph) + TopologicalSorter.from_dag(_UndirectedGraphStub()) # type: ignore[arg-type] def test_raise_error_for_cycle_in_graph(dag): From 13a3466b0c26ca7e7e0f61e8c2d70e1e3a7be17a Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 15:43:35 +0100 Subject: [PATCH 04/10] Add changelog note for internal DAG refactor --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4de47bb9..32f0f762 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and ## Unreleased +- [#830](https://github.com/pytask-dev/pytask/pull/830) replaces the internal + `networkx` dependency with a pytask-owned DAG implementation, lazy-loads + `networkx` only for DAG export and visualization, and makes the `networkx` + dependency optional for core builds. - [#822](https://github.com/pytask-dev/pytask/pull/822) fixes unstable signatures for remote `UPath`-backed `PathNode`s and `PickleNode`s so unchanged remote inputs are no longer reported as missing from the state database on subsequent runs. From b806a4105ca6fcb2eefcacba16c7155b269b5340 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 15:44:00 +0100 Subject: [PATCH 05/10] Introduce scheduler protocol and simple scheduler --- src/_pytask/dag_utils.py | 142 +++-------------------------- src/_pytask/execute.py | 13 +-- src/_pytask/provisional_utils.py | 6 +- src/_pytask/scheduler.py | 149 +++++++++++++++++++++++++++++++ src/_pytask/session.py | 3 +- tests/test_dag_utils.py | 16 ++-- 6 files changed, 182 insertions(+), 147 deletions(-) create mode 100644 src/_pytask/scheduler.py diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 130db30e..053f416b 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -3,22 +3,30 @@ from __future__ import annotations import itertools -from dataclasses import dataclass -from dataclasses import field from typing import TYPE_CHECKING from _pytask.dag_graph import DiGraph -from _pytask.dag_graph import NoCycleError from _pytask.dag_graph import ancestors from _pytask.dag_graph import descendants -from _pytask.dag_graph import find_cycle -from _pytask.mark_utils import has_mark +from _pytask.scheduler import SimpleScheduler +from _pytask.scheduler import TopologicalSorter +from _pytask.scheduler import _extract_priorities_from_tasks if TYPE_CHECKING: from collections.abc import Generator from collections.abc import Iterable - from _pytask.node_protocols import PTask + +__all__ = [ + "SimpleScheduler", + "TopologicalSorter", + "_extract_priorities_from_tasks", + "descending_tasks", + "node_and_neighbors", + "preceding_tasks", + "task_and_descending_tasks", + "task_and_preceding_tasks", +] def descending_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: @@ -62,125 +70,3 @@ def node_and_neighbors(dag: DiGraph, node: str) -> Iterable[str]: """ return itertools.chain(dag.predecessors(node), [node], dag.successors(node)) - - -@dataclass -class TopologicalSorter: - """The topological sorter class. - - This class allows to perform a topological sort# - - Attributes - ---------- - dag - Not the full DAG, but a reduced version that only considers tasks. - priorities - A dictionary of task names to a priority value. 1 for try first, 0 for the - default priority and, -1 for try last. - - """ - - dag: DiGraph - priorities: dict[str, int] = field(default_factory=dict) - _nodes_processing: set[str] = field(default_factory=set) - _nodes_done: set[str] = field(default_factory=set) - - @classmethod - def from_dag(cls, dag: DiGraph) -> TopologicalSorter: - """Instantiate from a DAG.""" - cls.check_dag(dag) - - tasks = [ - dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] - ] - priorities = _extract_priorities_from_tasks(tasks) - - task_signatures = {task.signature for task in tasks} - task_dag = DiGraph() - for signature in task_signatures: - task_dag.add_node(signature) - for signature in task_signatures: - # The scheduler graph uses edges from predecessor -> successor so that - # zero in-degree means "ready to run". This is the same orientation the - # previous networkx-based implementation reached after calling reverse(). - for ancestor_ in ancestors(dag, signature) & task_signatures: - task_dag.add_edge(ancestor_, signature) - - return cls(dag=task_dag, priorities=priorities) - - @classmethod - def from_dag_and_sorter( - cls, dag: DiGraph, sorter: TopologicalSorter - ) -> TopologicalSorter: - """Instantiate a sorter from another sorter and a DAG.""" - new_sorter = cls.from_dag(dag) - new_sorter.done(*sorter._nodes_done) - new_sorter._nodes_processing = sorter._nodes_processing - return new_sorter - - @staticmethod - def check_dag(dag: DiGraph) -> None: - if not dag.is_directed(): - msg = "Only directed graphs have a topological order." - raise ValueError(msg) - - try: - find_cycle(dag) - except NoCycleError: - pass - else: - msg = "The DAG contains cycles." - raise ValueError(msg) - - def get_ready(self, n: int = 1) -> list[str]: - """Get up to ``n`` tasks which are ready.""" - if not isinstance(n, int) or n < 1: - msg = "'n' must be an integer greater or equal than 1." - raise ValueError(msg) - - ready_nodes = { - v for v, d in self.dag.in_degree() if d == 0 - } - self._nodes_processing - prioritized_nodes = sorted( - ready_nodes, key=lambda x: self.priorities.get(x, 0) - )[-n:] - - self._nodes_processing.update(prioritized_nodes) - - return prioritized_nodes - - def is_active(self) -> bool: - """Indicate whether there are still tasks left.""" - return bool(self.dag.nodes) - - def done(self, *nodes: str) -> None: - """Mark some tasks as done.""" - self._nodes_processing = self._nodes_processing - set(nodes) - self.dag.remove_nodes_from(nodes) - self._nodes_done.update(nodes) - - -def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: - """Extract priorities from tasks. - - Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][] - markers. We recode these markers to numeric values to sort all available by - priorities. ``try_first`` is assigned the highest value such that it has the - rightmost position in the list. Then, we can simply call `list.pop` on the - list which is far more efficient than ``list.pop(0)``. - - """ - priorities = { - task.signature: { - "try_first": has_mark(task, "try_first"), - "try_last": has_mark(task, "try_last"), - } - for task in tasks - } - - # Recode to numeric values for sorting. - numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} - return { - name: numeric_mapping[(p["try_first"], p["try_last"])] - for name, p in priorities.items() - } diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 8f0e94a2..0c0a9c91 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -18,7 +18,6 @@ from _pytask.console import format_node_name from _pytask.console import format_strings_as_flat_tree from _pytask.console import unify_styles -from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import descending_tasks from _pytask.dag_utils import node_and_neighbors from _pytask.exceptions import ExecutionError @@ -44,6 +43,7 @@ from _pytask.pluginmanager import hookimpl from _pytask.provisional_utils import collect_provisional_products from _pytask.reports import ExecutionReport +from _pytask.scheduler import SimpleScheduler from _pytask.state import get_node_change_info from _pytask.state import has_node_changed from _pytask.state import update_states @@ -68,7 +68,7 @@ def pytask_post_parse(config: dict[str, Any]) -> None: def pytask_execute(session: Session) -> None: """Execute tasks.""" session.hook.pytask_execute_log_start(session=session) - session.scheduler = TopologicalSorter.from_dag(session.dag) + session.scheduler = SimpleScheduler.from_dag(session.dag) session.hook.pytask_execute_build(session=session) session.hook.pytask_execute_log_end( session=session, reports=session.execution_reports @@ -87,15 +87,16 @@ def pytask_execute_log_start(session: Session) -> None: @hookimpl def pytask_execute_build(session: Session) -> bool | None: """Execute tasks.""" - if isinstance(session.scheduler, TopologicalSorter): - while session.scheduler.is_active(): - task_name = session.scheduler.get_ready()[0] + scheduler = session.scheduler + if scheduler is not None: + while scheduler.is_active(): + task_name = scheduler.get_ready()[0] task = session.dag.nodes[task_name]["task"] report = session.hook.pytask_execute_task_protocol( session=session, task=task ) session.execution_reports.append(report) - session.scheduler.done(task_name) + scheduler.done(task_name) if session.should_stop: return True diff --git a/src/_pytask/provisional_utils.py b/src/_pytask/provisional_utils.py index ad01e790..604da3a8 100644 --- a/src/_pytask/provisional_utils.py +++ b/src/_pytask/provisional_utils.py @@ -9,7 +9,6 @@ from _pytask.collect_utils import collect_dependency from _pytask.dag import create_dag_from_session -from _pytask.dag_utils import TopologicalSorter from _pytask.models import NodeInfo from _pytask.node_protocols import PNode from _pytask.node_protocols import PProvisionalNode @@ -78,9 +77,8 @@ def recreate_dag(session: Session, task: PTask) -> None: """ try: session.dag = create_dag_from_session(session) - session.scheduler = TopologicalSorter.from_dag_and_sorter( - session.dag, session.scheduler - ) + if session.scheduler is not None: + session.scheduler = session.scheduler.rebuild(session.dag) except Exception: # noqa: BLE001 report = ExecutionReport.from_task_and_exception(task, sys.exc_info()) diff --git a/src/_pytask/scheduler.py b/src/_pytask/scheduler.py new file mode 100644 index 00000000..ceaa6bfa --- /dev/null +++ b/src/_pytask/scheduler.py @@ -0,0 +1,149 @@ +"""Contains scheduler protocols and implementations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import TYPE_CHECKING +from typing import Protocol + +from _pytask.dag_graph import DiGraph +from _pytask.dag_graph import NoCycleError +from _pytask.dag_graph import ancestors +from _pytask.dag_graph import find_cycle +from _pytask.mark_utils import has_mark + +if TYPE_CHECKING: + from _pytask.node_protocols import PTask + + +class PScheduler(Protocol): + """Protocol for schedulers that dispatch ready tasks.""" + + def get_ready(self, n: int = 1) -> list[str]: + """Get up to ``n`` tasks which are ready.""" + + def is_active(self) -> bool: + """Indicate whether there are still tasks left.""" + + def done(self, *nodes: str) -> None: + """Mark some tasks as done.""" + + def rebuild(self, dag: DiGraph) -> PScheduler: + """Rebuild the scheduler from an updated DAG while preserving state.""" + + +@dataclass +class SimpleScheduler: + """The default scheduler based on topological sorting.""" + + dag: DiGraph + priorities: dict[str, int] = field(default_factory=dict) + _nodes_processing: set[str] = field(default_factory=set) + _nodes_done: set[str] = field(default_factory=set) + + @classmethod + def from_dag(cls, dag: DiGraph) -> SimpleScheduler: + """Instantiate from a DAG.""" + cls.check_dag(dag) + + tasks = [ + dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] + ] + priorities = _extract_priorities_from_tasks(tasks) + + task_signatures = {task.signature for task in tasks} + task_dag = DiGraph() + for signature in task_signatures: + task_dag.add_node(signature) + for signature in task_signatures: + # The scheduler graph uses edges from predecessor -> successor so that + # zero in-degree means "ready to run". This is the same orientation the + # previous networkx-based implementation reached after calling reverse(). + for ancestor_ in ancestors(dag, signature) & task_signatures: + task_dag.add_edge(ancestor_, signature) + + return cls(dag=task_dag, priorities=priorities) + + @classmethod + def from_dag_and_sorter( + cls, dag: DiGraph, sorter: SimpleScheduler + ) -> SimpleScheduler: + """Instantiate a sorter from another sorter and a DAG.""" + new_sorter = cls.from_dag(dag) + new_sorter.done(*sorter._nodes_done) + new_sorter._nodes_processing = sorter._nodes_processing + return new_sorter + + @staticmethod + def check_dag(dag: DiGraph) -> None: + if not dag.is_directed(): + msg = "Only directed graphs have a topological order." + raise ValueError(msg) + + try: + find_cycle(dag) + except NoCycleError: + pass + else: + msg = "The DAG contains cycles." + raise ValueError(msg) + + def get_ready(self, n: int = 1) -> list[str]: + """Get up to ``n`` tasks which are ready.""" + if not isinstance(n, int) or n < 1: + msg = "'n' must be an integer greater or equal than 1." + raise ValueError(msg) + + ready_nodes = { + v for v, d in self.dag.in_degree() if d == 0 + } - self._nodes_processing + prioritized_nodes = sorted( + ready_nodes, key=lambda x: self.priorities.get(x, 0) + )[-n:] + + self._nodes_processing.update(prioritized_nodes) + + return prioritized_nodes + + def is_active(self) -> bool: + """Indicate whether there are still tasks left.""" + return bool(self.dag.nodes) + + def done(self, *nodes: str) -> None: + """Mark some tasks as done.""" + self._nodes_processing = self._nodes_processing - set(nodes) + self.dag.remove_nodes_from(nodes) + self._nodes_done.update(nodes) + + def rebuild(self, dag: DiGraph) -> SimpleScheduler: + """Rebuild the scheduler from an updated DAG while preserving state.""" + return self.from_dag_and_sorter(dag, self) + + +TopologicalSorter = SimpleScheduler + + +def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: + """Extract priorities from tasks. + + Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][] + markers. We recode these markers to numeric values to sort all available by + priorities. ``try_first`` is assigned the highest value such that it has the + rightmost position in the list. Then, we can simply call `list.pop` on the + list which is far more efficient than ``list.pop(0)``. + + """ + priorities = { + task.signature: { + "try_first": has_mark(task, "try_first"), + "try_last": has_mark(task, "try_last"), + } + for task in tasks + } + + numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} + return { + name: numeric_mapping[(p["try_first"], p["try_last"])] + for name, p in priorities.items() + } diff --git a/src/_pytask/session.py b/src/_pytask/session.py index 09008742..e5c07edf 100644 --- a/src/_pytask/session.py +++ b/src/_pytask/session.py @@ -17,6 +17,7 @@ from _pytask.reports import CollectionReport from _pytask.reports import DagReport from _pytask.reports import ExecutionReport + from _pytask.scheduler import PScheduler from _pytask.warnings_utils import WarningReport @@ -64,7 +65,7 @@ class Session: execution_end: float = float("inf") n_tasks_failed: int = 0 - scheduler: Any = None + scheduler: PScheduler | None = None should_stop: bool = False warnings: list[WarningReport] = field(default_factory=list) diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index 5f02aeda..c155895b 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -7,11 +7,11 @@ import pytest from _pytask.dag_graph import DiGraph -from _pytask.dag_utils import TopologicalSorter -from _pytask.dag_utils import _extract_priorities_from_tasks from _pytask.dag_utils import descending_tasks from _pytask.dag_utils import node_and_neighbors from _pytask.dag_utils import task_and_descending_tasks +from _pytask.scheduler import SimpleScheduler +from _pytask.scheduler import _extract_priorities_from_tasks from pytask import Mark from pytask import Task from tests.conftest import noop @@ -32,7 +32,7 @@ def dag(): def test_sort_tasks_topologically(dag): - sorter = TopologicalSorter.from_dag(dag) + sorter = SimpleScheduler.from_dag(dag) topo_ordering = [] while sorter.is_active(): task_name = sorter.get_ready()[0] @@ -152,7 +152,7 @@ def is_directed(self): def test_raise_error_for_undirected_graphs(): with pytest.raises(ValueError, match="Only directed graphs have a"): - TopologicalSorter.from_dag(_UndirectedGraphStub()) # type: ignore[arg-type] + SimpleScheduler.from_dag(_UndirectedGraphStub()) # type: ignore[arg-type] def test_raise_error_for_cycle_in_graph(dag): @@ -161,11 +161,11 @@ def test_raise_error_for_cycle_in_graph(dag): "55c6cef62d3e62d5f8fc65bb846e66d8d0d3ca60608c04f6f7b095ea073a7dcf", ) with pytest.raises(ValueError, match=r"The DAG contains cycles\."): - TopologicalSorter.from_dag(dag) + SimpleScheduler.from_dag(dag) def test_ask_for_invalid_number_of_ready_tasks(dag): - scheduler = TopologicalSorter.from_dag(dag) + scheduler = SimpleScheduler.from_dag(dag) with pytest.raises(ValueError, match="'n' must be"): scheduler.get_ready(0) @@ -173,7 +173,7 @@ def test_ask_for_invalid_number_of_ready_tasks(dag): def test_instantiate_sorter_from_other_sorter(dag): name_to_sig = {dag.nodes[sig]["task"].name: sig for sig in dag.nodes} - scheduler = TopologicalSorter.from_dag(dag) + scheduler = SimpleScheduler.from_dag(dag) for _ in range(2): task_name = scheduler.get_ready()[0] scheduler.done(task_name) @@ -183,7 +183,7 @@ def test_instantiate_sorter_from_other_sorter(dag): dag.add_node(task.signature, task=Task(base_name="5", path=Path(), function=noop)) dag.add_edge(name_to_sig[".::4"], task.signature) - new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) + new_scheduler = scheduler.rebuild(dag) while new_scheduler.is_active(): task_name = new_scheduler.get_ready()[0] new_scheduler.done(task_name) From b80be41381559a6ce8605c286f1c3d17135ab33d Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 15:51:38 +0100 Subject: [PATCH 06/10] Test scheduler priorities through public behavior --- src/_pytask/dag_utils.py | 2 - tests/test_dag_utils.py | 92 ++++++++++++---------------------------- 2 files changed, 26 insertions(+), 68 deletions(-) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 053f416b..2d0e1cee 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -10,7 +10,6 @@ from _pytask.dag_graph import descendants from _pytask.scheduler import SimpleScheduler from _pytask.scheduler import TopologicalSorter -from _pytask.scheduler import _extract_priorities_from_tasks if TYPE_CHECKING: from collections.abc import Generator @@ -20,7 +19,6 @@ __all__ = [ "SimpleScheduler", "TopologicalSorter", - "_extract_priorities_from_tasks", "descending_tasks", "node_and_neighbors", "preceding_tasks", diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index c155895b..5fc66217 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -from contextlib import ExitStack as does_not_raise # noqa: N813 from dataclasses import dataclass from pathlib import Path @@ -11,7 +10,6 @@ from _pytask.dag_utils import node_and_neighbors from _pytask.dag_utils import task_and_descending_tasks from _pytask.scheduler import SimpleScheduler -from _pytask.scheduler import _extract_priorities_from_tasks from pytask import Mark from pytask import Task from tests.conftest import noop @@ -78,70 +76,32 @@ def test_node_and_neighbors(dag): assert node_names == [f".::{j}" for j in range(i - 1, i + 2)] -@pytest.mark.parametrize( - ("tasks", "expectation", "expected"), - [ - pytest.param( - [ - Task( - base_name="1", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_last", (), {})], - ) - ], - does_not_raise(), - {"c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": -1}, - id="test try_last", - ), - pytest.param( - [ - Task( - base_name="1", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_first", (), {})], - ) - ], - does_not_raise(), - {"c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": 1}, - id="test try_first", - ), - pytest.param( - [Task(base_name="1", path=Path(), function=None, markers=[])], # type: ignore[arg-type] - does_not_raise(), - {"c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": 0}, - id="test no priority", - ), - pytest.param( - [ - Task( - base_name="1", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_first", (), {})], - ), - Task(base_name="2", path=Path(), function=None, markers=[]), # type: ignore[arg-type] - Task( - base_name="3", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_last", (), {})], - ), - ], - does_not_raise(), - { - "c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": 1, - "c5f667e69824043475b1283ed8920e513cb4343ec7077f71a3d9f5972f5204b9": 0, - "dca295f815f54d282b33e8d9398cea4962d0dfbe881d2ab28fc48ff9e060203a": -1, - }, - ), - ], -) -def test_extract_priorities_from_tasks(tasks, expectation, expected): - with expectation: - result = _extract_priorities_from_tasks(tasks) - assert result == expected +def test_prioritize_try_first_and_try_last_tasks(): + dag = DiGraph() + first = Task( + base_name="first", + path=Path(), + function=noop, + markers=[Mark("try_first", (), {})], + ) + default = Task(base_name="default", path=Path(), function=noop) + last = Task( + base_name="last", + path=Path(), + function=noop, + markers=[Mark("try_last", (), {})], + ) + + for task in (first, default, last): + dag.add_node(task.signature, task=task) + + scheduler = SimpleScheduler.from_dag(dag) + + first_batch = scheduler.get_ready(3) + first_batch_names = [dag.nodes[sig]["task"].name for sig in first_batch] + + assert first_batch_names[-1] == ".::first" + assert first_batch_names[0] == ".::last" @dataclass From a95de4d2ab28cc634a873aea45505a2fa4190e71 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 16:13:07 +0100 Subject: [PATCH 07/10] Drop obsolete topological sorter compatibility alias --- src/_pytask/dag_utils.py | 4 ---- src/_pytask/scheduler.py | 3 --- 2 files changed, 7 deletions(-) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 2d0e1cee..1e71c7d5 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -8,8 +8,6 @@ from _pytask.dag_graph import DiGraph from _pytask.dag_graph import ancestors from _pytask.dag_graph import descendants -from _pytask.scheduler import SimpleScheduler -from _pytask.scheduler import TopologicalSorter if TYPE_CHECKING: from collections.abc import Generator @@ -17,8 +15,6 @@ __all__ = [ - "SimpleScheduler", - "TopologicalSorter", "descending_tasks", "node_and_neighbors", "preceding_tasks", diff --git a/src/_pytask/scheduler.py b/src/_pytask/scheduler.py index ceaa6bfa..71f4a7c1 100644 --- a/src/_pytask/scheduler.py +++ b/src/_pytask/scheduler.py @@ -121,9 +121,6 @@ def rebuild(self, dag: DiGraph) -> SimpleScheduler: return self.from_dag_and_sorter(dag, self) -TopologicalSorter = SimpleScheduler - - def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: """Extract priorities from tasks. From 570973e6749d8fd563092743393981a9abb440ae Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Thu, 26 Mar 2026 16:30:10 +0100 Subject: [PATCH 08/10] Copy scheduler processing state during rebuild --- src/_pytask/scheduler.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/_pytask/scheduler.py b/src/_pytask/scheduler.py index 71f4a7c1..1fe25ed6 100644 --- a/src/_pytask/scheduler.py +++ b/src/_pytask/scheduler.py @@ -65,16 +65,6 @@ def from_dag(cls, dag: DiGraph) -> SimpleScheduler: return cls(dag=task_dag, priorities=priorities) - @classmethod - def from_dag_and_sorter( - cls, dag: DiGraph, sorter: SimpleScheduler - ) -> SimpleScheduler: - """Instantiate a sorter from another sorter and a DAG.""" - new_sorter = cls.from_dag(dag) - new_sorter.done(*sorter._nodes_done) - new_sorter._nodes_processing = sorter._nodes_processing - return new_sorter - @staticmethod def check_dag(dag: DiGraph) -> None: if not dag.is_directed(): @@ -118,7 +108,10 @@ def done(self, *nodes: str) -> None: def rebuild(self, dag: DiGraph) -> SimpleScheduler: """Rebuild the scheduler from an updated DAG while preserving state.""" - return self.from_dag_and_sorter(dag, self) + new_scheduler = type(self).from_dag(dag) + new_scheduler.done(*self._nodes_done) + new_scheduler._nodes_processing = self._nodes_processing.copy() + return new_scheduler def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: From 55e9ad8456c287d85f5b25d7d0be404f47630539 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 28 Mar 2026 00:33:56 +0100 Subject: [PATCH 09/10] fix: minimize diff --- src/_pytask/execute.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 4549c681..f8db6518 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -206,10 +206,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C # Check if node changed and collect detailed info if in explain mode if session.config["explain"]: has_changed, reason, details = get_node_change_info( - session=session, - task=task, - node=node, - state=node_state, + session=session, task=task, node=node, state=node_state ) if has_changed: needs_to_be_executed = True @@ -236,10 +233,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C ) else: has_changed = has_node_changed( - session=session, - task=task, - node=node, - state=node_state, + session=session, task=task, node=node, state=node_state ) if has_changed: needs_to_be_executed = True From a3c6f3bf613fd78660c6974eea907869cd161ea7 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 28 Mar 2026 00:37:15 +0100 Subject: [PATCH 10/10] Fix scheduler rebuild during task generation --- src/_pytask/execute.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index f8db6518..923049e0 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -86,10 +86,9 @@ def pytask_execute_log_start(session: Session) -> None: @hookimpl def pytask_execute_build(session: Session) -> bool | None: """Execute tasks.""" - scheduler = session.scheduler - if scheduler is not None: - while scheduler.is_active(): - task_name = scheduler.get_ready()[0] + if session.scheduler is not None: + while session.scheduler.is_active(): + task_name = session.scheduler.get_ready()[0] task = session.dag.nodes[task_name] if not isinstance(task, PTask): msg = f"Expected task node for signature {task_name!r}." @@ -98,7 +97,7 @@ def pytask_execute_build(session: Session) -> bool | None: session=session, task=task ) session.execution_reports.append(report) - scheduler.done(task_name) + session.scheduler.done(task_name) if session.should_stop: return True