Skip to content

Commit b806a41

Browse files
committed
Introduce scheduler protocol and simple scheduler
1 parent 13a3466 commit b806a41

6 files changed

Lines changed: 182 additions & 147 deletions

File tree

src/_pytask/dag_utils.py

Lines changed: 14 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,30 @@
33
from __future__ import annotations
44

55
import itertools
6-
from dataclasses import dataclass
7-
from dataclasses import field
86
from typing import TYPE_CHECKING
97

108
from _pytask.dag_graph import DiGraph
11-
from _pytask.dag_graph import NoCycleError
129
from _pytask.dag_graph import ancestors
1310
from _pytask.dag_graph import descendants
14-
from _pytask.dag_graph import find_cycle
15-
from _pytask.mark_utils import has_mark
11+
from _pytask.scheduler import SimpleScheduler
12+
from _pytask.scheduler import TopologicalSorter
13+
from _pytask.scheduler import _extract_priorities_from_tasks
1614

1715
if TYPE_CHECKING:
1816
from collections.abc import Generator
1917
from collections.abc import Iterable
2018

21-
from _pytask.node_protocols import PTask
19+
20+
__all__ = [
21+
"SimpleScheduler",
22+
"TopologicalSorter",
23+
"_extract_priorities_from_tasks",
24+
"descending_tasks",
25+
"node_and_neighbors",
26+
"preceding_tasks",
27+
"task_and_descending_tasks",
28+
"task_and_preceding_tasks",
29+
]
2230

2331

