1616from _pytask .console import format_node_name
1717from _pytask .console import format_task_name
1818from _pytask .console import render_to_string
19+ from _pytask .dag_graph import DagNode
1920from _pytask .dag_graph import DiGraph
2021from _pytask .dag_graph import NoCycleError
2122from _pytask .dag_graph import find_cycle
3940__all__ = ["create_dag" , "create_dag_from_session" ]
4041
4142
42- def create_dag (session : Session ) -> DiGraph :
43+ def create_dag (session : Session ) -> DiGraph [ str , DagNode ] :
4344 """Create a directed acyclic graph (DAG) for the workflow."""
4445 try :
4546 dag = create_dag_from_session (session )
@@ -52,7 +53,7 @@ def create_dag(session: Session) -> DiGraph:
5253 return dag
5354
5455
55- def create_dag_from_session (session : Session ) -> DiGraph :
56+ def create_dag_from_session (session : Session ) -> DiGraph [ str , DagNode ] :
5657 """Create a DAG from a session."""
5758 dag = _create_dag_from_tasks (tasks = session .tasks )
5859 _check_if_dag_has_cycles (dag )
@@ -62,14 +63,20 @@ def create_dag_from_session(session: Session) -> DiGraph:
6263 return dag
6364
6465
65- def _create_dag_from_tasks (tasks : list [PTask ]) -> DiGraph :
66+ def _create_dag_from_tasks (tasks : list [PTask ]) -> DiGraph [ str , DagNode ] :
6667 """Create the DAG from tasks, dependencies and products."""
6768
69+ def _add_node_data (
70+ dag : DiGraph [str , DagNode ], node : PNode | PProvisionalNode
71+ ) -> None :
72+ dag .add_node (node .signature , DagNode .from_node (node ))
73+ if isinstance (node , PythonNode ) and isinstance (node .value , PythonNode ):
74+ _add_node_data (dag , node .value )
75+
6876 def _add_dependency (
69- dag : DiGraph , task : PTask , node : PNode | PProvisionalNode
77+ dag : DiGraph [ str , DagNode ] , task : PTask , node : PNode | PProvisionalNode
7078 ) -> None :
7179 """Add a dependency to the DAG."""
72- dag .add_node (node .signature , node = node )
7380 dag .add_edge (node .signature , task .signature )
7481
7582 # If a node is a PythonNode wrapped in another PythonNode, it is a product from
@@ -78,34 +85,26 @@ def _add_dependency(
7885 if isinstance (node , PythonNode ) and isinstance (node .value , PythonNode ):
7986 dag .add_edge (node .value .signature , node .signature )
8087
81- def _add_product (dag : DiGraph , task : PTask , node : PNode | PProvisionalNode ) -> None :
88+ def _add_product (
89+ dag : DiGraph [str , DagNode ], task : PTask , node : PNode | PProvisionalNode
90+ ) -> None :
8291 """Add a product to the DAG."""
83- dag .add_node (node .signature , node = node )
8492 dag .add_edge (task .signature , node .signature )
8593
86- dag = DiGraph ()
94+ dag = DiGraph [ str , DagNode ] ()
8795
8896 for task in tasks :
89- dag .add_node (task .signature , task = task )
97+ dag .add_node (task .signature , DagNode .from_task (task ))
98+ tree_map (lambda x : _add_node_data (dag , x ), task .depends_on )
99+ tree_map (lambda x : _add_node_data (dag , x ), task .produces )
90100
101+ for task in tasks :
91102 tree_map (lambda x : _add_dependency (dag , task , x ), task .depends_on )
92103 tree_map (lambda x : _add_product (dag , task , x ), task .produces )
93-
94- # If a node is a PythonNode wrapped in another PythonNode, it is a product from
95- # another task that is a dependency in the current task. Thus, draw an edge
96- # connecting the two nodes.
97- tree_map (
98- lambda x : (
99- dag .add_edge (x .value .signature , x .signature )
100- if isinstance (x , PythonNode ) and isinstance (x .value , PythonNode )
101- else None
102- ),
103- task .depends_on ,
104- )
105104 return dag
106105
107106
108- def _modify_dag (session : Session , dag : DiGraph ) -> DiGraph :
107+ def _modify_dag (session : Session , dag : DiGraph [ str , DagNode ] ) -> DiGraph [ str , DagNode ] :
109108 """Create dependencies between tasks when using ``@task(after=...)``."""
110109 temporary_id_to_task = {
111110 task .attributes ["collection_id" ]: task
@@ -129,7 +128,7 @@ def _modify_dag(session: Session, dag: DiGraph) -> DiGraph:
129128 return dag
130129
131130
132- def _check_if_dag_has_cycles (dag : DiGraph ) -> None :
131+ def _check_if_dag_has_cycles (dag : DiGraph [ str , DagNode ] ) -> None :
133132 """Check if DAG has cycles."""
134133 try :
135134 cycles = find_cycle (dag )
@@ -145,7 +144,7 @@ def _check_if_dag_has_cycles(dag: DiGraph) -> None:
145144 raise ResolvingDependenciesError (msg )
146145
147146
148- def _format_cycles (dag : DiGraph , cycles : list [tuple [str , str ]]) -> str :
147+ def _format_cycles (dag : DiGraph [ str , DagNode ] , cycles : list [tuple [str , str ]]) -> str :
149148 """Format cycles as a paths connected by arrows."""
150149 chain = [
151150 x for i , x in enumerate (itertools .chain .from_iterable (cycles )) if i % 2 == 0
@@ -154,7 +153,7 @@ def _format_cycles(dag: DiGraph, cycles: list[tuple[str, str]]) -> str:
154153
155154 lines : list [str ] = []
156155 for x in chain :
157- node = dag .nodes [x ].get ( "task" ) or dag . nodes [ x ]. get ( "node" )
156+ node = dag .nodes [x ].value
158157 if isinstance (node , PTask ):
159158 short_name = format_task_name (node , editor_url_scheme = "no_link" ).plain
160159 elif isinstance (node , (PNode , PProvisionalNode )):
@@ -176,24 +175,27 @@ def _format_dictionary_to_tree(dict_: dict[str, list[str]], title: str) -> str:
176175 return render_to_string (tree , console = console , strip_styles = True )
177176
178177
179- def _check_if_tasks_have_the_same_products (dag : DiGraph , paths : list [Path ]) -> None :
178+ def _check_if_tasks_have_the_same_products (
179+ dag : DiGraph [str , DagNode ], paths : list [Path ]
180+ ) -> None :
180181 nodes_created_by_multiple_tasks = []
181182
182183 for node in dag .nodes :
183- is_node = "node" in dag .nodes [node ]
184- if is_node :
184+ if dag .nodes [node ].node is not None :
185185 parents = list (dag .predecessors (node ))
186186 if len (parents ) > 1 :
187187 nodes_created_by_multiple_tasks .append (node )
188188
189189 if nodes_created_by_multiple_tasks :
190190 dictionary = {}
191191 for node in nodes_created_by_multiple_tasks :
192- short_node_name = format_node_name (dag .nodes [node ]["node" ], paths ).plain
192+ short_node_name = format_node_name (
193+ dag .nodes [node ].node_or_raise (), paths
194+ ).plain
193195 short_predecessors = reduce_names_of_multiple_nodes (
194196 dag .predecessors (node ), dag , paths
195197 )
196- dictionary [short_node_name ] = short_predecessors
198+ dictionary [short_node_name ] = sorted ( short_predecessors )
197199 text = _format_dictionary_to_tree (dictionary , "Products from multiple tasks:" )
198200 msg = (
199201 f"There are some tasks which produce the same output. See the following "
0 commit comments