Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and

## Unreleased

- [#830](https://github.com/pytask-dev/pytask/pull/830) replaces the internal
`networkx` dependency with a pytask-owned DAG implementation, lazy-loads
`networkx` only for DAG export and visualization, and makes the `networkx`
dependency optional for core builds.
- [#822](https://github.com/pytask-dev/pytask/pull/822) fixes unstable signatures
for remote `UPath`-backed `PathNode`s and `PickleNode`s so unchanged remote inputs
are no longer reported as missing from the state database on subsequent runs.
Expand Down
20 changes: 15 additions & 5 deletions docs/source/tutorials/visualizing_the_dag.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
# Visualizing the DAG

To visualize the [DAG](../glossary.md#dag) of the project, first, install
[pygraphviz](https://github.com/pygraphviz/pygraphviz) and
[graphviz](https://graphviz.org/). For example, you can both install with pixi
[networkx](https://networkx.org/),
[pygraphviz](https://github.com/pygraphviz/pygraphviz), and
[graphviz](https://graphviz.org/).

```console
$ pixi add pygraphviz graphviz
```
=== "uv"

```console
$ uv add networkx
$ uv add --optional dag pygraphviz
```

=== "pixi"

```console
$ pixi add networkx pygraphviz graphviz
```

After that, pytask offers two interfaces to visualize your project's `DAG`.

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"click>=8.1.8,!=8.2.0",
"click-default-group>=1.2.4",
"msgspec>=0.18.6",
"networkx>=2.4.0",
"optree>=0.9.0",
"packaging>=23.0.0",
"pluggy>=1.3.0",
Expand All @@ -36,6 +35,9 @@ dependencies = [
"universal-pathlib>=0.2.2",
]

[project.optional-dependencies]
dag = ["networkx>=2.4.0"]

[project.readme]
file = "README.md"
content-type = "text/markdown"
Expand All @@ -54,6 +56,7 @@ docs = [
"ipywidgets>=8.1.6",
"matplotlib>=3.5.0",
"mkdocstrings[python]>=0.30.0",
"networkx>=2.4.0",
"zensical>=0.0.23",
]
docs-live = ["sphinx-autobuild>=2024.10.3"]
Expand All @@ -71,6 +74,7 @@ test = [
"syrupy>=4.5.0",
"aiohttp>=3.11.0", # For HTTPPath tests.
"coiled>=1.42.0; python_version < '3.14'",
"networkx>=2.4.0",
"pygraphviz>=1.12;platform_system=='Linux'",
]
typing = ["ty>=0.0.8"]
Expand Down
69 changes: 32 additions & 37 deletions src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import sys
from typing import TYPE_CHECKING

import networkx as nx
from rich.text import Text
from rich.tree import Tree

Expand All @@ -17,6 +16,9 @@
from _pytask.console import format_node_name
from _pytask.console import format_task_name
from _pytask.console import render_to_string
from _pytask.dag_graph import DAG
from _pytask.dag_graph import NoCycleError
from _pytask.dag_graph import find_cycle
from _pytask.exceptions import ResolvingDependenciesError
from _pytask.mark import select_by_after_keyword
from _pytask.mark import select_tasks_by_marks_and_expressions
Expand All @@ -37,7 +39,7 @@
__all__ = ["create_dag", "create_dag_from_session"]


def create_dag(session: Session) -> nx.DiGraph:
def create_dag(session: Session) -> DAG:
"""Create a directed acyclic graph (DAG) for the workflow."""
try:
dag = create_dag_from_session(session)
Expand All @@ -50,7 +52,7 @@ def create_dag(session: Session) -> nx.DiGraph:
return dag


def create_dag_from_session(session: Session) -> nx.DiGraph:
def create_dag_from_session(session: Session) -> DAG:
"""Create a DAG from a session."""
dag = _create_dag_from_tasks(tasks=session.tasks)
_check_if_dag_has_cycles(dag)
Expand All @@ -60,14 +62,16 @@ def create_dag_from_session(session: Session) -> nx.DiGraph:
return dag


def _create_dag_from_tasks(tasks: list[PTask]) -> nx.DiGraph:
def _create_dag_from_tasks(tasks: list[PTask]) -> DAG:
"""Create the DAG from tasks, dependencies and products."""

def _add_dependency(
dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode
) -> None:
def _add_node_data(dag: DAG, node: PNode | PProvisionalNode) -> None:
dag.add_node(node.signature, node)
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
_add_node_data(dag, node.value)

def _add_dependency(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None:
"""Add a dependency to the DAG."""
dag.add_node(node.signature, node=node)
dag.add_edge(node.signature, task.signature)

# If a node is a PythonNode wrapped in another PythonNode, it is a product from
Expand All @@ -76,36 +80,24 @@ def _add_dependency(
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
dag.add_edge(node.value.signature, node.signature)

def _add_product(
dag: nx.DiGraph, task: PTask, node: PNode | PProvisionalNode
) -> None:
def _add_product(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None:
"""Add a product to the DAG."""
dag.add_node(node.signature, node=node)
dag.add_edge(task.signature, node.signature)

dag = nx.DiGraph()
dag = DAG()

for task in tasks:
dag.add_node(task.signature, task=task)
dag.add_node(task.signature, task)
tree_map(lambda x: _add_node_data(dag, x), task.depends_on)
tree_map(lambda x: _add_node_data(dag, x), task.produces)

for task in tasks:
tree_map(lambda x: _add_dependency(dag, task, x), task.depends_on)
tree_map(lambda x: _add_product(dag, task, x), task.produces)

# If a node is a PythonNode wrapped in another PythonNode, it is a product from
# another task that is a dependency in the current task. Thus, draw an edge
# connecting the two nodes.
tree_map(
lambda x: (
dag.add_edge(x.value.signature, x.signature)
if isinstance(x, PythonNode) and isinstance(x.value, PythonNode)
else None
),
task.depends_on,
)
return dag


def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph:
def _modify_dag(session: Session, dag: DAG) -> DAG:
"""Create dependencies between tasks when using ``@task(after=...)``."""
temporary_id_to_task = {
task.attributes["collection_id"]: task
Expand All @@ -129,11 +121,11 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph:
return dag


def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
def _check_if_dag_has_cycles(dag: DAG) -> None:
"""Check if DAG has cycles."""
try:
cycles = nx.algorithms.cycles.find_cycle(dag)
except nx.NetworkXNoCycle:
cycles = find_cycle(dag)
except NoCycleError:
pass
else:
msg = (
Expand All @@ -145,7 +137,7 @@ def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
raise ResolvingDependenciesError(msg)


def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str:
def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str:
"""Format cycles as a paths connected by arrows."""
chain = [
x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0
Expand All @@ -154,7 +146,7 @@ def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str:

lines: list[str] = []
for x in chain:
node = dag.nodes[x].get("task") or dag.nodes[x].get("node")
node = dag.nodes[x]
if isinstance(node, PTask):
short_name = format_task_name(node, editor_url_scheme="no_link").plain
elif isinstance(node, (PNode, PProvisionalNode)):
Expand All @@ -176,24 +168,27 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
return render_to_string(tree, console=console, strip_styles=True)


def _check_if_tasks_have_the_same_products(dag: nx.DiGraph, paths: list[Path]) -> None:
def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None:
nodes_created_by_multiple_tasks = []

for node in dag.nodes:
is_node = "node" in dag.nodes[node]
if is_node:
if isinstance(dag.nodes[node], (PNode, PProvisionalNode)):
parents = list(dag.predecessors(node))
if len(parents) > 1:
nodes_created_by_multiple_tasks.append(node)

if nodes_created_by_multiple_tasks:
dictionary = {}
for node in nodes_created_by_multiple_tasks:
short_node_name = format_node_name(dag.nodes[node]["node"], paths).plain
payload = dag.nodes[node]
if not isinstance(payload, (PNode, PProvisionalNode)):
msg = f"Expected product node for signature {node!r}."
raise TypeError(msg)
short_node_name = format_node_name(payload, paths).plain
short_predecessors = reduce_names_of_multiple_nodes(
dag.predecessors(node), dag, paths
)
dictionary[short_node_name] = short_predecessors
dictionary[short_node_name] = sorted(short_predecessors)
text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:")
msg = (
f"There are some tasks which produce the same output. See the following "
Expand Down
43 changes: 24 additions & 19 deletions src/_pytask/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import enum
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import cast

import click
import networkx as nx
from rich.text import Text

from _pytask.click import ColoredCommand
Expand All @@ -30,6 +30,11 @@
from _pytask.shared import reduce_names_of_multiple_nodes
from _pytask.traceback import Traceback

if TYPE_CHECKING:
import networkx as nx

from _pytask.dag_graph import DAG


class _RankDirection(enum.Enum):
TB = "TB"
Expand Down Expand Up @@ -92,6 +97,7 @@ def dag(**raw_config: Any) -> int:
else:
try:
session.hook.pytask_log_session_header(session=session)
import_optional_dependency("networkx")
import_optional_dependency("pygraphviz")
check_for_optional_program(
session.config["layout"],
Expand All @@ -100,7 +106,7 @@ def dag(**raw_config: Any) -> int:
)
session.hook.pytask_collect(session=session)
session.dag = create_dag(session=session)
dag = _refine_dag(session)
dag = _to_visualization_graph(session)
_write_graph(dag, session.config["output_path"], session.config["layout"])

except CollectionError: # pragma: no cover
Expand Down Expand Up @@ -163,6 +169,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:

else:
session.hook.pytask_log_session_header(session=session)
import_optional_dependency("networkx")
import_optional_dependency("pygraphviz")
check_for_optional_program(
session.config["layout"],
Expand All @@ -172,44 +179,42 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
session.hook.pytask_collect(session=session)
session.dag = create_dag(session=session)
session.hook.pytask_unconfigure(session=session)
return _refine_dag(session)
return _to_visualization_graph(session)


def _refine_dag(session: Session) -> nx.DiGraph:
def _refine_dag(session: Session) -> DAG:
"""Refine the dag for plotting."""
dag = _shorten_node_labels(session.dag, session.config["paths"])
dag = _clean_dag(dag)
dag = _style_dag(dag)
dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name}
return _clean_dag(dag)


def _to_visualization_graph(session: Session) -> nx.DiGraph:
"""Convert the internal DAG to a styled networkx graph for visualization."""
nx = cast("Any", import_optional_dependency("networkx"))
dag = _refine_dag(session).to_networkx()
dag.graph["graph"] = {"rankdir": session.config["rank_direction"].name}
shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes}
nx.set_node_attributes(dag, shapes, "shape")
return dag


def _shorten_node_labels(dag: nx.DiGraph, paths: list[Path]) -> nx.DiGraph:
def _shorten_node_labels(dag: DAG, paths: list[Path]) -> DAG:
"""Shorten the node labels in the graph for a better experience."""
node_names = dag.nodes
short_names = reduce_names_of_multiple_nodes(node_names, dag, paths)
short_names = [i.plain if isinstance(i, Text) else i for i in short_names]
old_to_new = dict(zip(node_names, short_names, strict=False))
return nx.relabel_nodes(dag, old_to_new)
return dag.relabel_nodes(old_to_new)


def _clean_dag(dag: nx.DiGraph) -> nx.DiGraph:
def _clean_dag(dag: DAG) -> DAG:
"""Clean the DAG."""
for node in dag.nodes:
dag.nodes[node].clear()
return dag


def _style_dag(dag: nx.DiGraph) -> nx.DiGraph:
"""Style the DAG."""
shapes = {name: "hexagon" if "::task_" in name else "box" for name in dag.nodes}
nx.set_node_attributes(dag, shapes, "shape")
return dag


def _write_graph(dag: nx.DiGraph, path: Path, layout: str) -> None:
"""Write the graph to disk."""
nx = cast("Any", import_optional_dependency("networkx"))
path.parent.mkdir(exist_ok=True, parents=True)
graph = nx.nx_agraph.to_agraph(dag)
graph.draw(path, prog=layout)
Expand Down
Loading
Loading