66import sys
77from typing import TYPE_CHECKING
88
9- import networkx as nx
109from rich .text import Text
1110from rich .tree import Tree
1211
1716from _pytask .console import format_node_name
1817from _pytask .console import format_task_name
1918from _pytask .console import render_to_string
19+ from _pytask .dag_graph import DAG
20+ from _pytask .dag_graph import NoCycleError
21+ from _pytask .dag_graph import find_cycle
2022from _pytask .exceptions import ResolvingDependenciesError
2123from _pytask .mark import select_by_after_keyword
2224from _pytask .mark import select_tasks_by_marks_and_expressions
3739__all__ = ["create_dag" , "create_dag_from_session" ]
3840
3941
40- def create_dag (session : Session ) -> nx . DiGraph :
42+ def create_dag (session : Session ) -> DAG :
4143 """Create a directed acyclic graph (DAG) for the workflow."""
4244 try :
4345 dag = create_dag_from_session (session )
@@ -50,7 +52,7 @@ def create_dag(session: Session) -> nx.DiGraph:
5052 return dag
5153
5254
53- def create_dag_from_session (session : Session ) -> nx . DiGraph :
55+ def create_dag_from_session (session : Session ) -> DAG :
5456 """Create a DAG from a session."""
5557 dag = _create_dag_from_tasks (tasks = session .tasks )
5658 _check_if_dag_has_cycles (dag )
@@ -60,14 +62,16 @@ def create_dag_from_session(session: Session) -> nx.DiGraph:
6062 return dag
6163
6264
63- def _create_dag_from_tasks (tasks : list [PTask ]) -> nx . DiGraph :
65+ def _create_dag_from_tasks (tasks : list [PTask ]) -> DAG :
6466 """Create the DAG from tasks, dependencies and products."""
6567
66- def _add_dependency (
67- dag : nx .DiGraph , task : PTask , node : PNode | PProvisionalNode
68- ) -> None :
68+ def _add_node_data (dag : DAG , node : PNode | PProvisionalNode ) -> None :
69+ dag .add_node (node .signature , node )
70+ if isinstance (node , PythonNode ) and isinstance (node .value , PythonNode ):
71+ _add_node_data (dag , node .value )
72+
73+ def _add_dependency (dag : DAG , task : PTask , node : PNode | PProvisionalNode ) -> None :
6974 """Add a dependency to the DAG."""
70- dag .add_node (node .signature , node = node )
7175 dag .add_edge (node .signature , task .signature )
7276
7377 # If a node is a PythonNode wrapped in another PythonNode, it is a product from
@@ -76,36 +80,24 @@ def _add_dependency(
7680 if isinstance (node , PythonNode ) and isinstance (node .value , PythonNode ):
7781 dag .add_edge (node .value .signature , node .signature )
7882
79- def _add_product (
80- dag : nx .DiGraph , task : PTask , node : PNode | PProvisionalNode
81- ) -> None :
83+ def _add_product (dag : DAG , task : PTask , node : PNode | PProvisionalNode ) -> None :
8284 """Add a product to the DAG."""
83- dag .add_node (node .signature , node = node )
8485 dag .add_edge (task .signature , node .signature )
8586
86- dag = nx . DiGraph ()
87+ dag = DAG ()
8788
8889 for task in tasks :
89- dag .add_node (task .signature , task = task )
90+ dag .add_node (task .signature , task )
91+ tree_map (lambda x : _add_node_data (dag , x ), task .depends_on )
92+ tree_map (lambda x : _add_node_data (dag , x ), task .produces )
9093
94+ for task in tasks :
9195 tree_map (lambda x : _add_dependency (dag , task , x ), task .depends_on )
9296 tree_map (lambda x : _add_product (dag , task , x ), task .produces )
93-
94- # If a node is a PythonNode wrapped in another PythonNode, it is a product from
95- # another task that is a dependency in the current task. Thus, draw an edge
96- # connecting the two nodes.
97- tree_map (
98- lambda x : (
99- dag .add_edge (x .value .signature , x .signature )
100- if isinstance (x , PythonNode ) and isinstance (x .value , PythonNode )
101- else None
102- ),
103- task .depends_on ,
104- )
10597 return dag
10698
10799
108- def _modify_dag (session : Session , dag : nx . DiGraph ) -> nx . DiGraph :
100+ def _modify_dag (session : Session , dag : DAG ) -> DAG :
109101 """Create dependencies between tasks when using ``@task(after=...)``."""
110102 temporary_id_to_task = {
111103 task .attributes ["collection_id" ]: task
@@ -129,11 +121,11 @@ def _modify_dag(session: Session, dag: nx.DiGraph) -> nx.DiGraph:
129121 return dag
130122
131123
132- def _check_if_dag_has_cycles (dag : nx . DiGraph ) -> None :
124+ def _check_if_dag_has_cycles (dag : DAG ) -> None :
133125 """Check if DAG has cycles."""
134126 try :
135- cycles = nx . algorithms . cycles . find_cycle (dag )
136- except nx . NetworkXNoCycle :
127+ cycles = find_cycle (dag )
128+ except NoCycleError :
137129 pass
138130 else :
139131 msg = (
@@ -145,7 +137,7 @@ def _check_if_dag_has_cycles(dag: nx.DiGraph) -> None:
145137 raise ResolvingDependenciesError (msg )
146138
147139
148- def _format_cycles (dag : nx . DiGraph , cycles : list [tuple [str , ... ]]) -> str :
140+ def _format_cycles (dag : DAG , cycles : list [tuple [str , str ]]) -> str :
149141 """Format cycles as a paths connected by arrows."""
150142 chain = [
151143 x for i , x in enumerate (itertools .chain .from_iterable (cycles )) if i % 2 == 0
@@ -154,7 +146,7 @@ def _format_cycles(dag: nx.DiGraph, cycles: list[tuple[str, ...]]) -> str:
154146
155147 lines : list [str ] = []
156148 for x in chain :
157- node = dag .nodes [x ]. get ( "task" ) or dag . nodes [ x ]. get ( "node" )
149+ node = dag .nodes [x ]
158150 if isinstance (node , PTask ):
159151 short_name = format_task_name (node , editor_url_scheme = "no_link" ).plain
160152 elif isinstance (node , (PNode , PProvisionalNode )):
@@ -176,24 +168,27 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
176168 return render_to_string (tree , console = console , strip_styles = True )
177169
178170
179- def _check_if_tasks_have_the_same_products (dag : nx . DiGraph , paths : list [Path ]) -> None :
171+ def _check_if_tasks_have_the_same_products (dag : DAG , paths : list [Path ]) -> None :
180172 nodes_created_by_multiple_tasks = []
181173
182174 for node in dag .nodes :
183- is_node = "node" in dag .nodes [node ]
184- if is_node :
175+ if isinstance (dag .nodes [node ], (PNode , PProvisionalNode )):
185176 parents = list (dag .predecessors (node ))
186177 if len (parents ) > 1 :
187178 nodes_created_by_multiple_tasks .append (node )
188179
189180 if nodes_created_by_multiple_tasks :
190181 dictionary = {}
191182 for node in nodes_created_by_multiple_tasks :
192- short_node_name = format_node_name (dag .nodes [node ]["node" ], paths ).plain
183+ payload = dag .nodes [node ]
184+ if not isinstance (payload , (PNode , PProvisionalNode )):
185+ msg = f"Expected product node for signature { node !r} ."
186+ raise TypeError (msg )
187+ short_node_name = format_node_name (payload , paths ).plain
193188 short_predecessors = reduce_names_of_multiple_nodes (
194189 dag .predecessors (node ), dag , paths
195190 )
196- dictionary [short_node_name ] = short_predecessors
191+ dictionary [short_node_name ] = sorted ( short_predecessors )
197192 text = _format_dictionary_to_tree (dictionary , "Products from multiple tasks:" )
198193 msg = (
199194 f"There are some tasks which produce the same output. See the following "
0 commit comments