|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import itertools |
6 | | -from dataclasses import dataclass |
7 | | -from dataclasses import field |
8 | 6 | from typing import TYPE_CHECKING |
9 | 7 |
|
10 | | -from _pytask.dag_graph import DAG |
11 | | -from _pytask.dag_graph import NoCycleError |
12 | | -from _pytask.dag_graph import find_cycle |
13 | | -from _pytask.mark_utils import has_mark |
14 | 8 | from _pytask.node_protocols import PTask |
15 | 9 |
|
16 | 10 | if TYPE_CHECKING: |
17 | 11 | from collections.abc import Generator |
18 | 12 | from collections.abc import Iterable |
19 | 13 |
|
| 14 | + from _pytask.dag_graph import DAG |
| 15 | + |
| 16 | + |
| 17 | +__all__ = [ |
| 18 | + "descending_tasks", |
| 19 | + "node_and_neighbors", |
| 20 | + "preceding_tasks", |
| 21 | + "task_and_descending_tasks", |
| 22 | + "task_and_preceding_tasks", |
| 23 | +] |
| 24 | + |
20 | 25 |
|
21 | 26 | def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: |
22 | 27 | """Yield only descending tasks.""" |
@@ -55,119 +60,3 @@ def node_and_neighbors(dag: DAG, node: str) -> Iterable[str]: |
55 | 60 |
|
56 | 61 | """ |
57 | 62 | return itertools.chain(dag.predecessors(node), [node], dag.successors(node)) |
58 | | - |
59 | | - |
60 | | -@dataclass |
61 | | -class TopologicalSorter: |
62 | | - """The topological sorter class. |
63 | | -
|
64 | | - This class allows to perform a topological sort# |
65 | | -
|
66 | | - Attributes |
67 | | - ---------- |
68 | | - dag |
69 | | - Not the full DAG, but a reduced version that only considers tasks. |
70 | | - priorities |
71 | | - A dictionary of task names to a priority value. 1 for try first, 0 for the |
72 | | - default priority and, -1 for try last. |
73 | | -
|
74 | | - """ |
75 | | - |
76 | | - dag: DAG |
77 | | - priorities: dict[str, int] = field(default_factory=dict) |
78 | | - _nodes_processing: set[str] = field(default_factory=set) |
79 | | - _nodes_done: set[str] = field(default_factory=set) |
80 | | - |
81 | | - @classmethod |
82 | | - def from_dag(cls, dag: DAG) -> TopologicalSorter: |
83 | | - """Instantiate from a DAG.""" |
84 | | - cls.check_dag(dag) |
85 | | - |
86 | | - tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)] |
87 | | - priorities = _extract_priorities_from_tasks(tasks) |
88 | | - |
89 | | - task_signatures = {task.signature for task in tasks} |
90 | | - task_dag = DAG() |
91 | | - for signature in task_signatures: |
92 | | - task_dag.add_node(signature, dag.nodes[signature]) |
93 | | - for signature in task_signatures: |
94 | | - # The scheduler graph uses edges from predecessor -> successor so that |
95 | | - # zero in-degree means "ready to run". This is the same orientation the |
96 | | - # previous networkx-based implementation reached after calling reverse(). |
97 | | - for ancestor_ in dag.ancestors(signature) & task_signatures: |
98 | | - task_dag.add_edge(ancestor_, signature) |
99 | | - |
100 | | - return cls(dag=task_dag, priorities=priorities) |
101 | | - |
102 | | - @classmethod |
103 | | - def from_dag_and_sorter( |
104 | | - cls, dag: DAG, sorter: TopologicalSorter |
105 | | - ) -> TopologicalSorter: |
106 | | - """Instantiate a sorter from another sorter and a DAG.""" |
107 | | - new_sorter = cls.from_dag(dag) |
108 | | - new_sorter.done(*sorter._nodes_done) |
109 | | - new_sorter._nodes_processing = sorter._nodes_processing |
110 | | - return new_sorter |
111 | | - |
112 | | - @staticmethod |
113 | | - def check_dag(dag: DAG) -> None: |
114 | | - try: |
115 | | - find_cycle(dag) |
116 | | - except NoCycleError: |
117 | | - pass |
118 | | - else: |
119 | | - msg = "The DAG contains cycles." |
120 | | - raise ValueError(msg) |
121 | | - |
122 | | - def get_ready(self, n: int = 1) -> list[str]: |
123 | | - """Get up to ``n`` tasks which are ready.""" |
124 | | - if not isinstance(n, int) or n < 1: |
125 | | - msg = "'n' must be an integer greater or equal than 1." |
126 | | - raise ValueError(msg) |
127 | | - |
128 | | - ready_nodes = { |
129 | | - v for v, d in self.dag.in_degree() if d == 0 |
130 | | - } - self._nodes_processing |
131 | | - prioritized_nodes = sorted( |
132 | | - ready_nodes, key=lambda x: self.priorities.get(x, 0) |
133 | | - )[-n:] |
134 | | - |
135 | | - self._nodes_processing.update(prioritized_nodes) |
136 | | - |
137 | | - return prioritized_nodes |
138 | | - |
139 | | - def is_active(self) -> bool: |
140 | | - """Indicate whether there are still tasks left.""" |
141 | | - return bool(self.dag.nodes) |
142 | | - |
143 | | - def done(self, *nodes: str) -> None: |
144 | | - """Mark some tasks as done.""" |
145 | | - self._nodes_processing = self._nodes_processing - set(nodes) |
146 | | - self.dag.remove_nodes_from(nodes) |
147 | | - self._nodes_done.update(nodes) |
148 | | - |
149 | | - |
150 | | -def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: |
151 | | - """Extract priorities from tasks. |
152 | | -
|
153 | | - Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][] |
154 | | - markers. We recode these markers to numeric values to sort all available by |
155 | | - priorities. ``try_first`` is assigned the highest value such that it has the |
156 | | - rightmost position in the list. Then, we can simply call `list.pop` on the |
157 | | - list which is far more efficient than ``list.pop(0)``. |
158 | | -
|
159 | | - """ |
160 | | - priorities = { |
161 | | - task.signature: { |
162 | | - "try_first": has_mark(task, "try_first"), |
163 | | - "try_last": has_mark(task, "try_last"), |
164 | | - } |
165 | | - for task in tasks |
166 | | - } |
167 | | - |
168 | | - # Recode to numeric values for sorting. |
169 | | - numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} |
170 | | - return { |
171 | | - name: numeric_mapping[(p["try_first"], p["try_last"])] |
172 | | - for name, p in priorities.items() |
173 | | - } |
0 commit comments