Skip to content

Commit 6812a12

Browse files
committed
Prefetch tasks for pending status
1 parent f8c3fab commit 6812a12

1 file changed

Lines changed: 109 additions & 25 deletions

File tree

src/pytask_parallel/execute.py

Lines changed: 109 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import queue
77
import sys
88
import time
9+
from collections import deque
910
from contextlib import ExitStack
1011
from multiprocessing import Manager
1112
from typing import TYPE_CHECKING
@@ -61,6 +62,9 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
6162
__tracebackhide__ = True
6263
reports = session.execution_reports
6364
running_tasks: dict[str, Future[Any]] = {}
65+
running_try_last: set[str] = set()
66+
queued_tasks: deque[str] = deque()
67+
queued_try_last_tasks: deque[str] = deque()
6468
sleeper = _Sleeper()
6569
debug_status = _is_debug_status_enabled()
6670

@@ -97,10 +101,26 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
97101
elif status_queue_factory == "simple":
98102
session.config["_status_queue"] = queue.SimpleQueue()
99103

104+
if live_execution:
105+
live_execution.initial_status = start_execution_state
106+
100107
i = 0
108+
prefetch_factor = (
109+
2
110+
if session.config["parallel_backend"]
111+
in (
112+
ParallelBackend.PROCESSES,
113+
ParallelBackend.LOKY,
114+
ParallelBackend.THREADS,
115+
)
116+
else 1
117+
)
118+
use_prefetch_queue = prefetch_factor > 1
101119
while session.scheduler.is_active():
102120
try:
103121
newly_collected_reports = []
122+
did_enqueue = False
123+
did_submit = False
104124

105125
# If there is any coiled function, the user probably wants to exploit
106126
# adaptive scaling. Thus, we need to submit all ready tasks.
@@ -110,42 +130,104 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
110130
if any_coiled_task:
111131
n_new_tasks = 10_000
112132
else:
113-
n_new_tasks = session.config["n_workers"] - len(running_tasks)
133+
if use_prefetch_queue:
134+
n_new_tasks = (
135+
session.config["n_workers"] * prefetch_factor
136+
) - (
137+
len(running_tasks)
138+
+ len(queued_tasks)
139+
+ len(queued_try_last_tasks)
140+
)
141+
else:
142+
n_new_tasks = session.config["n_workers"] - len(running_tasks)
114143

115144
ready_tasks = (
116145
list(session.scheduler.get_ready(n_new_tasks))
117146
if n_new_tasks >= 1
118147
else []
119148
)
120149

121-
for task_signature in ready_tasks:
122-
task = session.dag.nodes[task_signature]["task"]
123-
if debug_status:
124-
_log_status(
125-
"PENDING"
126-
if start_execution_state == TaskExecutionStatus.PENDING
127-
else "RUNNING",
128-
task_signature,
150+
if use_prefetch_queue:
151+
for task_signature in ready_tasks:
152+
task = session.dag.nodes[task_signature]["task"]
153+
if debug_status:
154+
_log_status("PENDING", task_signature)
155+
session.hook.pytask_execute_task_log_start(
156+
session=session,
157+
task=task,
158+
status=start_execution_state,
129159
)
130-
session.hook.pytask_execute_task_log_start(
131-
session=session, task=task, status=start_execution_state
132-
)
133-
try:
134-
session.hook.pytask_execute_task_setup(
135-
session=session, task=task
136-
)
137-
running_tasks[task_signature] = (
138-
session.hook.pytask_execute_task(session=session, task=task)
160+
if get_marks(task, "try_last"):
161+
queued_try_last_tasks.append(task_signature)
162+
else:
163+
queued_tasks.append(task_signature)
164+
did_enqueue = True
165+
166+
def _can_run_try_last() -> bool:
167+
return not (
168+
queued_tasks
169+
or (len(running_tasks) > len(running_try_last))
139170
)
140-
sleeper.reset()
141-
except Exception: # noqa: BLE001
142-
report = ExecutionReport.from_task_and_exception(
143-
task, sys.exc_info()
171+
172+
while len(running_tasks) < session.config["n_workers"]:
173+
if queued_tasks:
174+
task_signature = queued_tasks.popleft()
175+
elif queued_try_last_tasks and _can_run_try_last():
176+
task_signature = queued_try_last_tasks.popleft()
177+
else:
178+
break
179+
task = session.dag.nodes[task_signature]["task"]
180+
try:
181+
session.hook.pytask_execute_task_setup(
182+
session=session, task=task
183+
)
184+
running_tasks[task_signature] = (
185+
session.hook.pytask_execute_task(
186+
session=session, task=task
187+
)
188+
)
189+
if get_marks(task, "try_last"):
190+
running_try_last.add(task_signature)
191+
sleeper.reset()
192+
did_submit = True
193+
except Exception: # noqa: BLE001
194+
report = ExecutionReport.from_task_and_exception(
195+
task, sys.exc_info()
196+
)
197+
newly_collected_reports.append(report)
198+
session.scheduler.done(task_signature)
199+
else:
200+
for task_signature in ready_tasks:
201+
task = session.dag.nodes[task_signature]["task"]
202+
if debug_status:
203+
_log_status(
204+
"PENDING"
205+
if start_execution_state == TaskExecutionStatus.PENDING
206+
else "RUNNING",
207+
task_signature,
208+
)
209+
session.hook.pytask_execute_task_log_start(
210+
session=session, task=task, status=start_execution_state
144211
)
145-
newly_collected_reports.append(report)
146-
session.scheduler.done(task_signature)
212+
try:
213+
session.hook.pytask_execute_task_setup(
214+
session=session, task=task
215+
)
216+
running_tasks[task_signature] = (
217+
session.hook.pytask_execute_task(
218+
session=session, task=task
219+
)
220+
)
221+
sleeper.reset()
222+
did_submit = True
223+
except Exception: # noqa: BLE001
224+
report = ExecutionReport.from_task_and_exception(
225+
task, sys.exc_info()
226+
)
227+
newly_collected_reports.append(report)
228+
session.scheduler.done(task_signature)
147229

148-
if not ready_tasks:
230+
if not ready_tasks and not did_enqueue and not did_submit:
149231
sleeper.increment()
150232

151233
for task_signature in list(running_tasks):
@@ -173,6 +255,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
173255
)
174256
)
175257
running_tasks.pop(task_signature)
258+
running_try_last.discard(task_signature)
176259
session.scheduler.done(task_signature)
177260
else:
178261
task = session.dag.nodes[task_signature]["task"]
@@ -192,6 +275,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
192275
report = ExecutionReport.from_task(task)
193276

194277
running_tasks.pop(task_signature)
278+
running_try_last.discard(task_signature)
195279
newly_collected_reports.append(report)
196280
session.scheduler.done(task_signature)
197281

0 commit comments

Comments
 (0)