Skip to content

Commit b993e76

Browse files
committed
Switch pending-status tracking to a start queue
1 parent 79c1137 commit b993e76

2 files changed

Lines changed: 30 additions & 25 deletions

File tree

src/pytask_parallel/execute.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import queue
56
import sys
67
import time
78
from contextlib import ExitStack
@@ -61,8 +62,8 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
6162
running_tasks: dict[str, Future[Any]] = {}
6263
sleeper = _Sleeper()
6364

64-
# Create a shared memory object to differentiate between running and pending
65-
# tasks for some parallel backends.
65+
# Create a shared queue to differentiate between running and pending tasks for
66+
# some parallel backends.
6667
if session.config["parallel_backend"] in (
6768
ParallelBackend.PROCESSES,
6869
ParallelBackend.THREADS,
@@ -89,7 +90,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
8990
ParallelBackend.THREADS,
9091
ParallelBackend.LOKY,
9192
):
92-
session.config["_shared_memory"] = manager.dict() # type: ignore[union-attr]
93+
session.config["_status_queue"] = manager.Queue() # type: ignore[union-attr]
9394

9495
i = 0
9596
while session.scheduler.is_active():
@@ -187,12 +188,18 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
187188
newly_collected_reports.append(report)
188189
session.scheduler.done(task_signature)
189190

190-
# Check if tasks are not pending but running and update the live
191-
# status.
192-
elif live_execution and "_shared_memory" in session.config:
193-
if task_signature in session.config["_shared_memory"]:
191+
# Check if tasks are not pending but running and update the live
192+
# status.
193+
if live_execution and "_status_queue" in session.config:
194+
status_queue = session.config["_status_queue"]
195+
while True:
196+
try:
197+
started_task = status_queue.get(block=False)
198+
except queue.Empty:
199+
break
200+
if started_task in running_tasks:
194201
live_execution.update_task(
195-
task_signature, status=TaskExecutionStatus.RUNNING
202+
started_task, status=TaskExecutionStatus.RUNNING
196203
)
197204

198205
for report in newly_collected_reports:
@@ -275,7 +282,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
275282
kwargs=kwargs,
276283
remote=remote,
277284
session_filterwarnings=session.config["filterwarnings"],
278-
shared_memory=session.config.get("_shared_memory"),
285+
status_queue=session.config.get("_status_queue"),
279286
show_locals=session.config["show_locals"],
280287
task_filterwarnings=get_marks(task, "filterwarnings"),
281288
)
@@ -288,7 +295,7 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
288295
wrap_task_in_thread,
289296
task=task,
290297
remote=False,
291-
shared_memory=session.config.get("_shared_memory"),
298+
status_queue=session.config.get("_status_queue"),
292299
**kwargs,
293300
)
294301

src/pytask_parallel/wrappers.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
if TYPE_CHECKING:
3939
from collections.abc import Callable
40+
from queue import Queue
4041
from types import TracebackType
4142

4243
from pytask import Mark
@@ -58,7 +59,11 @@ class WrapperResult:
5859

5960

6061
def wrap_task_in_thread(
61-
task: PTask, *, remote: bool, shared_memory: dict[str, bool] | None, **kwargs: Any
62+
task: PTask,
63+
*,
64+
remote: bool,
65+
status_queue: "Queue[str] | None" = None,
66+
**kwargs: Any,
6267
) -> WrapperResult:
6368
"""Mock execution function such that it returns the same as for processes.
6469
@@ -69,9 +74,9 @@ def wrap_task_in_thread(
6974
"""
7075
__tracebackhide__ = True
7176

72-
# Add task to shared memory to indicate that it is currently being executed.
73-
if shared_memory is not None:
74-
shared_memory[task.signature] = True
77+
# Add task to the status queue to indicate that it is currently being executed.
78+
if status_queue is not None:
79+
status_queue.put(task.signature)
7580

7681
try:
7782
out = task.function(**kwargs)
@@ -89,9 +94,6 @@ def wrap_task_in_thread(
8994
_handle_function_products(task, out, remote=remote)
9095
exc_info = None
9196

92-
# Remove task from shared memory to indicate that it is no longer being executed.
93-
if shared_memory is not None:
94-
shared_memory.pop(task.signature, None)
9597
return WrapperResult(
9698
carry_over_products=None,
9799
warning_reports=[],
@@ -108,7 +110,7 @@ def wrap_task_in_process( # noqa: PLR0913
108110
kwargs: dict[str, Any],
109111
remote: bool,
110112
session_filterwarnings: tuple[str, ...],
111-
shared_memory: dict[str, bool] | None,
113+
status_queue: "Queue[str] | None" = None,
112114
show_locals: bool,
113115
task_filterwarnings: tuple[Mark, ...],
114116
) -> WrapperResult:
@@ -121,9 +123,9 @@ def wrap_task_in_process( # noqa: PLR0913
121123
# Hide this function from tracebacks.
122124
__tracebackhide__ = True
123125

124-
# Add task to shared memory to indicate that it is currently being executed.
125-
if shared_memory is not None:
126-
shared_memory[task.signature] = True
126+
# Add task to the status queue to indicate that it is currently being executed.
127+
if status_queue is not None:
128+
status_queue.put(task.signature)
127129

128130
# Patch set_trace and breakpoint to show a better error message.
129131
_patch_set_trace_and_breakpoint()
@@ -184,10 +186,6 @@ def wrap_task_in_process( # noqa: PLR0913
184186
captured_stdout_buffer.close()
185187
captured_stderr_buffer.close()
186188

187-
# Remove task from shared memory to indicate that it is no longer being executed.
188-
if shared_memory is not None:
189-
shared_memory.pop(task.signature, None)
190-
191189
return WrapperResult(
192190
carry_over_products=products,
193191
warning_reports=warning_reports,

0 commit comments

Comments
 (0)