Skip to content

Commit fb103b6

Browse files
committed
Store DAG payloads directly
1 parent 1695d96 commit fb103b6

11 files changed

Lines changed: 67 additions & 100 deletions

File tree

src/_pytask/dag.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from _pytask.console import format_task_name
1818
from _pytask.console import render_to_string
1919
from _pytask.dag_graph import DAG
20-
from _pytask.dag_graph import DAGNode
2120
from _pytask.dag_graph import NoCycleError
2221
from _pytask.dag_graph import find_cycle
2322
from _pytask.exceptions import ResolvingDependenciesError
@@ -67,7 +66,7 @@ def _create_dag_from_tasks(tasks: list[PTask]) -> DAG:
6766
"""Create the DAG from tasks, dependencies and products."""
6867

6968
def _add_node_data(dag: DAG, node: PNode | PProvisionalNode) -> None:
70-
dag.add_node(node.signature, DAGNode.from_node(node))
69+
dag.add_node(node.signature, node)
7170
if isinstance(node, PythonNode) and isinstance(node.value, PythonNode):
7271
_add_node_data(dag, node.value)
7372

@@ -88,7 +87,7 @@ def _add_product(dag: DAG, task: PTask, node: PNode | PProvisionalNode) -> None:
8887
dag = DAG()
8988

9089
for task in tasks:
91-
dag.add_node(task.signature, DAGNode.from_task(task))
90+
dag.add_node(task.signature, task)
9291
tree_map(lambda x: _add_node_data(dag, x), task.depends_on)
9392
tree_map(lambda x: _add_node_data(dag, x), task.produces)
9493

@@ -147,7 +146,7 @@ def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str:
147146

148147
lines: list[str] = []
149148
for x in chain:
150-
node = dag.nodes[x].value
149+
node = dag.nodes[x]
151150
if isinstance(node, PTask):
152151
short_name = format_task_name(node, editor_url_scheme="no_link").plain
153152
elif isinstance(node, (PNode, PProvisionalNode)):
@@ -173,17 +172,19 @@ def _check_if_tasks_have_the_same_products(dag: DAG, paths: list[Path]) -> None:
173172
nodes_created_by_multiple_tasks = []
174173

175174
for node in dag.nodes:
176-
if dag.nodes[node].node is not None:
175+
if isinstance(dag.nodes[node], (PNode, PProvisionalNode)):
177176
parents = list(dag.predecessors(node))
178177
if len(parents) > 1:
179178
nodes_created_by_multiple_tasks.append(node)
180179

181180
if nodes_created_by_multiple_tasks:
182181
dictionary = {}
183182
for node in nodes_created_by_multiple_tasks:
184-
short_node_name = format_node_name(
185-
dag.nodes[node].node_or_raise(), paths
186-
).plain
183+
payload = dag.nodes[node]
184+
if not isinstance(payload, (PNode, PProvisionalNode)):
185+
msg = f"Expected product node for signature {node!r}."
186+
raise TypeError(msg)
187+
short_node_name = format_node_name(payload, paths).plain
187188
short_predecessors = reduce_names_of_multiple_nodes(
188189
dag.predecessors(node), dag, paths
189190
)

src/_pytask/dag_graph.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,79 +10,37 @@
1010
from typing import cast
1111

1212
from _pytask.compat import import_optional_dependency
13+
from _pytask.node_protocols import PNode
14+
from _pytask.node_protocols import PProvisionalNode
15+
from _pytask.node_protocols import PTask
1316

1417
if TYPE_CHECKING:
1518
from collections.abc import Callable
1619
from collections.abc import Iterable
1720
from collections.abc import Iterator
1821
from collections.abc import Mapping
1922

20-
from _pytask.node_protocols import PNode
21-
from _pytask.node_protocols import PProvisionalNode
22-
from _pytask.node_protocols import PTask
23+
24+
DAGEntry = PTask | PNode | PProvisionalNode
2325

2426

2527
class NoCycleError(Exception):
2628
"""Raised when no cycle is found in a graph."""
2729

2830

29-
@dataclass(slots=True)
30-
class DAGNode:
31-
"""Payload stored for nodes in pytask's internal DAG."""
32-
33-
task: PTask | None = None
34-
node: PNode | PProvisionalNode | None = None
35-
36-
def __post_init__(self) -> None:
37-
if (self.task is None) == (self.node is None):
38-
msg = "A DAG node must store exactly one of 'task' or 'node'."
39-
raise ValueError(msg)
40-
41-
@classmethod
42-
def from_task(cls, task: PTask) -> DAGNode:
43-
"""Create a DAG node from a task."""
44-
return cls(task=task)
45-
46-
@classmethod
47-
def from_node(cls, node: PNode | PProvisionalNode) -> DAGNode:
48-
"""Create a DAG node from a dependency or product node."""
49-
return cls(node=node)
50-
51-
@property
52-
def value(self) -> PTask | PNode | PProvisionalNode:
53-
"""Return the wrapped task or node."""
54-
if self.task is not None:
55-
return self.task
56-
return cast("PNode | PProvisionalNode", self.node)
57-
58-
def task_or_raise(self) -> PTask:
59-
"""Return the wrapped task."""
60-
if self.task is None:
61-
msg = "Expected DAG payload to contain a task."
62-
raise TypeError(msg)
63-
return self.task
64-
65-
def node_or_raise(self) -> PNode | PProvisionalNode:
66-
"""Return the wrapped dependency or product node."""
67-
if self.node is None:
68-
msg = "Expected DAG payload to contain a node."
69-
raise TypeError(msg)
70-
return self.node
71-
72-
7331
@dataclass
7432
class DAG:
7533
"""A minimal directed graph tailored to pytask's needs."""
7634

