Skip to content

Commit 4e7e718

Browse files
authored
Introduce scheduler protocol and simple scheduler (#831)
1 parent fd71455 commit 4e7e718

9 files changed

Lines changed: 180 additions & 208 deletions

File tree

src/_pytask/dag.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str:
152152
elif isinstance(node, (PNode, PProvisionalNode)):
153153
short_name = node.name
154154
lines.extend((short_name, " " + ARROW_DOWN_ICON))
155-
# Join while removing last arrow.
156155
return "\n".join(lines[:-1])
157156

158157

src/_pytask/dag_command.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
147147
pm = get_plugin_manager()
148148
storage.store(pm)
149149

150-
# If someone called the programmatic interface, we need to do some parsing.
151150
if "command" not in raw_config:
152-
# Add defaults from cli.
153151
from _pytask.cli import DEFAULTS_FROM_CLI # noqa: PLC0415
154152

155153
raw_config = normalize_programmatic_config(

src/_pytask/dag_graph.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,7 @@ def _traverse(
127127
return visited
128128

129129

130-
def find_cycle(
131-
dag: DAG,
132-
) -> list[tuple[str, str]]:
130+
def find_cycle(dag: DAG) -> list[tuple[str, str]]:
133131
"""Find one cycle in the graph."""
134132
visited: set[str] = set()
135133
active: set[str] = set()

src/_pytask/dag_utils.py

Lines changed: 11 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,25 @@
33
from __future__ import annotations
44

55
import itertools
6-
from dataclasses import dataclass
7-
from dataclasses import field
86
from typing import TYPE_CHECKING
97

10-
from _pytask.dag_graph import DAG
11-
from _pytask.dag_graph import NoCycleError
12-
from _pytask.dag_graph import find_cycle
13-
from _pytask.mark_utils import has_mark
148
from _pytask.node_protocols import PTask
159

1610
if TYPE_CHECKING:
1711
from collections.abc import Generator
1812
from collections.abc import Iterable
1913

14+
from _pytask.dag_graph import DAG
15+
16+
17+
__all__ = [
18+
"descending_tasks",
19+
"node_and_neighbors",
20+
"preceding_tasks",
21+
"task_and_descending_tasks",
22+
"task_and_preceding_tasks",
23+
]
24+
2025

2126
def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]:
2227
"""Yield only descending tasks."""
@@ -55,119 +60,3 @@ def node_and_neighbors(dag: DAG, node: str) -> Iterable[str]:
5560
5661
"""
5762
return itertools.chain(dag.predecessors(node), [node], dag.successors(node))
58-
59-
60-
@dataclass
61-
class TopologicalSorter:
62-
"""The topological sorter class.
63-
64-
This class allows to perform a topological sort#
65-
66-
Attributes
67-
----------
68-
dag
69-
Not the full DAG, but a reduced version that only considers tasks.
70-
priorities
71-
A dictionary of task names to a priority value. 1 for try first, 0 for the
72-
default priority and, -1 for try last.
73-
74-
"""
75-
76-
dag: DAG
77-
priorities: dict[str, int] = field(default_factory=dict)
78-
_nodes_processing: set[str] = field(default_factory=set)
79-
_nodes_done: set[str] = field(default_factory=set)
80-
81-
@classmethod
82-
def from_dag(cls, dag: DAG) -> TopologicalSorter:
83-
"""Instantiate from a DAG."""
84-
cls.check_dag(dag)
85-
86-
tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)]
87-
priorities = _extract_priorities_from_tasks(tasks)
88-
89-
task_signatures = {task.signature for task in tasks}
90-
task_dag = DAG()
91-
for signature in task_signatures:
92-
task_dag.add_node(signature, dag.nodes[signature])
93-
for signature in task_signatures:
94-
# The scheduler graph uses edges from predecessor -> successor so that
95-
# zero in-degree means "ready to run". This is the same orientation the
96-
# previous networkx-based implementation reached after calling reverse().
97-
for ancestor_ in dag.ancestors(signature) & task_signatures:
98-
task_dag.add_edge(ancestor_, signature)
99-
100-
return cls(dag=task_dag, priorities=priorities)
101-
102-
@classmethod
103-
def from_dag_and_sorter(
104-
cls, dag: DAG, sorter: TopologicalSorter
105-
) -> TopologicalSorter:
106-
"""Instantiate a sorter from another sorter and a DAG."""
107-
new_sorter = cls.from_dag(dag)
108-
new_sorter.done(*sorter._nodes_done)
109-
new_sorter._nodes_processing = sorter._nodes_processing
110-
return new_sorter
111-
112-
@staticmethod
113-
def check_dag(dag: DAG) -> None:
114-
try:
115-
find_cycle(dag)
116-
except NoCycleError:
117-
pass
118-
else:
119-
msg = "The DAG contains cycles."
120-
raise ValueError(msg)
121-
122-
def get_ready(self, n: int = 1) -> list[str]:
123-
"""Get up to ``n`` tasks which are ready."""
124-
if not isinstance(n, int) or n < 1:
125-
msg = "'n' must be an integer greater or equal than 1."
126-
raise ValueError(msg)
127-
128-
ready_nodes = {
129-
v for v, d in self.dag.in_degree() if d == 0
130-
} - self._nodes_processing
131-
prioritized_nodes = sorted(
132-
ready_nodes, key=lambda x: self.priorities.get(x, 0)
133-
)[-n:]
134-
135-
self._nodes_processing.update(prioritized_nodes)
136-
137-
return prioritized_nodes
138-
139-
def is_active(self) -> bool:
140-
"""Indicate whether there are still tasks left."""
141-
return bool(self.dag.nodes)
142-
143-
def done(self, *nodes: str) -> None:
144-
"""Mark some tasks as done."""
145-
self._nodes_processing = self._nodes_processing - set(nodes)
146-
self.dag.remove_nodes_from(nodes)
147-
self._nodes_done.update(nodes)
148-
149-
150-
def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
151-
"""Extract priorities from tasks.
152-
153-
Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][]
154-
markers. We recode these markers to numeric values to sort all available by
155-
priorities. ``try_first`` is assigned the highest value such that it has the
156-
rightmost position in the list. Then, we can simply call `list.pop` on the
157-
list which is far more efficient than ``list.pop(0)``.
158-
159-
"""
160-
priorities = {
161-
task.signature: {
162-
"try_first": has_mark(task, "try_first"),
163-
"try_last": has_mark(task, "try_last"),
164-
}
165-
for task in tasks
166-
}
167-
168-
# Recode to numeric values for sorting.
169-
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
170-
return {
171-
name: numeric_mapping[(p["try_first"], p["try_last"])]
172-
for name, p in priorities.items()
173-
}

src/_pytask/execute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from _pytask.console import format_node_name
1818
from _pytask.console import format_strings_as_flat_tree
1919
from _pytask.console import unify_styles
20-
from _pytask.dag_utils import TopologicalSorter
2120
from _pytask.dag_utils import descending_tasks
2221
from _pytask.dag_utils import node_and_neighbors
2322
from _pytask.exceptions import ExecutionError
@@ -43,6 +42,7 @@
4342
from _pytask.pluginmanager import hookimpl
4443
from _pytask.provisional_utils import collect_provisional_products
4544
from _pytask.reports import ExecutionReport
45+
from _pytask.scheduler import SimpleScheduler
4646
from _pytask.state import get_node_change_info
4747
from _pytask.state import has_node_changed
4848
from _pytask.state import update_states
@@ -67,7 +67,7 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
6767
def pytask_execute(session: Session) -> None:
6868
"""Execute tasks."""
6969
session.hook.pytask_execute_log_start(session=session)
70-
session.scheduler = TopologicalSorter.from_dag(session.dag)
70+
session.scheduler = SimpleScheduler.from_dag(session.dag)
7171
session.hook.pytask_execute_build(session=session)
7272
session.hook.pytask_execute_log_end(
7373
session=session, reports=session.execution_reports
@@ -86,7 +86,7 @@ def pytask_execute_log_start(session: Session) -> None:
8686
@hookimpl
8787
def pytask_execute_build(session: Session) -> bool | None:
8888
"""Execute tasks."""
89-
if isinstance(session.scheduler, TopologicalSorter):
89+
if session.scheduler is not None:
9090
while session.scheduler.is_active():
9191
task_name = session.scheduler.get_ready()[0]
9292
task = session.dag.nodes[task_name]

src/_pytask/provisional_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from _pytask.collect_utils import collect_dependency
1111
from _pytask.dag import create_dag_from_session
12-
from _pytask.dag_utils import TopologicalSorter
1312
from _pytask.models import NodeInfo
1413
from _pytask.node_protocols import PNode
1514
from _pytask.node_protocols import PProvisionalNode
@@ -78,9 +77,8 @@ def recreate_dag(session: Session, task: PTask) -> None:
7877
"""
7978
try:
8079
session.dag = create_dag_from_session(session)
81-
session.scheduler = TopologicalSorter.from_dag_and_sorter(
82-
session.dag, session.scheduler
83-
)
80+
if session.scheduler is not None:
81+
session.scheduler = session.scheduler.rebuild(session.dag)
8482

8583
except Exception: # noqa: BLE001
8684
report = ExecutionReport.from_task_and_exception(task, sys.exc_info())

src/_pytask/scheduler.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
"""Contains scheduler protocols and implementations."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from dataclasses import field
7+
from typing import Protocol
8+
9+
from _pytask.dag_graph import DAG
10+
from _pytask.dag_graph import NoCycleError
11+
from _pytask.dag_graph import find_cycle
12+
from _pytask.mark_utils import has_mark
13+
from _pytask.node_protocols import PTask
14+
15+
16+
class PScheduler(Protocol):
17+
"""Protocol for schedulers that dispatch ready tasks."""
18+
19+
def get_ready(self, n: int = 1) -> list[str]:
20+
"""Get up to ``n`` tasks which are ready."""
21+
22+
def is_active(self) -> bool:
23+
"""Indicate whether there are still tasks left."""
24+
25+
def done(self, *nodes: str) -> None:
26+
"""Mark some tasks as done."""
27+
28+
def rebuild(self, dag: DAG) -> PScheduler:
29+
"""Rebuild the scheduler from an updated DAG while preserving state."""
30+
31+
32+
@dataclass
33+
class SimpleScheduler:
34+
"""The default scheduler based on topological sorting."""
35+
36+
dag: DAG
37+
priorities: dict[str, int] = field(default_factory=dict)
38+
_nodes_processing: set[str] = field(default_factory=set)
39+
_nodes_done: set[str] = field(default_factory=set)
40+
41+
@classmethod
42+
def from_dag(cls, dag: DAG) -> SimpleScheduler:
43+
"""Instantiate from a DAG."""
44+
cls.check_dag(dag)
45+
46+
tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)]
47+
priorities = _extract_priorities_from_tasks(tasks)
48+
49+
task_signatures = {task.signature for task in tasks}
50+
task_dag = DAG()
51+
for signature in task_signatures:
52+
task_dag.add_node(signature, dag.nodes[signature])
53+
for signature in task_signatures:
54+
# The scheduler graph uses edges from predecessor -> successor so that
55+
# zero in-degree means "ready to run". This is the same orientation the
56+
# previous networkx-based implementation reached after calling reverse().
57+
for ancestor_ in dag.ancestors(signature) & task_signatures:
58+
task_dag.add_edge(ancestor_, signature)
59+
60+
return cls(dag=task_dag, priorities=priorities)
61+
62+
@staticmethod
63+
def check_dag(dag: DAG) -> None:
64+
try:
65+
find_cycle(dag)
66+
except NoCycleError:
67+
pass
68+
else:
69+
msg = "The DAG contains cycles."
70+
raise ValueError(msg)
71+
72+
def get_ready(self, n: int = 1) -> list[str]:
73+
"""Get up to ``n`` tasks which are ready."""
74+
if not isinstance(n, int) or n < 1:
75+
msg = "'n' must be an integer greater or equal than 1."
76+
raise ValueError(msg)
77+
78+
ready_nodes = {
79+
v for v, d in self.dag.in_degree() if d == 0
80+
} - self._nodes_processing
81+
prioritized_nodes = sorted(
82+
ready_nodes, key=lambda x: self.priorities.get(x, 0)
83+
)[-n:]
84+
85+
self._nodes_processing.update(prioritized_nodes)
86+
87+
return prioritized_nodes
88+
89+
def is_active(self) -> bool:
90+
"""Indicate whether there are still tasks left."""
91+
return bool(self.dag.nodes)
92+
93+
def done(self, *nodes: str) -> None:
94+
"""Mark some tasks as done."""
95+
self._nodes_processing = self._nodes_processing - set(nodes)
96+
self.dag.remove_nodes_from(nodes)
97+
self._nodes_done.update(nodes)
98+
99+
def rebuild(self, dag: DAG) -> SimpleScheduler:
100+
"""Rebuild the scheduler from an updated DAG while preserving state."""
101+
new_scheduler = type(self).from_dag(dag)
102+
new_scheduler.done(*self._nodes_done)
103+
new_scheduler._nodes_processing = self._nodes_processing.copy()
104+
return new_scheduler
105+
106+
107+
def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
108+
"""Extract priorities from tasks.
109+
110+
Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][]
111+
markers. We recode these markers to numeric values to sort all available by
112+
priorities. ``try_first`` is assigned the highest value such that it has the
113+
rightmost position in the list. Then, we can simply call `list.pop` on the
114+
list which is far more efficient than ``list.pop(0)``.
115+
116+
"""
117+
priorities = {
118+
task.signature: {
119+
"try_first": has_mark(task, "try_first"),
120+
"try_last": has_mark(task, "try_last"),
121+
}
122+
for task in tasks
123+
}
124+
125+
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
126+
return {
127+
name: numeric_mapping[(p["try_first"], p["try_last"])]
128+
for name, p in priorities.items()
129+
}

src/_pytask/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from _pytask.reports import CollectionReport
1818
from _pytask.reports import DagReport
1919
from _pytask.reports import ExecutionReport
20+
from _pytask.scheduler import PScheduler
2021
from _pytask.warnings_utils import WarningReport
2122

2223

@@ -64,7 +65,7 @@ class Session:
6465
execution_end: float = float("inf")
6566

6667
n_tasks_failed: int = 0
67-
scheduler: Any = None
68+
scheduler: PScheduler | None = None
6869
should_stop: bool = False
6970
warnings: list[WarningReport] = field(default_factory=list)
7071

0 commit comments

Comments
 (0)