diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index a9a3f6cf..58a53aa6 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -152,7 +152,6 @@ def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str: elif isinstance(node, (PNode, PProvisionalNode)): short_name = node.name lines.extend((short_name, " " + ARROW_DOWN_ICON)) - # Join while removing last arrow. return "\n".join(lines[:-1]) diff --git a/src/_pytask/dag_command.py b/src/_pytask/dag_command.py index 7ed12e32..b83c654a 100644 --- a/src/_pytask/dag_command.py +++ b/src/_pytask/dag_command.py @@ -147,9 +147,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph: pm = get_plugin_manager() storage.store(pm) - # If someone called the programmatic interface, we need to do some parsing. if "command" not in raw_config: - # Add defaults from cli. from _pytask.cli import DEFAULTS_FROM_CLI # noqa: PLC0415 raw_config = normalize_programmatic_config( diff --git a/src/_pytask/dag_graph.py b/src/_pytask/dag_graph.py index 0e3df198..ce4dafef 100644 --- a/src/_pytask/dag_graph.py +++ b/src/_pytask/dag_graph.py @@ -127,9 +127,7 @@ def _traverse( return visited -def find_cycle( - dag: DAG, -) -> list[tuple[str, str]]: +def find_cycle(dag: DAG) -> list[tuple[str, str]]: """Find one cycle in the graph.""" visited: set[str] = set() active: set[str] = set() diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index aea91dbf..ca1c6664 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -3,20 +3,25 @@ from __future__ import annotations import itertools -from dataclasses import dataclass -from dataclasses import field from typing import TYPE_CHECKING -from _pytask.dag_graph import DAG -from _pytask.dag_graph import NoCycleError -from _pytask.dag_graph import find_cycle -from _pytask.mark_utils import has_mark from _pytask.node_protocols import PTask if TYPE_CHECKING: from collections.abc import Generator from collections.abc import Iterable + from _pytask.dag_graph import DAG + + +__all__ = [ + "descending_tasks", + "node_and_neighbors", + "preceding_tasks", + "task_and_descending_tasks", + "task_and_preceding_tasks", +] + def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]: """Yield only descending tasks.""" @@ -55,119 +60,3 @@ def node_and_neighbors(dag: DAG, node: str) -> Iterable[str]: """ return itertools.chain(dag.predecessors(node), [node], dag.successors(node)) - - -@dataclass -class TopologicalSorter: - """The topological sorter class. - - This class allows to perform a topological sort# - - Attributes - ---------- - dag - Not the full DAG, but a reduced version that only considers tasks. - priorities - A dictionary of task names to a priority value. 1 for try first, 0 for the - default priority and, -1 for try last. - - """ - - dag: DAG - priorities: dict[str, int] = field(default_factory=dict) - _nodes_processing: set[str] = field(default_factory=set) - _nodes_done: set[str] = field(default_factory=set) - - @classmethod - def from_dag(cls, dag: DAG) -> TopologicalSorter: - """Instantiate from a DAG.""" - cls.check_dag(dag) - - tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)] - priorities = _extract_priorities_from_tasks(tasks) - - task_signatures = {task.signature for task in tasks} - task_dag = DAG() - for signature in task_signatures: - task_dag.add_node(signature, dag.nodes[signature]) - for signature in task_signatures: - # The scheduler graph uses edges from predecessor -> successor so that - # zero in-degree means "ready to run". This is the same orientation the - # previous networkx-based implementation reached after calling reverse(). - for ancestor_ in dag.ancestors(signature) & task_signatures: - task_dag.add_edge(ancestor_, signature) - - return cls(dag=task_dag, priorities=priorities) - - @classmethod - def from_dag_and_sorter( - cls, dag: DAG, sorter: TopologicalSorter - ) -> TopologicalSorter: - """Instantiate a sorter from another sorter and a DAG.""" - new_sorter = cls.from_dag(dag) - new_sorter.done(*sorter._nodes_done) - new_sorter._nodes_processing = sorter._nodes_processing - return new_sorter - - @staticmethod - def check_dag(dag: DAG) -> None: - try: - find_cycle(dag) - except NoCycleError: - pass - else: - msg = "The DAG contains cycles." - raise ValueError(msg) - - def get_ready(self, n: int = 1) -> list[str]: - """Get up to ``n`` tasks which are ready.""" - if not isinstance(n, int) or n < 1: - msg = "'n' must be an integer greater or equal than 1." - raise ValueError(msg) - - ready_nodes = { - v for v, d in self.dag.in_degree() if d == 0 - } - self._nodes_processing - prioritized_nodes = sorted( - ready_nodes, key=lambda x: self.priorities.get(x, 0) - )[-n:] - - self._nodes_processing.update(prioritized_nodes) - - return prioritized_nodes - - def is_active(self) -> bool: - """Indicate whether there are still tasks left.""" - return bool(self.dag.nodes) - - def done(self, *nodes: str) -> None: - """Mark some tasks as done.""" - self._nodes_processing = self._nodes_processing - set(nodes) - self.dag.remove_nodes_from(nodes) - self._nodes_done.update(nodes) - - -def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: - """Extract priorities from tasks. - - Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][] - markers. We recode these markers to numeric values to sort all available by - priorities. ``try_first`` is assigned the highest value such that it has the - rightmost position in the list. Then, we can simply call `list.pop` on the - list which is far more efficient than ``list.pop(0)``. - - """ - priorities = { - task.signature: { - "try_first": has_mark(task, "try_first"), - "try_last": has_mark(task, "try_last"), - } - for task in tasks - } - - # Recode to numeric values for sorting. - numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} - return { - name: numeric_mapping[(p["try_first"], p["try_last"])] - for name, p in priorities.items() - } diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 7f9c43d3..923049e0 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -17,7 +17,6 @@ from _pytask.console import format_node_name from _pytask.console import format_strings_as_flat_tree from _pytask.console import unify_styles -from _pytask.dag_utils import TopologicalSorter from _pytask.dag_utils import descending_tasks from _pytask.dag_utils import node_and_neighbors from _pytask.exceptions import ExecutionError @@ -43,6 +42,7 @@ from _pytask.pluginmanager import hookimpl from _pytask.provisional_utils import collect_provisional_products from _pytask.reports import ExecutionReport +from _pytask.scheduler import SimpleScheduler from _pytask.state import get_node_change_info from _pytask.state import has_node_changed from _pytask.state import update_states @@ -67,7 +67,7 @@ def pytask_post_parse(config: dict[str, Any]) -> None: def pytask_execute(session: Session) -> None: """Execute tasks.""" session.hook.pytask_execute_log_start(session=session) - session.scheduler = TopologicalSorter.from_dag(session.dag) + session.scheduler = SimpleScheduler.from_dag(session.dag) session.hook.pytask_execute_build(session=session) session.hook.pytask_execute_log_end( session=session, reports=session.execution_reports @@ -86,7 +86,7 @@ def pytask_execute_log_start(session: Session) -> None: @hookimpl def pytask_execute_build(session: Session) -> bool | None: """Execute tasks.""" - if isinstance(session.scheduler, TopologicalSorter): + if session.scheduler is not None: while session.scheduler.is_active(): task_name = session.scheduler.get_ready()[0] task = session.dag.nodes[task_name] diff --git a/src/_pytask/provisional_utils.py b/src/_pytask/provisional_utils.py index ad01e790..604da3a8 100644 --- a/src/_pytask/provisional_utils.py +++ b/src/_pytask/provisional_utils.py @@ -9,7 +9,6 @@ from _pytask.collect_utils import collect_dependency from _pytask.dag import create_dag_from_session -from _pytask.dag_utils import TopologicalSorter from _pytask.models import NodeInfo from _pytask.node_protocols import PNode from _pytask.node_protocols import PProvisionalNode @@ -78,9 +77,8 @@ def recreate_dag(session: Session, task: PTask) -> None: """ try: session.dag = create_dag_from_session(session) - session.scheduler = TopologicalSorter.from_dag_and_sorter( - session.dag, session.scheduler - ) + if session.scheduler is not None: + session.scheduler = session.scheduler.rebuild(session.dag) except Exception: # noqa: BLE001 report = ExecutionReport.from_task_and_exception(task, sys.exc_info()) diff --git a/src/_pytask/scheduler.py b/src/_pytask/scheduler.py new file mode 100644 index 00000000..9ef81a6b --- /dev/null +++ b/src/_pytask/scheduler.py @@ -0,0 +1,129 @@ +"""Contains scheduler protocols and implementations.""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from typing import Protocol + +from _pytask.dag_graph import DAG +from _pytask.dag_graph import NoCycleError +from _pytask.dag_graph import find_cycle +from _pytask.mark_utils import has_mark +from _pytask.node_protocols import PTask + + +class PScheduler(Protocol): + """Protocol for schedulers that dispatch ready tasks.""" + + def get_ready(self, n: int = 1) -> list[str]: + """Get up to ``n`` tasks which are ready.""" + + def is_active(self) -> bool: + """Indicate whether there are still tasks left.""" + + def done(self, *nodes: str) -> None: + """Mark some tasks as done.""" + + def rebuild(self, dag: DAG) -> PScheduler: + """Rebuild the scheduler from an updated DAG while preserving state.""" + + +@dataclass +class SimpleScheduler: + """The default scheduler based on topological sorting.""" + + dag: DAG + priorities: dict[str, int] = field(default_factory=dict) + _nodes_processing: set[str] = field(default_factory=set) + _nodes_done: set[str] = field(default_factory=set) + + @classmethod + def from_dag(cls, dag: DAG) -> SimpleScheduler: + """Instantiate from a DAG.""" + cls.check_dag(dag) + + tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)] + priorities = _extract_priorities_from_tasks(tasks) + + task_signatures = {task.signature for task in tasks} + task_dag = DAG() + for signature in task_signatures: + task_dag.add_node(signature, dag.nodes[signature]) + for signature in task_signatures: + # The scheduler graph uses edges from predecessor -> successor so that + # zero in-degree means "ready to run". This is the same orientation the + # previous networkx-based implementation reached after calling reverse(). + for ancestor_ in dag.ancestors(signature) & task_signatures: + task_dag.add_edge(ancestor_, signature) + + return cls(dag=task_dag, priorities=priorities) + + @staticmethod + def check_dag(dag: DAG) -> None: + try: + find_cycle(dag) + except NoCycleError: + pass + else: + msg = "The DAG contains cycles." + raise ValueError(msg) + + def get_ready(self, n: int = 1) -> list[str]: + """Get up to ``n`` tasks which are ready.""" + if not isinstance(n, int) or n < 1: + msg = "'n' must be an integer greater or equal than 1." + raise ValueError(msg) + + ready_nodes = { + v for v, d in self.dag.in_degree() if d == 0 + } - self._nodes_processing + prioritized_nodes = sorted( + ready_nodes, key=lambda x: self.priorities.get(x, 0) + )[-n:] + + self._nodes_processing.update(prioritized_nodes) + + return prioritized_nodes + + def is_active(self) -> bool: + """Indicate whether there are still tasks left.""" + return bool(self.dag.nodes) + + def done(self, *nodes: str) -> None: + """Mark some tasks as done.""" + self._nodes_processing = self._nodes_processing - set(nodes) + self.dag.remove_nodes_from(nodes) + self._nodes_done.update(nodes) + + def rebuild(self, dag: DAG) -> SimpleScheduler: + """Rebuild the scheduler from an updated DAG while preserving state.""" + new_scheduler = type(self).from_dag(dag) + new_scheduler.done(*self._nodes_done) + new_scheduler._nodes_processing = self._nodes_processing.copy() + return new_scheduler + + +def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: + """Extract priorities from tasks. + + Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][] + markers. We recode these markers to numeric values to sort all available by + priorities. ``try_first`` is assigned the highest value such that it has the + rightmost position in the list. Then, we can simply call `list.pop` on the + list which is far more efficient than ``list.pop(0)``. + + """ + priorities = { + task.signature: { + "try_first": has_mark(task, "try_first"), + "try_last": has_mark(task, "try_last"), + } + for task in tasks + } + + numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1} + return { + name: numeric_mapping[(p["try_first"], p["try_last"])] + for name, p in priorities.items() + } diff --git a/src/_pytask/session.py b/src/_pytask/session.py index 5c9bbf8b..4a008cc2 100644 --- a/src/_pytask/session.py +++ b/src/_pytask/session.py @@ -17,6 +17,7 @@ from _pytask.reports import CollectionReport from _pytask.reports import DagReport from _pytask.reports import ExecutionReport + from _pytask.scheduler import PScheduler from _pytask.warnings_utils import WarningReport @@ -64,7 +65,7 @@ class Session: execution_end: float = float("inf") n_tasks_failed: int = 0 - scheduler: Any = None + scheduler: PScheduler | None = None should_stop: bool = False warnings: list[WarningReport] = field(default_factory=list) diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index 32effef0..dd8af9b6 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -1,16 +1,14 @@ from __future__ import annotations -from contextlib import ExitStack as does_not_raise # noqa: N813 from pathlib import Path import pytest from _pytask.dag_graph import DAG -from _pytask.dag_utils import TopologicalSorter -from _pytask.dag_utils import _extract_priorities_from_tasks from _pytask.dag_utils import descending_tasks from _pytask.dag_utils import node_and_neighbors from _pytask.dag_utils import task_and_descending_tasks +from _pytask.scheduler import SimpleScheduler from pytask import Mark from pytask import Task from tests.conftest import noop @@ -31,7 +29,7 @@ def dag(): def test_sort_tasks_topologically(dag): - sorter = TopologicalSorter.from_dag(dag) + sorter = SimpleScheduler.from_dag(dag) topo_ordering = [] while sorter.is_active(): task_name = sorter.get_ready()[0] @@ -71,70 +69,32 @@ def test_node_and_neighbors(dag): assert node_names == [f".::{j}" for j in range(i - 1, i + 2)] -@pytest.mark.parametrize( - ("tasks", "expectation", "expected"), - [ - pytest.param( - [ - Task( - base_name="1", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_last", (), {})], - ) - ], - does_not_raise(), - {"c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": -1}, - id="test try_last", - ), - pytest.param( - [ - Task( - base_name="1", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_first", (), {})], - ) - ], - does_not_raise(), - {"c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": 1}, - id="test try_first", - ), - pytest.param( - [Task(base_name="1", path=Path(), function=None, markers=[])], # type: ignore[arg-type] - does_not_raise(), - {"c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": 0}, - id="test no priority", - ), - pytest.param( - [ - Task( - base_name="1", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_first", (), {})], - ), - Task(base_name="2", path=Path(), function=None, markers=[]), # type: ignore[arg-type] - Task( - base_name="3", - path=Path(), - function=None, # type: ignore[arg-type] - markers=[Mark("try_last", (), {})], - ), - ], - does_not_raise(), - { - "c12d8d4f7e2e3128d27878d1fb3d8e3583e90e68000a13634dfbf21f4d1456f3": 1, - "c5f667e69824043475b1283ed8920e513cb4343ec7077f71a3d9f5972f5204b9": 0, - "dca295f815f54d282b33e8d9398cea4962d0dfbe881d2ab28fc48ff9e060203a": -1, - }, - ), - ], -) -def test_extract_priorities_from_tasks(tasks, expectation, expected): - with expectation: - result = _extract_priorities_from_tasks(tasks) - assert result == expected +def test_prioritize_try_first_and_try_last_tasks(): + dag = DAG() + first = Task( + base_name="first", + path=Path(), + function=noop, + markers=[Mark("try_first", (), {})], + ) + default = Task(base_name="default", path=Path(), function=noop) + last = Task( + base_name="last", + path=Path(), + function=noop, + markers=[Mark("try_last", (), {})], + ) + + for task in (first, default, last): + dag.add_node(task.signature, task) + + scheduler = SimpleScheduler.from_dag(dag) + + first_batch = scheduler.get_ready(3) + first_batch_names = [dag.nodes[sig].name for sig in first_batch] + + assert first_batch_names[-1] == ".::first" + assert first_batch_names[0] == ".::last" def test_raise_error_for_cycle_in_graph(dag): @@ -143,11 +103,11 @@ def test_raise_error_for_cycle_in_graph(dag): "55c6cef62d3e62d5f8fc65bb846e66d8d0d3ca60608c04f6f7b095ea073a7dcf", ) with pytest.raises(ValueError, match=r"The DAG contains cycles\."): - TopologicalSorter.from_dag(dag) + SimpleScheduler.from_dag(dag) def test_ask_for_invalid_number_of_ready_tasks(dag): - scheduler = TopologicalSorter.from_dag(dag) + scheduler = SimpleScheduler.from_dag(dag) with pytest.raises(ValueError, match="'n' must be"): scheduler.get_ready(0) @@ -155,7 +115,7 @@ def test_ask_for_invalid_number_of_ready_tasks(dag): def test_instantiate_sorter_from_other_sorter(dag): name_to_sig = {dag.nodes[sig].name: sig for sig in dag.nodes} - scheduler = TopologicalSorter.from_dag(dag) + scheduler = SimpleScheduler.from_dag(dag) for _ in range(2): task_name = scheduler.get_ready()[0] scheduler.done(task_name) @@ -165,7 +125,7 @@ def test_instantiate_sorter_from_other_sorter(dag): dag.add_node(task.signature, task) dag.add_edge(name_to_sig[".::4"], task.signature) - new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) + new_scheduler = scheduler.rebuild(dag) while new_scheduler.is_active(): task_name = new_scheduler.get_ready()[0] new_scheduler.done(task_name)