|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import itertools |
6 | | -from dataclasses import dataclass |
7 | | -from dataclasses import field |
8 | 6 | from typing import TYPE_CHECKING |
9 | 7 |
|
10 | 8 | from _pytask.dag_graph import DiGraph |
11 | | -from _pytask.dag_graph import NoCycleError |
12 | 9 | from _pytask.dag_graph import ancestors |
13 | 10 | from _pytask.dag_graph import descendants |
14 | | -from _pytask.dag_graph import find_cycle |
15 | | -from _pytask.mark_utils import has_mark |
| 11 | +from _pytask.scheduler import SimpleScheduler |
| 12 | +from _pytask.scheduler import TopologicalSorter |
| 13 | +from _pytask.scheduler import _extract_priorities_from_tasks |
16 | 14 |
|
17 | 15 | if TYPE_CHECKING: |
18 | 16 | from collections.abc import Generator |
19 | 17 | from collections.abc import Iterable |
20 | 18 |
|
21 | | - from _pytask.node_protocols import PTask |
| 19 | + |
| 20 | +__all__ = [ |
| 21 | + "SimpleScheduler", |
| 22 | + "TopologicalSorter", |
| 23 | + "_extract_priorities_from_tasks", |
| 24 | + "descending_tasks", |
| 25 | + "node_and_neighbors", |
| 26 | + "preceding_tasks", |
| 27 | + "task_and_descending_tasks", |
| 28 | + "task_and_preceding_tasks", |
| 29 | +] |
22 | 30 |
|
23 | 31 |
|
24 | 32 | def descending_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]: |
@@ -62,125 +70,3 @@ def node_and_neighbors(dag: DiGraph, node: str) -> Iterable[str]: |
62 | 70 |
|
63 | 71 | """ |
64 | 72 | return itertools.chain(dag.predecessors(node), [node], dag.successors(node)) |
65 | | - |
66 | | - |
67 | | -@dataclass |
68 | | -class TopologicalSorter: |
69 | | - """The topological sorter class. |
70 | | -
|
71 | | - This class allows to perform a topological sort# |
72 | | -
|
73 | | - Attributes |
74 | | - ---------- |
75 | | - dag |
76 | | - Not the full DAG, but a reduced version that only considers tasks. |
77 | | - priorities |
78 | | - A dictionary of task names to a priority value. 1 for try first, 0 for the |
79 | | - default priority and, -1 for try last. |
80 | | -
|
81 | | - """ |
82 | | - |
83 | | - dag: DiGraph |
84 | | - priorities: dict[str, int] = field(default_factory=dict) |
85 | | - _nodes_processing: set[str] = field(default_factory=set) |
86 | | - _nodes_done: set[str] = field(default_factory=set) |
87 | | - |
88 | | - @classmethod |
89 | | - def from_dag(cls, dag: DiGraph) -> TopologicalSorter: |
90 | | - """Instantiate from a DAG.""" |
91 | | - cls.check_dag(dag) |
92 | | - |
93 | | - tasks = [ |
94 | | - dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] |
95 | | - ] |
96 | | - priorities = _extract_priorities_from_tasks(tasks) |
97 | | - |
98 | | - task_signatures = {task.signature for task in tasks} |
99 | | - task_dag = DiGraph() |
100 | | - for signature in task_signatures: |
101 | | - task_dag.add_node(signature) |
102 | | - for signature in task_signatures: |
103 | | - # The scheduler graph uses edges from predecessor -> successor so that |
104 | | - # zero in-degree means "ready to run". This is the same orientation the |
105 | | - # previous networkx-based implementation reached after calling reverse(). |
106 | | - for ancestor_ in ancestors(dag, signature) & task_signatures: |
107 | | - task_dag.add_edge(ancestor_, signature) |
108 | | - |
109 | | - return cls(dag=task_dag, priorities=priorities) |
110 | | - |
111 | | - @classmethod |
112 | | - def from_dag_and_sorter( |
113 | | - cls, dag: DiGraph, sorter: TopologicalSorter |
114 | | - ) -> TopologicalSorter: |
115 | | - """Instantiate a sorter from another sorter and a DAG.""" |
116 | | - new_sorter = cls.from_dag(dag) |
117 | | - new_sorter.done(*sorter._nodes_done) |
118 | | - new_sorter._nodes_processing = sorter._nodes_processing |
119 | | - return new_sorter |
120 | | - |
121 | | - @staticmethod |
122 | | - def check_dag(dag: DiGraph) -> None: |
123 | | - if not dag.is_directed(): |
124 | | - msg = "Only directed graphs have a topological order." |
125 | | - raise ValueError(msg) |
126 | | - |
127 | | - try: |
128 | | - find_cycle(dag) |
129 | | - except NoCycleError: |
130 | | - pass |
131 | | - else: |
132 | | - msg = "The DAG contains cycles." |
133 | | - raise ValueError(msg) |
134 | | - |
135 | | - def get_ready(self, n: int = 1) -> list[str]: |
136 | | - """Get up to ``n`` tasks which are ready.""" |
137 | | - if not isinstance(n, int) or n < 1: |
138 | | - msg = "'n' must be an integer greater or equal than 1." |
139 | | - raise ValueError(msg) |
140 | | - |
141 | | - ready_nodes = { |
142 | | - v for v, d in self.dag.in_degree() if d == 0 |
143 | | - } - self._nodes_processing |
144 | | - prioritized_nodes = sorted( |
145 | | - ready_nodes, key=lambda x: self.priorities.get(x, 0) |
146 | | - )[-n:] |
147 | | - |
148 | | - self._nodes_processing.update(prioritized_nodes) |
149 | | - |
150 | | - return prioritized_nodes |
151 | | - |
152 | | - def is_active(self) -> bool: |
153 | | - """Indicate whether there are still tasks left.""" |
154 | | - return bool(self.dag.nodes) |
155 | | - |
156 | | - def done(self, *nodes: str) -> None: |
157 | | - """Mark some tasks as done.""" |
158 | | - self._nodes_processing = self._nodes_processing - set(nodes) |
159 | | - self.dag.remove_nodes_from(nodes) |
160 | | - self._nodes_done.update(nodes) |
161 | | - |
162 | | - |
163 | | -def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: |
164 | | - """Extract priorities from tasks. |
165 | | -
|
166 | | - Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][] |
167 | | - markers. We recode these markers to numeric values to sort all available by |
168 | | - priorities. ``try_first`` is assigned the highest value such that it has the |
169 | | - rightmost position in the list. Then, we can simply call `list.pop` on the |
170 | | - list which is far more efficient than ``list.pop(0)``. |
171 | | -
|
172 | | - """ |
173 | | - priorities = { |
174 | | - task.signature: { |
175 | | - "try_first": has_mark(task, "try_first"), |
176 | | - "try_last": has_mark(task, "try_last"), |
177 | | - } |
178 | | - for task in tasks |
179 | | - } |
180 | | - |
181 | | - # Recode to numeric values for sorting. |
182 | | - numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} |
183 | | - return { |
184 | | - name: numeric_mapping[(p["try_first"], p["try_last"])] |
185 | | - for name, p in priorities.items() |
186 | | - } |
0 commit comments