-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexecute.py
More file actions
478 lines (408 loc) · 19.1 KB
/
execute.py
File metadata and controls
478 lines (408 loc) · 19.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
"""Contains code relevant to the execution."""
from __future__ import annotations
import os
import queue
import sys
import time
from collections import deque
from contextlib import ExitStack
from multiprocessing import Manager
from typing import TYPE_CHECKING
from typing import Any
from typing import cast
import cloudpickle
from _pytask.node_protocols import PPathNode
from attrs import define
from attrs import field
from pytask import ExecutionReport
from pytask import PNode
from pytask import PTask
from pytask import PythonNode
from pytask import Session
from pytask import TaskExecutionStatus
from pytask import console
from pytask import get_marks
from pytask import hookimpl
from pytask.tree_util import PyTree
from pytask.tree_util import tree_map
from pytask.tree_util import tree_structure
from pytask_parallel.backends import ParallelBackend
from pytask_parallel.backends import WorkerType
from pytask_parallel.backends import registry
from pytask_parallel.typing import CarryOverPath
from pytask_parallel.typing import is_coiled_function
from pytask_parallel.utils import create_kwargs_for_task
from pytask_parallel.utils import get_module
from pytask_parallel.utils import parse_future_result
if TYPE_CHECKING:
from collections.abc import Callable
from concurrent.futures import Future
from multiprocessing.managers import SyncManager
from pytask_parallel.wrappers import WrapperResult
@hookimpl
def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR0912, PLR0915
"""Execute tasks with a parallel backend.
There are three phases while the scheduler has tasks which need to be executed.
1. Take all ready tasks, set up their execution and submit them.
2. For all tasks which are running, find those which have finished and turn them
into a report.
3. Process all reports and report the result on the command line.
"""
__tracebackhide__ = True
reports = session.execution_reports
running_tasks: dict[str, Future[Any]] = {}
running_try_last: set[str] = set()
queued_try_first_tasks: deque[str] = deque()
queued_tasks: deque[str] = deque()
queued_try_last_tasks: deque[str] = deque()
sleeper = _Sleeper()
debug_status = _is_debug_status_enabled()
# Create a shared queue to differentiate between running and pending tasks for
# some parallel backends.
if session.config["parallel_backend"] in (
ParallelBackend.PROCESSES,
ParallelBackend.LOKY,
):
manager_cls: Callable[[], SyncManager] | type[ExitStack] = Manager
start_execution_state = TaskExecutionStatus.PENDING
status_queue_factory = "manager"
elif session.config["parallel_backend"] == ParallelBackend.THREADS:
manager_cls = ExitStack
start_execution_state = TaskExecutionStatus.PENDING
status_queue_factory = "simple"
else:
manager_cls = ExitStack
start_execution_state = TaskExecutionStatus.RUNNING
status_queue_factory = None
# Get the live execution manager from the registry if it exists.
live_execution = session.config["pm"].get_plugin("live_execution")
any_coiled_task = any(is_coiled_function(task) for task in session.tasks)
# The executor can only be created after the collection to give users the
# possibility to inject their own executors.
session.config["_parallel_executor"] = registry.get_parallel_backend(
session.config["parallel_backend"], n_workers=session.config["n_workers"]
)
with session.config["_parallel_executor"], manager_cls() as manager:
if status_queue_factory == "manager":
session.config["_status_queue"] = manager.Queue() # type: ignore[union-attr]
elif status_queue_factory == "simple":
session.config["_status_queue"] = queue.SimpleQueue()
if live_execution:
live_execution.initial_status = start_execution_state
i = 0
prefetch_factor = (
2
if session.config["parallel_backend"]
in (
ParallelBackend.PROCESSES,
ParallelBackend.LOKY,
ParallelBackend.THREADS,
)
else 1
)
use_prefetch_queue = prefetch_factor > 1
while session.scheduler.is_active():
try:
newly_collected_reports = []
did_enqueue = False
did_submit = False
# If there is any coiled function, the user probably wants to exploit
# adaptive scaling. Thus, we need to submit all ready tasks.
# Unfortunately, all submitted tasks are shown as running although some
# are pending.
#
if any_coiled_task:
n_new_tasks = 10_000
elif use_prefetch_queue:
n_new_tasks = (session.config["n_workers"] * prefetch_factor) - (
len(running_tasks)
+ len(queued_try_first_tasks)
+ len(queued_tasks)
+ len(queued_try_last_tasks)
)
else:
n_new_tasks = session.config["n_workers"] - len(running_tasks)
ready_tasks = (
list(session.scheduler.get_ready(n_new_tasks))
if n_new_tasks >= 1
else []
)
if use_prefetch_queue:
for task_signature in ready_tasks:
task = session.dag.nodes[task_signature]["task"]
if debug_status:
_log_status("PENDING", task_signature)
session.hook.pytask_execute_task_log_start(
session=session,
task=task,
status=start_execution_state,
)
if get_marks(task, "try_first"):
queued_try_first_tasks.append(task_signature)
elif get_marks(task, "try_last"):
queued_try_last_tasks.append(task_signature)
else:
queued_tasks.append(task_signature)
did_enqueue = True
def _can_run_try_last() -> bool:
return not (
queued_try_first_tasks
or queued_tasks
or (len(running_tasks) > len(running_try_last))
)
while len(running_tasks) < session.config["n_workers"]:
if queued_try_first_tasks:
task_signature = queued_try_first_tasks.popleft()
elif queued_tasks:
task_signature = queued_tasks.popleft()
elif queued_try_last_tasks and _can_run_try_last():
task_signature = queued_try_last_tasks.popleft()
else:
break
task = session.dag.nodes[task_signature]["task"]
try:
session.hook.pytask_execute_task_setup(
session=session, task=task
)
running_tasks[task_signature] = (
session.hook.pytask_execute_task(
session=session, task=task
)
)
if get_marks(task, "try_last"):
running_try_last.add(task_signature)
sleeper.reset()
did_submit = True
except Exception: # noqa: BLE001
report = ExecutionReport.from_task_and_exception(
task, sys.exc_info()
)
newly_collected_reports.append(report)
session.scheduler.done(task_signature)
else:
for task_signature in ready_tasks:
task = session.dag.nodes[task_signature]["task"]
if debug_status:
_log_status(
"PENDING"
if start_execution_state == TaskExecutionStatus.PENDING
else "RUNNING",
task_signature,
)
session.hook.pytask_execute_task_log_start(
session=session, task=task, status=start_execution_state
)
try:
session.hook.pytask_execute_task_setup(
session=session, task=task
)
running_tasks[task_signature] = (
session.hook.pytask_execute_task(
session=session, task=task
)
)
sleeper.reset()
did_submit = True
except Exception: # noqa: BLE001
report = ExecutionReport.from_task_and_exception(
task, sys.exc_info()
)
newly_collected_reports.append(report)
session.scheduler.done(task_signature)
if not ready_tasks and not did_enqueue and not did_submit:
sleeper.increment()
for task_signature in list(running_tasks):
future = running_tasks[task_signature]
if future.done():
wrapper_result = parse_future_result(future)
session.warnings.extend(wrapper_result.warning_reports)
if wrapper_result.stdout:
task.report_sections.append(
("call", "stdout", wrapper_result.stdout)
)
if wrapper_result.stderr:
task.report_sections.append(
("call", "stderr", wrapper_result.stderr)
)
if wrapper_result.exc_info is not None:
task = session.dag.nodes[task_signature]["task"]
newly_collected_reports.append(
ExecutionReport.from_task_and_exception(
task,
wrapper_result.exc_info, # type: ignore[arg-type]
)
)
running_tasks.pop(task_signature)
running_try_last.discard(task_signature)
session.scheduler.done(task_signature)
else:
task = session.dag.nodes[task_signature]["task"]
_update_carry_over_products(
task, wrapper_result.carry_over_products
)
try:
session.hook.pytask_execute_task_teardown(
session=session, task=task
)
except Exception: # noqa: BLE001
report = ExecutionReport.from_task_and_exception(
task, sys.exc_info()
)
else:
report = ExecutionReport.from_task(task)
running_tasks.pop(task_signature)
running_try_last.discard(task_signature)
newly_collected_reports.append(report)
session.scheduler.done(task_signature)
# Check if tasks are not pending but running and update the live
# status.
if (
live_execution or debug_status
) and "_status_queue" in session.config:
status_queue = session.config["_status_queue"]
while True:
try:
started_task = status_queue.get(block=False)
except queue.Empty:
break
if started_task in running_tasks:
if live_execution:
live_execution.update_task(
started_task, status=TaskExecutionStatus.RUNNING
)
if debug_status:
_log_status("RUNNING", started_task)
for report in newly_collected_reports:
session.hook.pytask_execute_task_process_report(
session=session, report=report
)
session.hook.pytask_execute_task_log_end(
session=session, task=task, report=report
)
reports.append(report)
if session.should_stop:
break
sleeper.sleep()
except KeyboardInterrupt:
break
i += 1
return True
@hookimpl
def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
"""Execute a task.
The task function is wrapped according to the worker type and submitted to the
executor.
"""
parallel_backend = registry.registry[session.config["parallel_backend"]]
worker_type = parallel_backend.worker_type
remote = parallel_backend.remote
kwargs = create_kwargs_for_task(task, remote=remote)
if is_coiled_function(task):
# Prevent circular import for coiled backend.
from pytask_parallel.wrappers import ( # noqa: PLC0415
rewrap_task_with_coiled_function,
)
wrapper_func = rewrap_task_with_coiled_function(task)
# Task modules are dynamically loaded and added to `sys.modules`. Thus,
# cloudpickle believes the module of the task function is also importable in the
# child process. We have to register the module as dynamic again, so that
# cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373
# and pytask#374.
task_module = get_module(task.function, getattr(task, "path", None))
cloudpickle.register_pickle_by_value(task_module)
return cast("Any", wrapper_func).submit(
task=task,
console_options=console.options,
kwargs=kwargs,
remote=True,
session_filterwarnings=session.config["filterwarnings"],
show_locals=session.config["show_locals"],
task_filterwarnings=get_marks(task, "filterwarnings"),
)
if worker_type == WorkerType.PROCESSES:
# Prevent circular import for loky backend.
from pytask_parallel.wrappers import wrap_task_in_process # noqa: PLC0415
# Task modules are dynamically loaded and added to `sys.modules`. Thus,
# cloudpickle believes the module of the task function is also importable in the
# child process. We have to register the module as dynamic again, so that
# cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373
# and pytask#374.
task_module = get_module(task.function, getattr(task, "path", None))
cloudpickle.register_pickle_by_value(task_module)
return session.config["_parallel_executor"].submit(
wrap_task_in_process,
task=task,
console_options=console.options,
kwargs=kwargs,
remote=remote,
session_filterwarnings=session.config["filterwarnings"],
status_queue=session.config.get("_status_queue"),
show_locals=session.config["show_locals"],
task_filterwarnings=get_marks(task, "filterwarnings"),
)
if worker_type == WorkerType.THREADS:
# Prevent circular import for loky backend.
from pytask_parallel.wrappers import wrap_task_in_thread # noqa: PLC0415
return session.config["_parallel_executor"].submit(
wrap_task_in_thread,
task=task,
remote=False,
status_queue=session.config.get("_status_queue"),
**kwargs,
)
msg = f"Unknown worker type {worker_type}"
raise ValueError(msg)
@hookimpl
def pytask_unconfigure() -> None:
"""Clean up the parallel executor."""
registry.reset()
def _is_debug_status_enabled() -> bool:
"""Return whether to emit debug status updates."""
value = os.environ.get("PYTASK_PARALLEL_DEBUG_STATUS", "")
return value.strip().lower() in {"1", "true", "yes", "on"}
def _log_status(status: str, task_signature: str) -> None:
"""Log a status transition for a task."""
console.print(f"[pytask-parallel] {status}: {task_signature}")
def _update_carry_over_products(
task: PTask, carry_over_products: PyTree[CarryOverPath | PythonNode | None] | None
) -> None:
"""Update products carry over from a another process or remote worker.
The python node can be a regular one passing the value to another python node.
In other instances the python holds a string or bytes from a RemotePathNode.
"""
def _update_carry_over_node(
x: PNode, y: CarryOverPath | PythonNode | None
) -> PNode:
if y is None:
return x
if isinstance(x, PPathNode) and isinstance(y, CarryOverPath):
x.path.write_bytes(y.content)
return x
if isinstance(y, PythonNode):
x.save(y.load())
return x
raise NotImplementedError
structure_carry_over_products = tree_structure(carry_over_products)
structure_produces = tree_structure(task.produces)
# strict must be false when none is leaf.
if structure_produces.is_prefix(structure_carry_over_products, strict=False):
task.produces = tree_map(
_update_carry_over_node,
task.produces,
carry_over_products,
)
@define(kw_only=True)
class _Sleeper:
"""A sleeper that always sleeps a bit and up to 1 second if you don't wake it up.
This class controls when the next iteration of the execution loop starts. If new
tasks are scheduled, the time spent sleeping is reset to a lower value.
"""
timings: list[float] = field(default=[(i / 10) ** 2 for i in range(1, 11)])
timing_idx: int = 0
def reset(self) -> None:
self.timing_idx = 0
def increment(self) -> None:
if self.timing_idx < len(self.timings) - 1:
self.timing_idx += 1
def sleep(self) -> None:
time.sleep(self.timings[self.timing_idx])