Skip to content

Commit 7c7fef8

Browse files
committed
Make internal DAG generic and typed
1 parent 6028419 commit 7c7fef8

14 files changed

Lines changed: 229 additions & 155 deletions

src/_pytask/dag.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from _pytask.console import format_node_name
1717
from _pytask.console import format_task_name
1818
from _pytask.console import render_to_string
19+
from _pytask.dag_graph import DagNode
1920
from _pytask.dag_graph import DiGraph
2021
from _pytask.dag_graph import NoCycleError
2122
from _pytask.dag_graph import find_cycle
@@ -39,7 +40,7 @@
3940
__all__ = ["create_dag", "create_dag_from_session"]
4041

4142

42-
def create_dag(session: Session) -> DiGraph:
43+
def create_dag(session: Session) -> DiGraph[str, DagNode]:
4344
"""Create a directed acyclic graph (DAG) for the workflow."""
4445
try:
4546
dag = create_dag_from_session(session)
@@ -52,7 +53,7 @@ def create_dag(session: Session) -> DiGraph:
5253
return dag
5354

5455

55-
def create_dag_from_session(session: Session) -> DiGraph:
56+
def create_dag_from_session(session: Session) -> DiGraph[str, DagNode]:
5657
"""Create a DAG from a session."""
5758
dag = _create_dag_from_tasks(tasks=session.tasks)
5859
_check_if_dag_has_cycles(dag)
@@ -62,14 +63,20 @@ def create_dag_from_session(session: Session) -> DiGraph:
6263
return dag
6364

6465

65-
def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph:
66+
def _create_dag_from_tasks(tasks: list[PTask]) -> DiGraph[str, DagNode]:
6667
"""Create the DAG from tasks, dependencies and products."""
6768

69+
def _add_node_data(
70+
dag: DiGraph[str, DagNode], node: PNode | PProvisionalNode
71+
) -> None:
72+
dag.add_node(node.signature, DagNode.from_node(node))
73+
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
74+
_add_node_data(dag, node.value)
75+
6876
def _add_dependency(
69-
dag: DiGraph, task: PTask, node: PNode | PProvisionalNode
77+
dag: DiGraph[str, DagNode], task: PTask, node: PNode | PProvisionalNode
7078
) -> None:
7179
"""Add a dependency to the DAG."""
72-
dag.add_node(node.signature, node=node)
7380
dag.add_edge(node.signature, task.signature)
7481

