Skip to content
Merged
1 change: 0 additions & 1 deletion src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def _format_cycles(dag: DAG, cycles: list[tuple[str, str]]) -> str:
elif isinstance(node, (PNode, PProvisionalNode)):
short_name = node.name
lines.extend((short_name, " " + ARROW_DOWN_ICON))
# Join while removing last arrow.
return "\n".join(lines[:-1])


Expand Down
2 changes: 0 additions & 2 deletions src/_pytask/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ def build_dag(raw_config: dict[str, Any]) -> nx.DiGraph:
pm = get_plugin_manager()
storage.store(pm)

# If someone called the programmatic interface, we need to do some parsing.
if "command" not in raw_config:
# Add defaults from cli.
from _pytask.cli import DEFAULTS_FROM_CLI # noqa: PLC0415

raw_config = normalize_programmatic_config(
Expand Down
4 changes: 1 addition & 3 deletions src/_pytask/dag_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def _traverse(
return visited


def find_cycle(
dag: DAG,
) -> list[tuple[str, str]]:
def find_cycle(dag: DAG) -> list[tuple[str, str]]:
"""Find one cycle in the graph."""
visited: set[str] = set()
active: set[str] = set()
Expand Down
133 changes: 11 additions & 122 deletions src/_pytask/dag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,25 @@
from __future__ import annotations

import itertools
from dataclasses import dataclass
from dataclasses import field
from typing import TYPE_CHECKING

from _pytask.dag_graph import DAG
from _pytask.dag_graph import NoCycleError
from _pytask.dag_graph import find_cycle
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PTask

if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Iterable

from _pytask.dag_graph import DAG


__all__ = [
"descending_tasks",
"node_and_neighbors",
"preceding_tasks",
"task_and_descending_tasks",
"task_and_preceding_tasks",
]


def descending_tasks(task_name: str, dag: DAG) -> Generator[str, None, None]:
"""Yield only descending tasks."""
Expand Down Expand Up @@ -55,119 +60,3 @@ def node_and_neighbors(dag: DAG, node: str) -> Iterable[str]:

"""
return itertools.chain(dag.predecessors(node), [node], dag.successors(node))


@dataclass
class TopologicalSorter:
"""The topological sorter class.

This class allows to perform a topological sort#

Attributes
----------
dag
Not the full DAG, but a reduced version that only considers tasks.
priorities
A dictionary of task names to a priority value. 1 for try first, 0 for the
default priority and, -1 for try last.

"""

dag: DAG
priorities: dict[str, int] = field(default_factory=dict)
_nodes_processing: set[str] = field(default_factory=set)
_nodes_done: set[str] = field(default_factory=set)

@classmethod
def from_dag(cls, dag: DAG) -> TopologicalSorter:
"""Instantiate from a DAG."""
cls.check_dag(dag)

tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)]
priorities = _extract_priorities_from_tasks(tasks)

task_signatures = {task.signature for task in tasks}
task_dag = DAG()
for signature in task_signatures:
task_dag.add_node(signature, dag.nodes[signature])
for signature in task_signatures:
# The scheduler graph uses edges from predecessor -> successor so that
# zero in-degree means "ready to run". This is the same orientation the
# previous networkx-based implementation reached after calling reverse().
for ancestor_ in dag.ancestors(signature) & task_signatures:
task_dag.add_edge(ancestor_, signature)

return cls(dag=task_dag, priorities=priorities)

@classmethod
def from_dag_and_sorter(
cls, dag: DAG, sorter: TopologicalSorter
) -> TopologicalSorter:
"""Instantiate a sorter from another sorter and a DAG."""
new_sorter = cls.from_dag(dag)
new_sorter.done(*sorter._nodes_done)
new_sorter._nodes_processing = sorter._nodes_processing
return new_sorter

@staticmethod
def check_dag(dag: DAG) -> None:
try:
find_cycle(dag)
except NoCycleError:
pass
else:
msg = "The DAG contains cycles."
raise ValueError(msg)

def get_ready(self, n: int = 1) -> list[str]:
"""Get up to ``n`` tasks which are ready."""
if not isinstance(n, int) or n < 1:
msg = "'n' must be an integer greater or equal than 1."
raise ValueError(msg)

ready_nodes = {
v for v, d in self.dag.in_degree() if d == 0
} - self._nodes_processing
prioritized_nodes = sorted(
ready_nodes, key=lambda x: self.priorities.get(x, 0)
)[-n:]

self._nodes_processing.update(prioritized_nodes)

return prioritized_nodes

def is_active(self) -> bool:
"""Indicate whether there are still tasks left."""
return bool(self.dag.nodes)

def done(self, *nodes: str) -> None:
"""Mark some tasks as done."""
self._nodes_processing = self._nodes_processing - set(nodes)
self.dag.remove_nodes_from(nodes)
self._nodes_done.update(nodes)


def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
"""Extract priorities from tasks.

Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][]
markers. We recode these markers to numeric values to sort all available by
priorities. ``try_first`` is assigned the highest value such that it has the
rightmost position in the list. Then, we can simply call `list.pop` on the
list which is far more efficient than ``list.pop(0)``.

"""
priorities = {
task.signature: {
"try_first": has_mark(task, "try_first"),
"try_last": has_mark(task, "try_last"),
}
for task in tasks
}

# Recode to numeric values for sorting.
numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
return {
name: numeric_mapping[(p["try_first"], p["try_last"])]
for name, p in priorities.items()
}
6 changes: 3 additions & 3 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from _pytask.console import format_node_name
from _pytask.console import format_strings_as_flat_tree
from _pytask.console import unify_styles
from _pytask.dag_utils import TopologicalSorter
from _pytask.dag_utils import descending_tasks
from _pytask.dag_utils import node_and_neighbors
from _pytask.exceptions import ExecutionError
Expand All @@ -43,6 +42,7 @@
from _pytask.pluginmanager import hookimpl
from _pytask.provisional_utils import collect_provisional_products
from _pytask.reports import ExecutionReport
from _pytask.scheduler import SimpleScheduler
from _pytask.state import get_node_change_info
from _pytask.state import has_node_changed
from _pytask.state import update_states
Expand All @@ -67,7 +67,7 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
def pytask_execute(session: Session) -> None:
"""Execute tasks."""
session.hook.pytask_execute_log_start(session=session)
session.scheduler = TopologicalSorter.from_dag(session.dag)
session.scheduler = SimpleScheduler.from_dag(session.dag)
session.hook.pytask_execute_build(session=session)
session.hook.pytask_execute_log_end(
session=session, reports=session.execution_reports
Expand All @@ -86,7 +86,7 @@ def pytask_execute_log_start(session: Session) -> None:
@hookimpl
def pytask_execute_build(session: Session) -> bool | None:
"""Execute tasks."""
if isinstance(session.scheduler, TopologicalSorter):
if session.scheduler is not None:
while session.scheduler.is_active():
task_name = session.scheduler.get_ready()[0]
task = session.dag.nodes[task_name]
Expand Down
6 changes: 2 additions & 4 deletions src/_pytask/provisional_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from _pytask.collect_utils import collect_dependency
from _pytask.dag import create_dag_from_session
from _pytask.dag_utils import TopologicalSorter
from _pytask.models import NodeInfo
from _pytask.node_protocols import PNode
from _pytask.node_protocols import PProvisionalNode
Expand Down Expand Up @@ -78,9 +77,8 @@ def recreate_dag(session: Session, task: PTask) -> None:
"""
try:
session.dag = create_dag_from_session(session)
session.scheduler = TopologicalSorter.from_dag_and_sorter(
session.dag, session.scheduler
)
if session.scheduler is not None:
session.scheduler = session.scheduler.rebuild(session.dag)

except Exception: # noqa: BLE001
report = ExecutionReport.from_task_and_exception(task, sys.exc_info())
Expand Down
129 changes: 129 additions & 0 deletions src/_pytask/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Contains scheduler protocols and implementations."""

from __future__ import annotations

from dataclasses import dataclass
from dataclasses import field
from typing import Protocol

from _pytask.dag_graph import DAG
from _pytask.dag_graph import NoCycleError
from _pytask.dag_graph import find_cycle
from _pytask.mark_utils import has_mark
from _pytask.node_protocols import PTask


class PScheduler(Protocol):
"""Protocol for schedulers that dispatch ready tasks."""

def get_ready(self, n: int = 1) -> list[str]:
"""Get up to ``n`` tasks which are ready."""

def is_active(self) -> bool:
"""Indicate whether there are still tasks left."""

def done(self, *nodes: str) -> None:
"""Mark some tasks as done."""

def rebuild(self, dag: DAG) -> PScheduler:
"""Rebuild the scheduler from an updated DAG while preserving state."""


@dataclass
class SimpleScheduler:
"""The default scheduler based on topological sorting."""

dag: DAG
priorities: dict[str, int] = field(default_factory=dict)
_nodes_processing: set[str] = field(default_factory=set)
_nodes_done: set[str] = field(default_factory=set)

@classmethod
def from_dag(cls, dag: DAG) -> SimpleScheduler:
"""Instantiate from a DAG."""
cls.check_dag(dag)

tasks = [node for node in dag.nodes.values() if isinstance(node, PTask)]
priorities = _extract_priorities_from_tasks(tasks)

task_signatures = {task.signature for task in tasks}
task_dag = DAG()
for signature in task_signatures:
task_dag.add_node(signature, dag.nodes[signature])
for signature in task_signatures:
# The scheduler graph uses edges from predecessor -> successor so that
# zero in-degree means "ready to run". This is the same orientation the
# previous networkx-based implementation reached after calling reverse().
for ancestor_ in dag.ancestors(signature) & task_signatures:
task_dag.add_edge(ancestor_, signature)

return cls(dag=task_dag, priorities=priorities)

@staticmethod
def check_dag(dag: DAG) -> None:
try:
find_cycle(dag)
except NoCycleError:
pass
else:
msg = "The DAG contains cycles."
raise ValueError(msg)

def get_ready(self, n: int = 1) -> list[str]:
"""Get up to ``n`` tasks which are ready."""
if not isinstance(n, int) or n < 1:
msg = "'n' must be an integer greater or equal than 1."
raise ValueError(msg)

ready_nodes = {
v for v, d in self.dag.in_degree() if d == 0
} - self._nodes_processing
prioritized_nodes = sorted(
ready_nodes, key=lambda x: self.priorities.get(x, 0)
)[-n:]

self._nodes_processing.update(prioritized_nodes)

return prioritized_nodes

def is_active(self) -> bool:
"""Indicate whether there are still tasks left."""
return bool(self.dag.nodes)

def done(self, *nodes: str) -> None:
"""Mark some tasks as done."""
self._nodes_processing = self._nodes_processing - set(nodes)
self.dag.remove_nodes_from(nodes)
self._nodes_done.update(nodes)

def rebuild(self, dag: DAG) -> SimpleScheduler:
"""Rebuild the scheduler from an updated DAG while preserving state."""
new_scheduler = type(self).from_dag(dag)
new_scheduler.done(*self._nodes_done)
new_scheduler._nodes_processing = self._nodes_processing.copy()
return new_scheduler


def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
"""Extract priorities from tasks.

Priorities are set via the [pytask.mark.try_first][] and [pytask.mark.try_last][]
markers. We recode these markers to numeric values to sort all available by
priorities. ``try_first`` is assigned the highest value such that it has the
rightmost position in the list. Then, we can simply call `list.pop` on the
list which is far more efficient than ``list.pop(0)``.

"""
priorities = {
task.signature: {
"try_first": has_mark(task, "try_first"),
"try_last": has_mark(task, "try_last"),
}
for task in tasks
}

numeric_mapping = {(True, False): 1, (False, False): 0, (False, True): -1}
return {
name: numeric_mapping[(p["try_first"], p["try_last"])]
for name, p in priorities.items()
}
3 changes: 2 additions & 1 deletion src/_pytask/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from _pytask.reports import CollectionReport
from _pytask.reports import DagReport
from _pytask.reports import ExecutionReport
from _pytask.scheduler import PScheduler
from _pytask.warnings_utils import WarningReport


Expand Down Expand Up @@ -64,7 +65,7 @@ class Session:
execution_end: float = float("inf")

n_tasks_failed: int = 0
scheduler: Any = None
scheduler: PScheduler | None = None
should_stop: bool = False
warnings: list[WarningReport] = field(default_factory=list)

Expand Down
Loading
Loading