Skip to content

Commit 5099f82

Browse files
committed
Simplify internal DAG types
1 parent 4b73493 commit 5099f82

2 files changed

Lines changed: 20 additions & 60 deletions

File tree

src/_pytask/dag_graph.py

Lines changed: 11 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from __future__ import annotations
44

55
import itertools
6+
from dataclasses import dataclass
7+
from dataclasses import field
68
from typing import TYPE_CHECKING
79
from typing import Any
810
from typing import cast
@@ -17,56 +19,18 @@ class NoCycleError(Exception):
1719
"""Raised when no cycle is found in a graph."""
1820

1921

20-
class NodeView:
21-
"""A minimal mapping-like view over node attributes."""
22-
23-
def __init__(self, node_attributes: dict[str, dict[str, Any]]) -> None:
24-
self._node_attributes = node_attributes
25-
26-
def __getitem__(self, node: str) -> dict[str, Any]:
27-
return self._node_attributes[node]
28-
29-
def __iter__(self) -> Iterator[str]:
30-
return iter(self._node_attributes)
31-
32-
def __len__(self) -> int:
33-
return len(self._node_attributes)
34-
35-
def __contains__(self, node: object) -> bool:
36-
return node in self._node_attributes
37-
38-
39-
class UndirectedGraph:
40-
"""A minimal undirected graph used for validation tests."""
41-
42-
def __init__(
43-
self,
44-
node_attributes: dict[str, dict[str, Any]],
45-
adjacency: dict[str, dict[str, None]],
46-
graph_attributes: dict[str, Any],
47-
) -> None:
48-
self._node_attributes = {
49-
node: attributes.copy() for node, attributes in node_attributes.items()
50-
}
51-
self._adjacency = {
52-
node: neighbors.copy() for node, neighbors in adjacency.items()
53-
}
54-
self.graph = graph_attributes.copy()
55-
self.nodes = NodeView(self._node_attributes)
56-
57-
def is_directed(self) -> bool:
58-
return False
59-
60-
22+
@dataclass
6123
class DiGraph:
6224
"""A minimal directed graph tailored to pytask's needs."""
6325

64-
def __init__(self) -> None:
65-
self._node_attributes: dict[str, dict[str, Any]] = {}
66-
self._successors: dict[str, dict[str, None]] = {}
67-
self._predecessors: dict[str, dict[str, None]] = {}
68-
self.graph: dict[str, Any] = {}
69-
self.nodes = NodeView(self._node_attributes)
26+
_node_attributes: dict[str, dict[str, Any]] = field(default_factory=dict)
27+
_successors: dict[str, dict[str, None]] = field(default_factory=dict)
28+
_predecessors: dict[str, dict[str, None]] = field(default_factory=dict)
29+
graph: dict[str, Any] = field(default_factory=dict)
30+
31+
@property
32+
def nodes(self) -> dict[str, dict[str, Any]]:
33+
return self._node_attributes
7034

7135
def add_node(self, node_name: str, **attributes: Any) -> None:
7236
if node_name not in self._node_attributes:
@@ -138,16 +102,6 @@ def set_node_attributes(self, values: dict[str, Any], name: str) -> None:
138102
if node in self._node_attributes:
139103
self._node_attributes[node][name] = value
140104

141-
def to_undirected(self) -> UndirectedGraph:
142-
adjacency = {
143-
node: {
144-
**self._predecessors[node],
145-
**self._successors[node],
146-
}
147-
for node in self._node_attributes
148-
}
149-
return UndirectedGraph(self._node_attributes, adjacency, self.graph)
150-
151105
def to_networkx(self) -> Any:
152106
nx = cast("Any", import_optional_dependency("networkx"))
153107
graph = nx.DiGraph()

tests/test_dag_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import ExitStack as does_not_raise # noqa: N813
4+
from dataclasses import dataclass
45
from pathlib import Path
56

67
import pytest
@@ -143,10 +144,15 @@ def test_extract_priorities_from_tasks(tasks, expectation, expected):
143144
assert result == expected
144145

145146

146-
def test_raise_error_for_undirected_graphs(dag):
147-
undirected_graph = dag.to_undirected()
147+
@dataclass
148+
class _UndirectedGraphStub:
149+
def is_directed(self):
150+
return False
151+
152+
153+
def test_raise_error_for_undirected_graphs():
148154
with pytest.raises(ValueError, match="Only directed graphs have a"):
149-
TopologicalSorter.from_dag(undirected_graph)
155+
TopologicalSorter.from_dag(_UndirectedGraphStub()) # type: ignore[arg-type]
150156

151157

152158
def test_raise_error_for_cycle_in_graph(dag):

0 commit comments

Comments
 (0)