7582
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
@@ -78,34 +85,26 @@ def _add_dependency(
7885
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
7986
dag.add_edge(node.value.signature, node.signature)
8087

81-
def _add_product(dag: DiGraph, task: PTask, node: PNode | PProvisionalNode) -> None:
88+
def _add_product(
89+
dag: DiGraph[str, DagNode], task: PTask, node: PNode | PProvisionalNode
90+
) -> None:
8291
"""Add a product to the DAG."""
83-
dag.add_node(node.signature, node=node)
8492
dag.add_edge(task.signature, node.signature)
8593

86-
dag = DiGraph()
94+
dag = DiGraph[str, DagNode]()
8795

8896
for task in tasks:
89-
dag.add_node(task.signature, task=task)
97+
dag.add_node(task.signature, DagNode.from_task(task))
98+
tree_map(lambda x: _add_node_data(dag, x), task.depends_on)
99+
tree_map(lambda x: _add_node_data(dag, x), task.produces)
90100

101+
for task in tasks:
91102
tree_map(lambda x: _add_dependency(dag, task, x), task.depends_on)
92103
tree_map(lambda x: _add_product(dag, task, x), task.produces)
93-
94-
# If a node is a PythonNode wrapped in another PythonNode, it is a product from
95-
# another task that is a dependency in the current task. Thus, draw an edge
96-
# connecting the two nodes.
97-
tree_map(
98-
lambda x: (
99-
dag.add_edge(x.value.signature, x.signature)
100-
if isinstance(x, PythonNode) and isinstance(x.value, PythonNode)
101-
else None
102-
),
103-
task.depends_on,
104-
)
105104
return dag
106105

107106

108-
def _modify_dag(session: Session, dag: DiGraph) -> DiGraph:
107+
def _modify_dag(session: Session, dag: DiGraph[str, DagNode]) -> DiGraph[str, DagNode]:
109108
"""Create dependencies between tasks when using ``@task(after=...)``."""
110109
temporary_id_to_task = {
111110
task.attributes["collection_id"]: task
@@ -129,7 +128,7 @@ def _modify_dag(session: Session, dag: DiGraph) -> DiGraph:
129128
return dag
130129

131130

132-
def _check_if_dag_has_cycles(dag: DiGraph) -> None:
131+
def _check_if_dag_has_cycles(dag: DiGraph[str, DagNode]) -> None:
133132
"""Check if DAG has cycles."""
134133
try:
135134
cycles = find_cycle(dag)
@@ -145,7 +144,7 @@ def _check_if_dag_has_cycles(dag: DiGraph) -> None:
145144
raise ResolvingDependenciesError(msg)
146145

147146

148-
def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str:
147+
def _format_cycles(dag: DiGraph[str, DagNode], cycles: list[tuple[str, str]]) -> str:
149148
"""Format cycles as a paths connected by arrows."""
150149
chain = [
151150
x for i, x in enumerate(itertools.chain.from_iterable(cycles)) if i % 2 == 0
@@ -154,7 +153,7 @@ def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str:
154153

155154
lines: list[str] = []
156155
for x in chain:
157-
node = dag.nodes[x].get("task") or dag.nodes[x].get("node")
156+
node = dag.nodes[x].value
158157
if isinstance(node, PTask):
159158
short_name = format_task_name(node, editor_url_scheme="no_link").plain
160159
elif isinstance(node, (PNode, PProvisionalNode)):
@@ -176,24 +175,27 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
176175
return render_to_string(tree, console=console, strip_styles=True)
177176

178177

179-
def _check_if_tasks_have_the_same_products(dag: DiGraph, paths: list[Path]) -> None:
178+
def _check_if_tasks_have_the_same_products(
179+
dag: DiGraph[str, DagNode], paths: list[Path]
180+
) -> None:
180181
nodes_created_by_multiple_tasks = []
181182

182183
for node in dag.nodes:
183-
is_node = "node" in dag.nodes[node]
184-
if is_node:
184+
if dag.nodes[node].node is not None:
185185
parents = list(dag.predecessors(node))
186186
if len(parents) > 1:
187187
nodes_created_by_multiple_tasks.append(node)
188188

189189
if nodes_created_by_multiple_tasks:
190190
dictionary = {}
191191
for node in nodes_created_by_multiple_tasks:
192-
short_node_name = format_node_name(dag.nodes[node]["node"], paths).plain
192+
short_node_name = format_node_name(
193+
dag.nodes[node].node_or_raise(), paths
194+
).plain
193195
short_predecessors = reduce_names_of_multiple_nodes(
194196
dag.predecessors(node), dag, paths
195197
)
196-
dictionary[short_node_name] = short_predecessors
198+
dictionary[short_node_name] = sorted(short_predecessors)
197199
text = _format_dictionary_to_tree(dictionary, "Products from multiple tasks:")
198200
msg = (
199201
f"There are some tasks which produce the same output. See the following "

src/_pytask/dag_command.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
if TYPE_CHECKING:
3434
import networkx as nx
3535

36+
from _pytask.dag_graph import DagNode
3637
from _pytask.dag_graph import DiGraph
3738

3839

@@ -182,7 +183,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
182183
return _to_visualization_graph(session)
183184

184185

185-
def _refine_dag(session: Session) -> DiGraph:
186+
def _refine_dag(session: Session) -> DiGraph[str, DagNode]:
186187
"""Refine the dag for plotting."""
187188
dag = _shorten_node_labels(session.dag, session.config["paths"])
188189
return _clean_dag(dag)
@@ -198,7 +199,9 @@ def _to_visualization_graph(session: Session) -> nx.DiGraph:
198199
return dag
199200

200201

201-
def _shorten_node_labels(dag: DiGraph, paths: list[Path]) -> DiGraph:
202+
def _shorten_node_labels(
203+
dag: DiGraph[str, DagNode], paths: list[Path]
204+
) -> DiGraph[str, DagNode]:
202205
"""Shorten the node labels in the graph for a better experience."""
203206
node_names = dag.nodes
204207
short_names = reduce_names_of_multiple_nodes(node_names, dag, paths)
@@ -207,10 +210,8 @@ def _shorten_node_labels(dag: DiGraph, paths: list[Path]) -> DiGraph:
207210
return dag.relabel_nodes(old_to_new)
208211

209212

210-
def _clean_dag(dag: DiGraph) -> DiGraph:
213+
def _clean_dag(dag: DiGraph[str, DagNode]) -> DiGraph[str, DagNode]:
211214
"""Clean the DAG."""
212-
for node in dag.nodes:
213-
dag.nodes[node].clear()
214215
return dag
215216

216217

0 commit comments

Comments
 (0)