Skip to content

Commit 8a86c8a

Browse files
authored
Refactor DAG internals and lazy-load networkx (#830)
1 parent 1945471 commit 8a86c8a

21 files changed

Lines changed: 456 additions & 184 deletions

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
77

88
## Unreleased
99

10+
- [#830](https://github.com/pytask-dev/pytask/pull/830) replaces the internal
11+
`networkx` dependency with a pytask-owned DAG implementation, lazy-loads
12+
`networkx` only for DAG export and visualization, and makes the `networkx`
13+
dependency optional for core builds.
1014
- [#822](https://github.com/pytask-dev/pytask/pull/822) fixes unstable signatures
1115
for remote `UPath`-backed `PathNode`s and `PickleNode`s so unchanged remote inputs
1216
are no longer reported as missing from the state database on subsequent runs.

docs/source/tutorials/visualizing_the_dag.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
# Visualizing the DAG
22

33
To visualize the [DAG](../glossary.md#dag) of the project, first, install
4-
[pygraphviz](https://github.com/pygraphviz/pygraphviz) and
5-
[graphviz](https://graphviz.org/). For example, you can both install with pixi
4+
[networkx](https://networkx.org/),
5+
[pygraphviz](https://github.com/pygraphviz/pygraphviz), and
6+
[graphviz](https://graphviz.org/).
67

7-
```console
8-
$ pixi add pygraphviz graphviz
9-
```
8+
=== "uv"
9+
10+
```console
11+
$ uv add networkx
12+
$ uv add --optional dag pygraphviz
13+
```
14+
15+
=== "pixi"
16+
17+
```console
18+
$ pixi add networkx pygraphviz graphviz
19+
```
1020

1121
After that, pytask offers two interfaces to visualize your project's `DAG`.
1222

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ dependencies = [
2424
"click>=8.1.8,!=8.2.0",
2525
"click-default-group>=1.2.4",
2626
"msgspec>=0.18.6",
27-
"networkx>=2.4.0",
2827
"optree>=0.9.0",
2928
"packaging>=23.0.0",
3029
"pluggy>=1.3.0",
@@ -36,6 +35,9 @@ dependencies = [
3635
"universal-pathlib>=0.2.2",
3736
]
3837

38+
[project.optional-dependencies]
39+
dag = ["networkx>=2.4.0"]
40+
3941
[project.readme]
4042
file = "README.md"
4143
content-type = "text/markdown"
@@ -54,6 +56,7 @@ docs = [
5456
"ipywidgets>=8.1.6",
5557
"matplotlib>=3.5.0",
5658
"mkdocstrings[python]>=0.30.0",
59+
"networkx>=2.4.0",
5760
"zensical>=0.0.23",
5861
]
5962
docs-live = ["sphinx-autobuild>=2024.10.3"]
@@ -71,6 +74,7 @@ test = [
7174
"syrupy>=4.5.0",
7275
"aiohttp>=3.11.0", # For HTTPPath tests.
7376
"coiled>=1.42.0; python_version < '3.14'",
77+
"networkx>=2.4.0",
7478
"pygraphviz>=1.12;platform_system=='Linux'",
7579
]
7680
typing = ["ty>=0.0.8"]

src/_pytask/dag.py

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import sys
77
from typing import TYPE_CHECKING
88

9-
import networkx as nx
109
from rich.text import Text
1110
from rich.tree import Tree
1211

@@ -17,6 +16,9 @@
1716
from _pytask.console import format_node_name
1817
from _pytask.console import format_task_name
1918
from _pytask.console import render_to_string
19+
from _pytask.dag_graph import DAG
20+
from _pytask.dag_graph import NoCycleError
21+
from _pytask.dag_graph import find_cycle
2022
from _pytask.exceptions import ResolvingDependenciesError
2123
from _pytask.mark import select_by_after_keyword
2224
from _pytask.mark import select_tasks_by_marks_and_expressions
@@ -37,7 +39,7 @@
3739
__all__ = ["create_dag", "create_dag_from_session"]
3840

3941

40-
def create_dag(session: Session) -> nx.DiGraph:
42+
def create_dag(session: Session) -> DAG:
4143
"""Create a directed acyclic graph (DAG) for the workflow."""
4244
try:
4345
dag = create_dag_from_session(session)
@@ -50,7 +52,7 @@ def create_dag(session: Session) -> nx.DiGraph:
5052
return dag
5153

5254

53-
def create_dag_from_session(session: Session) -> nx.DiGraph:
55+
def create_dag_from_session(session: Session) -> DAG:
5456
"""Create a DAG from a session."""
5557
dag = _create_dag_from_tasks(tasks=session.tasks)
5658
_check_if_dag_has_cycles(dag)
@@ -60,14 +62,16 @@ def create_dag_from_session(session: Session) -> nx.DiGraph:
6062
return dag
6163

6264

63-
def _create_dag_from_tasks(tasks: list[PTask]) -> nx.DiGraph:
65+
def _create_dag_from_tasks(tasks: list[PTask]) -> DAG:
6466
"""Create the DAG from tasks, dependencies and products."""
6567

66-
def _add_dependency(
67-
dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode
68-
) -> None:
68+
def _add_node_data(dag: DAG, node: PNode | PProvisionalNode) -> None:
69+
dag.add_node(node.signature, node)
70+
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
71+
_add_node_data(dag, node.value)
72+
73+
def _add_dependency(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None:
6974
"""Add a dependency to the DAG."""
70-
dag.add_node(node.signature, node=node)
7175
dag.add_edge(node.signature, task.signature)
7276

7377
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
@@ -76,36 +80,24 @@ def _add_dependency(
7680
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
7781
dag.add_edge(node.value.signature, node.signature)
7882

79-
def _add_product(
80-
dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode
81-
) -> None:
83+
def _add_product(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None:
8284
"""Add a product to the DAG."""
83-
dag.add_node(node.signature, node=node)
8485
dag.add_edge(task.signature, node.signature)
8586

86-
dag = nx.DiGraph()
87+
dag = DAG()
8788

8889
for task in tasks:
89-
dag.add_node(task.signature, task=task)
90+
dag.add_node(task.signature, task)
91+
tree_map(lambda x: _add_node_data(dag, x), task.depends_on)
92+
tree_map(lambda x: _add_node_data(dag, x), task.produces)
9093

94+
for task in tasks:
9195
tree_map(lambda x: _add_dependency(dag, task, x), task.depends_on)
9296
tree_map(lambda x: _add_product(dag, task, x), task.produces)
93-
94-
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
95-
# another task that is a dependency in the current task. Thus, draw an edge
96-
# connecting the two nodes.
97-
tree_map(
98-
lambda x: (
99-
dag.add_edge(x.value.signature, x.signature)
100-
if isinstance(x, PythonNode) and isinstance(x.value, PythonNode)
101-
else None
102-
),
103-
task.depends_on,
104-
)
10597
return dag
10698

10799

108-
def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph:
100+
def _modify_dag(session: Session, dag: DAG) -> DAG:
109101
"""Create dependencies between tasks when using ``@task(after=...)``."""
110102
temporary_id_to_task = {
111103
task.attributes["collection_id"]: task
@@ -129,11 +121,11 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph:
129121
return dag
130122

131123

132-
def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
124+
def _check_if_dag_has_cycles(dag: DAG) -> None:
133125
"""Check if DAG has cycles."""
134126
try:
135-
cycles = nx.algorithms.cycles.find_cycle(dag)
136-
except nx.NetworkXNoCycle:
127+
cycles = find_cycle(dag)
128+
except NoCycleError:
137129
pass
138130
else:
139131
msg = (
@@ -145,7 +137,7 @@ def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
145137
raise ResolvingDependenciesError(msg)
146138

147139

148-
def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str:
140+
def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str:
149141
"""Format cycles as a paths connected by arrows."""
150142
chain = [
151143
x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0
@@ -154,7 +146,7 @@ def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str:
154146

155147
lines: list[str] = []
156148
for x in chain:
157-
node = dag.nodes[x].get("task") or dag.nodes[x].get("node")
149+
node = dag.nodes[x]
158150
if isinstance(node, PTask):
159151
short_name = format_task_name(node, editor_url_scheme="no_link").plain
160152
elif isinstance(node, (PNode, PProvisionalNode)):
@@ -176,24 +168,27 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
176168
return render_to_string(tree, console=console, strip_styles=True)
177169

178170

179-
def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None:
171+
def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None:
180172
nodes_created_by_multiple_tasks = []
181173

182174
for node in dag.nodes:
183-
is_node = "node" in dag.nodes[node]
184-
if is_node:
175+
if isinstance(dag.nodes[node], (PNode, PProvisionalNode)):
185176
parents = list(dag.predecessors(node))
186177
if len(parents) > 1:
187178
nodes_created_by_multiple_tasks.append(node)
188179

189180
if nodes_created_by_multiple_tasks:
190181
dictionary = {}
191182
for node in nodes_created_by_multiple_tasks:
192-
short_node_name = format_node_name(dag.nodes[node]["node"], paths).plain
183+
payload = dag.nodes[node]
184+
if not isinstance(payload, (PNode, PProvisionalNode)):
185+
msg = f"Expected product node for signature {node!r}."
186+
raise TypeError(msg)
187+
short_node_name = format_node_name(payload, paths).plain
193188
short_predecessors = reduce_names_of_multiple_nodes(
194189
dag.predecessors(node), dag, paths
195190
)
196-
dictionary[short_node_name] = short_predecessors
191+
dictionary[short_node_name] = sorted(short_predecessors)
197192
text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:")
198193
msg = (
199194
f"There are some tasks which produce the same output. See the following "

src/_pytask/dag_command.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import enum
66
import sys
77
from pathlib import Path
8+
from typing import TYPE_CHECKING
89
from typing import Any
910
from typing import cast
1011

1112
import click
12-
import networkx as nx
1313
from rich.text import Text
1414

1515
from _pytask.click import ColoredCommand
@@ -30,6 +30,11 @@
3030
from _pytask.shared import reduce_names_of_multiple_nodes
3131
from _pytask.traceback import Traceback
3232

33+
if TYPE_CHECKING:
34+
import networkx as nx
35+
36+
from _pytask.dag_graph import DAG
37+
3338

3439
class _RankDirection(enum.Enum):
3540
TB = "TB"
@@ -92,6 +97,7 @@ def dag(**raw_config: Any) -> int:
9297
else:
9398
try:
9499
session.hook.pytask_log_session_header(session=session)
100+
import_optional_dependency("networkx")
95101
import_optional_dependency("pygraphviz")
96102
check_for_optional_program(
97103
session.config["layout"],
@@ -100,7 +106,7 @@ def dag(**raw_config: Any) -> int:
100106
)
101107
session.hook.pytask_collect(session=session)
102108
session.dag = create_dag(session=session)
103-
dag = _refine_dag(session)
109+
dag = _to_visualization_graph(session)
104110
_write_graph(dag, session.config["output_path"], session.config["layout"])
105111

106112
except CollectionError: # pragma: no cover
@@ -163,6 +169,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
163169

164170
else:
165171
session.hook.pytask_log_session_header(session=session)
172+
import_optional_dependency("networkx")
166173
import_optional_dependency("pygraphviz")
167174
check_for_optional_program(
168175
session.config["layout"],
@@ -172,44 +179,42 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
172179
session.hook.pytask_collect(session=session)
173180
session.dag = create_dag(session=session)
174181
session.hook.pytask_unconfigure(session=session)
175-
return _refine_dag(session)
182+
return _to_visualization_graph(session)
176183

177184

178-
def _refine_dag(session: Session) -> nx.DiGraph:
185+
def _refine_dag(session: Session) -> DAG:
179186
"""Refine the dag for plotting."""
180187
dag = _shorten_node_labels(session.dag, session.config["paths"])
181-
dag = _clean_dag(dag)
182-
dag = _style_dag(dag)
183-
dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name}
188+
return _clean_dag(dag)
184189

190+
191+
def _to_visualization_graph(session: Session) -> nx.DiGraph:
192+
"""Convert the internal DAG to a styled networkx graph for visualization."""
193+
nx = cast("Any", import_optional_dependency("networkx"))
194+
dag = _refine_dag(session).to_networkx()
195+
dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name}
196+
shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes}
197+
nx.set_node_attributes(dag, shapes, "shape")
185198
return dag
186199

187200

188-
def _shorten_node_labels(dag: nx.DiGraph, paths: list[Path]) -> nx.DiGraph:
201+
def _shorten_node_labels(dag: DAG, paths: list[Path]) -> DAG:
189202
"""Shorten the node labels in the graph for a better experience."""
190203
node_names = dag.nodes
191204
short_names = reduce_names_of_multiple_nodes(node_names, dag, paths)
192205
short_names = [i.plain if isinstance(i, Text) else i for i in short_names]
193206
old_to_new = dict(zip(node_names, short_names, strict=False))
194-
return nx.relabel_nodes(dag, old_to_new)
207+
return dag.relabel_nodes(old_to_new)
195208

196209

197-
def _clean_dag(dag: nx.DiGraph) -> nx.DiGraph:
210+
def _clean_dag(dag: DAG) -> DAG:
198211
"""Clean the DAG."""
199-
for node in dag.nodes:
200-
dag.nodes[node].clear()
201-
return dag
202-
203-
204-
def _style_dag(dag: nx.DiGraph) -> nx.DiGraph:
205-
"""Style the DAG."""
206-
shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes}
207-
nx.set_node_attributes(dag, shapes, "shape")
208212
return dag
209213

210214

211215
def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None:
212216
"""Write the graph to disk."""
217+
nx = cast("Any", import_optional_dependency("networkx"))
213218
path.parent.mkdir(exist_ok=True, parents=True)
214219
graph = nx.nx_agraph.to_agraph(dag)
215220
graph.draw(path, prog=layout)

0 commit comments

Comments
 (0)