2432
def descending_tasks(task_name: str, dag: DiGraph) -> Generator[str, None, None]:
@@ -62,125 +70,3 @@ def node_and_neighbors(dag: DiGraph, node: str) -> Iterable[str]:
6270
6371
"""
6472
return itertools.chain(dag.predecessors(node), [node], dag.successors(node))
65-
66-
67-
@dataclass
68-
class TopologicalSorter:
69-
"""The topological sorter class.
70-
71-
This class allows to perform a topological sort#
72-
73-
Attributes
74-
----------
75-
dag
76-
Not the full DAG, but a reduced version that only considers tasks.
77-
priorities
78-
A dictionary of task names to a priority value. 1 for try first, 0 for the
79-
default priority and, -1 for try last.
80-
81-
"""
82-
83-
dag: DiGraph
84-
priorities: dict[str, int] = field(default_factory=dict)
85-
_nodes_processing: set[str] = field(default_factory=set)
86-
_nodes_done: set[str] = field(default_factory=set)
87-
88-
@classmethod
89-
def from_dag(cls, dag: DiGraph) -> TopologicalSorter:
90-
"""Instantiate from a DAG."""
91-
cls.check_dag(dag)
92-
93-
tasks = [
94-
dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node]
95-
]
96-
priorities = _extract_priorities_from_tasks(tasks)
97-
98-
task_signatures = {task.signature for task in tasks}
99-
task_dag = DiGraph()
100-
for signature in task_signatures:
101-
task_dag.add_node(signature)
102-
for signature in task_signatures:
103-
# The scheduler graph uses edges from predecessor -> successor so that
104-
# zero in-degree means "ready to run". This is the same orientation the
105-
# previous networkx-based implementation reached after calling reverse().
106-
for ancestor_ in ancestors(dag, signature) & task_signatures:
107-
task_dag.add_edge(ancestor_, signature)
108-
109-
return cls(dag=task_dag, priorities=priorities)
110-
111-
@classmethod
112-
def from_dag_and_sorter(
113-
cls, dag: DiGraph, sorter: TopologicalSorter
114-
) -> TopologicalSorter:
115-
"""Instantiate a sorter from another sorter and a DAG."""
116-
new_sorter = cls.from_dag(dag)
117-
new_sorter.done(*sorter._nodes_done)
118-
new_sorter._nodes_processing = sorter._nodes_processing
119-
return new_sorter
120-
121-
@staticmethod
122-
def check_dag(dag: DiGraph) -> None:
123-
if not dag.is_directed():
124-
msg = "Only directed graphs have a topological order."
125-
raise ValueError(msg)
126-
127-
try:
128-
find_cycle(dag)
129-
except NoCycleError:
130-
pass
131-
else:
132-
msg = "The DAG contains cycles."
133-
raise ValueError(msg)
134-
135-
def get_ready(self, n: int = 1) -> list[str]:
136-
"""Get up to ``n`` tasks which are ready."""
137-
if not isinstance(n, int) or n < 1:
138-
msg = "'n' must be an integer greater or equal than 1."
139-
raise ValueError(msg)
140-
141-
ready_nodes = {
142-
v for v, d in self.dag.in_degree() if d == 0
143-
} - self._nodes_processing
144-
prioritized_nodes = sorted(
145-
ready_nodes, key=lambda x: self.priorities.get(x, 0)
146-
)[-n:]
147-
148-
self._nodes_processing.update(prioritized_nodes)
149-
150-
return prioritized_nodes
151-
152-
def is_active(self) -> bool:
153-
"""Indicate whether there are still tasks left."""
154-
return bool(self.dag.nodes)
155-
156-
def done(self, *nodes: str) -> None:
157-
"""Mark some tasks as done."""
158-
self._nodes_processing = self._nodes_processing - set(nodes)
159-
self.dag.remove_nodes_from(nodes)
160-
self._nodes_done.update(nodes)
161-
162-
163-
def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
164-
"""Extract priorities from tasks.
165-
166-
Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][]
167-
markers. We recode these markers to numeric values to sort all available by
168-
priorities. ``try_first`` is assigned the highest value such that it has the
169-
rightmost position in the list. Then, we can simply call `list.pop` on the
170-
list which is far more efficient than ``list.pop(0)``.
171-
172-
"""
173-
priorities = {
174-
task.signature: {
175-
"try_first": has_mark(task, "try_first"),
176-
"try_last": has_mark(task, "try_last"),
177-
}
178-
for task in tasks
179-
}
180-
181-
# Recode to numeric values for sorting.
182-
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
183-
return {
184-
name: numeric_mapping[(p["try_first"], p["try_last"])]
185-
for name, p in priorities.items()
186-
}

src/_pytask/execute.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from _pytask.console import format_node_name
1919
from _pytask.console import format_strings_as_flat_tree
2020
from _pytask.console import unify_styles
21-
from _pytask.dag_utils import TopologicalSorter
2221
from _pytask.dag_utils import descending_tasks
2322
from _pytask.dag_utils import node_and_neighbors
2423
from _pytask.exceptions import ExecutionError
@@ -44,6 +43,7 @@
4443
from _pytask.pluginmanager import hookimpl
4544
from _pytask.provisional_utils import collect_provisional_products
4645
from _pytask.reports import ExecutionReport
46+
from _pytask.scheduler import SimpleScheduler
4747
from _pytask.state import get_node_change_info
4848
from _pytask.state import has_node_changed
4949
from _pytask.state import update_states
@@ -68,7 +68,7 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
6868
def pytask_execute(session: Session) -> None:
6969
"""Execute tasks."""
7070
session.hook.pytask_execute_log_start(session=session)
71-
session.scheduler = TopologicalSorter.from_dag(session.dag)
71+
session.scheduler = SimpleScheduler.from_dag(session.dag)
7272
session.hook.pytask_execute_build(session=session)
7373
session.hook.pytask_execute_log_end(
7474
session=session, reports=session.execution_reports
@@ -87,15 +87,16 @@ def pytask_execute_log_start(session: Session) -> None:
8787
@hookimpl
8888
def pytask_execute_build(session: Session) -> bool | None:
8989
"""Execute tasks."""
90-
if isinstance(session.scheduler, TopologicalSorter):
91-
while session.scheduler.is_active():
92-
task_name = session.scheduler.get_ready()[0]
90+
scheduler = session.scheduler
91+
if scheduler is not None:
92+
while scheduler.is_active():
93+
task_name = scheduler.get_ready()[0]
9394
task = session.dag.nodes[task_name]["task"]
9495
report = session.hook.pytask_execute_task_protocol(
9596
session=session, task=task
9697
)
9798
session.execution_reports.append(report)
98-
session.scheduler.done(task_name)
99+
scheduler.done(task_name)
99100

100101
if session.should_stop:
101102
return True

src/_pytask/provisional_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from _pytask.collect_utils import collect_dependency
1111
from _pytask.dag import create_dag_from_session
12-
from _pytask.dag_utils import TopologicalSorter
1312
from _pytask.models import NodeInfo
1413
from _pytask.node_protocols import PNode
1514
from _pytask.node_protocols import PProvisionalNode
@@ -78,9 +77,8 @@ def recreate_dag(session: Session, task: PTask) -> None:
7877
"""
7978
try:
8079
session.dag = create_dag_from_session(session)
81-
session.scheduler = TopologicalSorter.from_dag_and_sorter(
82-
session.dag, session.scheduler
83-
)
80+
if session.scheduler is not None:
81+
session.scheduler = session.scheduler.rebuild(session.dag)
8482

8583
except Exception: # noqa: BLE001
8684
report = ExecutionReport.from_task_and_exception(task, sys.exc_info())