77-
_node_data: dict[str, DAGNode] = field(default_factory=dict)
35+
_node_data: dict[str, DAGEntry] = field(default_factory=dict)
7836
_successors: dict[str, set[str]] = field(default_factory=dict)
7937
_predecessors: dict[str, set[str]] = field(default_factory=dict)
8038

8139
@property
82-
def nodes(self) -> dict[str, DAGNode]:
40+
def nodes(self) -> dict[str, DAGEntry]:
8341
return self._node_data
8442

85-
def add_node(self, node_name: str, data: DAGNode) -> None:
43+
def add_node(self, node_name: str, data: DAGEntry) -> None:
8644
if node_name not in self._node_data:
8745
self._successors[node_name] = set()
8846
self._predecessors[node_name] = set()

src/_pytask/dag_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,17 @@
1313
from _pytask.dag_graph import descendants
1414
from _pytask.dag_graph import find_cycle
1515
from _pytask.mark_utils import has_mark
16+
from _pytask.node_protocols import PTask
1617

1718
if TYPE_CHECKING:
1819
from collections.abc import Generator
1920
from collections.abc import Iterable
2021

21-
from _pytask.node_protocols import PTask
22-
2322

2423
def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]:
2524
"""Yield only descending tasks."""
2625
for descendant in descendants(dag, task_name):
27-
if dag.nodes[descendant].task is not None:
26+
if isinstance(dag.nodes[descendant], PTask):
2827
yield descendant
2928

3029

@@ -37,7 +36,7 @@ def task_and_descending_tasks(task_name: str, dag: DAG) -> Generator[str, None,
3736
def preceding_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]:
3837
"""Yield only preceding tasks."""
3938
for ancestor in ancestors(dag, task_name):
40-
if dag.nodes[ancestor].task is not None:
39+
if isinstance(dag.nodes[ancestor], PTask):
4140
yield ancestor
4241

4342

@@ -86,7 +85,7 @@ def from_dag(cls, dag: DAG) -> TopologicalSorter:
8685
"""Instantiate from a DAG."""
8786
cls.check_dag(dag)
8887

89-
tasks = [node.task for node in dag.nodes.values() if node.task is not None]
88+
tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)]
9089
priorities = _extract_priorities_from_tasks(tasks)
9190

9291
task_signatures = {task.signature for task in tasks}

src/_pytask/database_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def update_states_in_database(session: Session, task_signature: str) -> None:
116116
if _ENGINE is None:
117117
return
118118
for name in node_and_neighbors(session.dag, task_signature):
119-
node = session.dag.nodes[name].value
119+
node = session.dag.nodes[name]
120120
if isinstance(node, PProvisionalNode):
121121
continue
122122
hash_ = node.state()

src/_pytask/execute.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def pytask_execute_build(session: Session) -> bool | None:
8989
if isinstance(session.scheduler, TopologicalSorter):
9090
while session.scheduler.is_active():
9191
task_name = session.scheduler.get_ready()[0]
92-
task = session.dag.nodes[task_name].task_or_raise()
92+
task = session.dag.nodes[task_name]
93+
if not isinstance(task, PTask):
94+
msg = f"Expected task node for signature {task_name!r}."
95+
raise TypeError(msg)
9396
report = session.hook.pytask_execute_task_protocol(
9497
session=session, task=task
9598
)
@@ -172,7 +175,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C
172175
if not needs_to_be_executed:
173176
predecessors = set(dag.predecessors(task.signature)) | {task.signature}
174177
for node_signature in node_and_neighbors(dag, task.signature):
175-
node = dag.nodes[node_signature].value
178+
node = dag.nodes[node_signature]
176179

