1717from _pytask .console import format_task_name
1818from _pytask .console import render_to_string
1919from _pytask .dag_graph import DAG
20- from _pytask .dag_graph import DAGNode
2120from _pytask .dag_graph import NoCycleError
2221from _pytask .dag_graph import find_cycle
2322from _pytask .exceptions import ResolvingDependenciesError
@@ -67,7 +66,7 @@ def _create_dag_from_tasks(tasks: list[PTask]) -> DAG:
6766 """Create the DAG from tasks, dependencies and products."""
6867
6968 def _add_node_data (dag : DAG , node : PNode | PProvisionalNode ) -> None :
70- dag .add_node (node .signature , DAGNode . from_node ( node ) )
69+ dag .add_node (node .signature , node )
7170 if isinstance (node , PythonNode ) and isinstance (node .value , PythonNode ):
7271 _add_node_data (dag , node .value )
7372
@@ -88,7 +87,7 @@ def _add_product(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None:
8887 dag = DAG ()
8988
9089 for task in tasks :
91- dag .add_node (task .signature , DAGNode . from_task ( task ) )
90+ dag .add_node (task .signature , task )
9291 tree_map (lambda x : _add_node_data (dag , x ), task .depends_on )
9392 tree_map (lambda x : _add_node_data (dag , x ), task .produces )
9493
@@ -147,7 +146,7 @@ def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str:
147146
148147 lines : list [str ] = []
149148 for x in chain :
150- node = dag .nodes [x ]. value
149+ node = dag .nodes [x ]
151150 if isinstance (node , PTask ):
152151 short_name = format_task_name (node , editor_url_scheme = "no_link" ).plain
153152 elif isinstance (node , (PNode , PProvisionalNode )):
@@ -173,17 +172,19 @@ def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None:
173172 nodes_created_by_multiple_tasks = []
174173
175174 for node in dag .nodes :
176- if dag .nodes [node ]. node is not None :
175+ if isinstance ( dag .nodes [node ], ( PNode , PProvisionalNode )) :
177176 parents = list (dag .predecessors (node ))
178177 if len (parents ) > 1 :
179178 nodes_created_by_multiple_tasks .append (node )
180179
181180 if nodes_created_by_multiple_tasks :
182181 dictionary = {}
183182 for node in nodes_created_by_multiple_tasks :
184- short_node_name = format_node_name (
185- dag .nodes [node ].node_or_raise (), paths
186- ).plain
183+ payload = dag .nodes [node ]
184+ if not isinstance (payload , (PNode , PProvisionalNode )):
185+ msg = f"Expected product node for signature { node !r} ."
186+ raise TypeError (msg )
187+ short_node_name = format_node_name (payload , paths ).plain
187188 short_predecessors = reduce_names_of_multiple_nodes (
188189 dag .predecessors (node ), dag , paths
189190 )
0 commit comments