src/_pytask/scheduler.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""Contains scheduler protocols and implementations."""
2+
3+
from __future__ import annotations
4+
5+
from dataclasses import dataclass
6+
from dataclasses import field
7+
from typing import TYPE_CHECKING
8+
from typing import Protocol
9+
10+
from _pytask.dag_graph import DiGraph
11+
from _pytask.dag_graph import NoCycleError
12+
from _pytask.dag_graph import ancestors
13+
from _pytask.dag_graph import find_cycle
14+
from _pytask.mark_utils import has_mark
15+
16+
if TYPE_CHECKING:
17+
from _pytask.node_protocols import PTask
18+
19+
20+
class PScheduler(Protocol):
21+
"""Protocol for schedulers that dispatch ready tasks."""
22+
23+
def get_ready(self, n: int = 1) -> list[str]:
24+
"""Get up to ``n`` tasks which are ready."""
25+
26+
def is_active(self) -> bool:
27+
"""Indicate whether there are still tasks left."""
28+
29+
def done(self, *nodes: str) -> None:
30+
"""Mark some tasks as done."""
31+
32+
def rebuild(self, dag: DiGraph) -> PScheduler:
33+
"""Rebuild the scheduler from an updated DAG while preserving state."""
34+
35+
36+
@dataclass
37+
class SimpleScheduler:
38+
"""The default scheduler based on topological sorting."""
39+
40+
dag: DiGraph
41+
priorities: dict[str, int] = field(default_factory=dict)
42+
_nodes_processing: set[str] = field(default_factory=set)
43+
_nodes_done: set[str] = field(default_factory=set)
44+
45+
@classmethod
46+
def from_dag(cls, dag: DiGraph) -> SimpleScheduler:
47+
"""Instantiate from a DAG."""
48+
cls.check_dag(dag)
49+
50+
tasks = [
51+
dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node]
52+
]
53+
priorities = _extract_priorities_from_tasks(tasks)
54+
55+
task_signatures = {task.signature for task in tasks}
56+
task_dag = DiGraph()
57+
for signature in task_signatures:
58+
task_dag.add_node(signature)
59+
for signature in task_signatures:
60+
# The scheduler graph uses edges from predecessor -> successor so that
61+
# zero in-degree means "ready to run". This is the same orientation the
62+
# previous networkx-based implementation reached after calling reverse().
63+
for ancestor_ in ancestors(dag, signature) & task_signatures:
64+
task_dag.add_edge(ancestor_, signature)
65+
66+
return cls(dag=task_dag, priorities=priorities)
67+
68+
@classmethod
69+
def from_dag_and_sorter(
70+
cls, dag: DiGraph, sorter: SimpleScheduler
71+
) -> SimpleScheduler:
72+
"""Instantiate a sorter from another sorter and a DAG."""
73+
new_sorter = cls.from_dag(dag)
74+
new_sorter.done(*sorter._nodes_done)
75+
new_sorter._nodes_processing = sorter._nodes_processing
76+
return new_sorter
77+
78+
@staticmethod
79+
def check_dag(dag: DiGraph) -> None:
80+
if not dag.is_directed():
81+
msg = "Only directed graphs have a topological order."
82+
raise ValueError(msg)
83+
84+
try:
85+
find_cycle(dag)
86+
except NoCycleError:
87+
pass
88+
else:
89+
msg = "The DAG contains cycles."
90+
raise ValueError(msg)
91+
92+
def get_ready(self, n: int = 1) -> list[str]:
93+
"""Get up to ``n`` tasks which are ready."""
94+
if not isinstance(n, int) or n < 1:
95+
msg = "'n' must be an integer greater or equal than 1."
96+
raise ValueError(msg)
97+
98+
ready_nodes = {
99+
v for v, d in self.dag.in_degree() if d == 0
100+
} - self._nodes_processing
101+
prioritized_nodes = sorted(
102+
ready_nodes, key=lambda x: self.priorities.get(x, 0)
103+
)[-n:]
104+
105+
self._nodes_processing.update(prioritized_nodes)
106+
107+
return prioritized_nodes
108+
109+
def is_active(self) -> bool:
110+
"""Indicate whether there are still tasks left."""
111+
return bool(self.dag.nodes)
112+
113+
def done(self, *nodes: str) -> None:
114+
"""Mark some tasks as done."""
115+
self._nodes_processing = self._nodes_processing - set(nodes)
116+
self.dag.remove_nodes_from(nodes)
117+
self._nodes_done.update(nodes)
118+
119+
def rebuild(self, dag: DiGraph) -> SimpleScheduler:
120+
"""Rebuild the scheduler from an updated DAG while preserving state."""
121+
return self.from_dag_and_sorter(dag, self)
122+
123+
124+
TopologicalSorter = SimpleScheduler
125+
126+
127+
def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
128+
"""Extract priorities from tasks.
129+
130+
Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][]
131+
markers. We recode these markers to numeric values to sort all available by
132+
priorities. ``try_first`` is assigned the highest value such that it has the
133+
rightmost position in the list. Then, we can simply call `list.pop` on the
134+
list which is far more efficient than ``list.pop(0)``.
135+
136+
"""
137+
priorities = {
138+
task.signature: {
139+
"try_first": has_mark(task, "try_first"),
140+
"try_last": has_mark(task, "try_last"),
141+
}
142+
for task in tasks
143+
}
144+
145+
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
146+
return {
147+
name: numeric_mapping[(p["try_first"], p["try_last"])]
148+
for name, p in priorities.items()
149+
}

src/_pytask/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from _pytask.reports import CollectionReport
1818
from _pytask.reports import DagReport
1919
from _pytask.reports import ExecutionReport
20+
from _pytask.scheduler import PScheduler
2021
from _pytask.warnings_utils import WarningReport
2122

2223

@@ -64,7 +65,7 @@ class Session:
6465
execution_end: float = float("inf")
6566

6667
n_tasks_failed: int = 0
67-
scheduler: Any = None
68+
scheduler: PScheduler | None = None
6869
should_stop: bool = False
6970
warnings: list[WarningReport] = field(default_factory=list)
7071

0 commit comments

Comments
 (0)