177180
# Skip provisional nodes that are products since they do not have a state.
178181
if node_signature not in predecessors and isinstance(
@@ -253,7 +256,10 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None: # noqa: C
253256
# Create directory for product if it does not exist. Maybe this should be a `setup`
254257
# method for the node classes.
255258
for product in dag.successors(task.signature):
256-
node = dag.nodes[product].node_or_raise()
259+
node = dag.nodes[product]
260+
if not isinstance(node, (PNode, PProvisionalNode)):
261+
msg = f"Expected product node for signature {product!r}."
262+
raise TypeError(msg)
257263
if isinstance(node, PPathNode):
258264
node.path.parent.mkdir(parents=True, exist_ok=True)
259265
if isinstance(node, DirectoryNode) and node.root_dir:
@@ -347,7 +353,12 @@ def pytask_execute_task_process_report(
347353
report.outcome = TaskOutcome.WOULD_BE_EXECUTED
348354

349355
for descending_task_name in descending_tasks(task.signature, session.dag):
350-
descending_task = session.dag.nodes[descending_task_name].task_or_raise()
356+
descending_task = session.dag.nodes[descending_task_name]
357+
if not isinstance(descending_task, PTask):
358+
msg = (
359+
f"Expected descending task for signature {descending_task_name!r}."
360+
)
361+
raise TypeError(msg)
351362
descending_task.markers.append(
352363
Mark(
353364
"would_be_executed",
@@ -360,7 +371,12 @@ def pytask_execute_task_process_report(
360371
)
361372
else:
362373
for descending_task_name in descending_tasks(task.signature, session.dag):
363-
descending_task = session.dag.nodes[descending_task_name].task_or_raise()
374+
descending_task = session.dag.nodes[descending_task_name]
375+
if not isinstance(descending_task, PTask):
376+
msg = (
377+
f"Expected descending task for signature {descending_task_name!r}."
378+
)
379+
raise TypeError(msg)
364380
descending_task.markers.append(
365381
Mark(
366382
"skip_ancestor_failed",

src/_pytask/lockfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def _build_task_entry(session: Session, task: PTask, root: Path) -> _TaskEntry |
225225

226226
depends_on: dict[str, str] = {}
227227
for node_signature in predecessors:
228-
node = dag.nodes[node_signature].value
228+
node = dag.nodes[node_signature]
229229
if not isinstance(node, (PNode, PTask)):
230230
continue
231231
state = node.state()
@@ -240,7 +240,7 @@ def _build_task_entry(session: Session, task: PTask, root: Path) -> _TaskEntry |
240240

241241
produces: dict[str, str] = {}
242242
for node_signature in successors:
243-
node = dag.nodes[node_signature].value
243+
node = dag.nodes[node_signature]
244244
if not isinstance(node, (PNode, PTask)):
245245
continue
246246
state = node.state()

src/_pytask/persist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def pytask_execute_task_setup(session: Session, task: PTask) -> None:
5151
if has_mark(task, "persist"):
5252
stateful_nodes: list[tuple[PTask | PNode, str | None]] = []
5353
for name in node_and_neighbors(session.dag, task.signature):
54-
node = session.dag.nodes[name].value
54+
node = session.dag.nodes[name]
5555
if isinstance(node, PProvisionalNode):
5656
continue
5757
stateful_nodes.append((node, node.state()))

src/_pytask/profile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def pytask_profile_add_info_on_task(
210210
if successors:
211211
sum_bytes = 0
212212
for successor in successors:
213-
node = session.dag.nodes[successor].node_or_raise()
213+
node = session.dag.nodes[successor]
214214
if isinstance(node, PPathNode):
215215
with suppress(FileNotFoundError):
216216
sum_bytes += node.path.stat().st_size

src/_pytask/shared.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def reduce_names_of_multiple_nodes(
8484
"""Reduce the names of multiple nodes in the DAG."""
8585
short_names = []
8686
for name in names:
87-
node = dag.nodes[name].value
87+
node = dag.nodes[name]
8888

8989
if isinstance(node, PTask):
9090
short_name = format_task_name(node, editor_url_scheme="no_link").plain

src/_pytask/skipping.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from _pytask.mark import Mark
1010
from _pytask.mark_utils import get_marks
1111
from _pytask.mark_utils import has_mark
12+
from _pytask.node_protocols import PTask
1213
from _pytask.outcomes import Skipped
1314
from _pytask.outcomes import SkippedAncestorFailed
1415
from _pytask.outcomes import SkippedUnchanged
@@ -17,7 +18,6 @@
1718
from _pytask.provisional_utils import collect_provisional_products
1819

1920
if TYPE_CHECKING:
20-
from _pytask.node_protocols import PTask
2121
from _pytask.reports import ExecutionReport
2222
from _pytask.session import Session
2323

@@ -97,9 +97,13 @@ def pytask_execute_task_process_report(
9797
report.outcome = TaskOutcome.SKIP
9898

9999
for descending_task_name in descending_tasks(task.signature, session.dag):
100-
descending_task = session.dag.nodes[
101-
descending_task_name
102-
].task_or_raise()
100+
descending_task = session.dag.nodes[descending_task_name]
101+
if not isinstance(descending_task, PTask):
102+
msg = (
103+
f"Expected descending task for signature "
104+
f"{descending_task_name!r}."
105+
)
106+
raise TypeError(msg)
103107
descending_task.markers.append(
104108
Mark(
105109
"skip",

0 commit comments

Comments
 (0)