33from __future__ import annotations
44
55import itertools
6+ from dataclasses import dataclass
7+ from dataclasses import field
68from typing import TYPE_CHECKING
79from typing import Any
810from typing import cast
@@ -17,56 +19,18 @@ class NoCycleError(Exception):
1719 """Raised when no cycle is found in a graph."""
1820
1921
20- class NodeView :
21- """A minimal mapping-like view over node attributes."""
22-
23- def __init__ (self , node_attributes : dict [str , dict [str , Any ]]) -> None :
24- self ._node_attributes = node_attributes
25-
26- def __getitem__ (self , node : str ) -> dict [str , Any ]:
27- return self ._node_attributes [node ]
28-
29- def __iter__ (self ) -> Iterator [str ]:
30- return iter (self ._node_attributes )
31-
32- def __len__ (self ) -> int :
33- return len (self ._node_attributes )
34-
35- def __contains__ (self , node : object ) -> bool :
36- return node in self ._node_attributes
37-
38-
39- class UndirectedGraph :
40- """A minimal undirected graph used for validation tests."""
41-
42- def __init__ (
43- self ,
44- node_attributes : dict [str , dict [str , Any ]],
45- adjacency : dict [str , dict [str , None ]],
46- graph_attributes : dict [str , Any ],
47- ) -> None :
48- self ._node_attributes = {
49- node : attributes .copy () for node , attributes in node_attributes .items ()
50- }
51- self ._adjacency = {
52- node : neighbors .copy () for node , neighbors in adjacency .items ()
53- }
54- self .graph = graph_attributes .copy ()
55- self .nodes = NodeView (self ._node_attributes )
56-
57- def is_directed (self ) -> bool :
58- return False
59-
60-
22+ @dataclass
6123class DiGraph :
6224 """A minimal directed graph tailored to pytask's needs."""
6325
64- def __init__ (self ) -> None :
65- self ._node_attributes : dict [str , dict [str , Any ]] = {}
66- self ._successors : dict [str , dict [str , None ]] = {}
67- self ._predecessors : dict [str , dict [str , None ]] = {}
68- self .graph : dict [str , Any ] = {}
69- self .nodes = NodeView (self ._node_attributes )
26+ _node_attributes : dict [str , dict [str , Any ]] = field (default_factory = dict )
27+ _successors : dict [str , dict [str , None ]] = field (default_factory = dict )
28+ _predecessors : dict [str , dict [str , None ]] = field (default_factory = dict )
29+ graph : dict [str , Any ] = field (default_factory = dict )
30+
31+ @property
32+ def nodes (self ) -> dict [str , dict [str , Any ]]:
33+ return self ._node_attributes
7034
7135 def add_node (self , node_name : str , ** attributes : Any ) -> None :
7236 if node_name not in self ._node_attributes :
@@ -138,16 +102,6 @@ def set_node_attributes(self, values: dict[str, Any], name: str) -> None:
138102 if node in self ._node_attributes :
139103 self ._node_attributes [node ][name ] = value
140104
141- def to_undirected (self ) -> UndirectedGraph :
142- adjacency = {
143- node : {
144- ** self ._predecessors [node ],
145- ** self ._successors [node ],
146- }
147- for node in self ._node_attributes
148- }
149- return UndirectedGraph (self ._node_attributes , adjacency , self .graph )
150-
151105 def to_networkx (self ) -> Any :
152106 nx = cast ("Any" , import_optional_dependency ("networkx" ))
153107 graph = nx .DiGraph ()
0 commit comments