Skip to content

Commit 23d7d6c

Browse files
committed
Fix pickling errors by importing task modules in workers
1 parent f168e73 commit 23d7d6c

6 files changed

Lines changed: 138 additions & 142 deletions

File tree

requirements-dev.lock

Lines changed: 0 additions & 86 deletions
This file was deleted.

requirements.lock

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/pytask_parallel/backends.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,55 @@
1212
from typing import Any
1313
from typing import ClassVar
1414

15+
import os
16+
import sys
1517
import cloudpickle
1618
from attrs import define
1719
from loky import get_reusable_executor
1820

1921
if TYPE_CHECKING:
2022
from collections.abc import Callable
2123

22-
__all__ = ["ParallelBackend", "ParallelBackendRegistry", "WorkerType", "registry"]
24+
__all__ = [
25+
"ParallelBackend",
26+
"ParallelBackendRegistry",
27+
"WorkerType",
28+
"registry",
29+
"set_worker_root",
30+
]
31+
32+
_WORKER_ROOT: str | None = None
33+
34+
35+
def set_worker_root(path: os.PathLike[str] | str) -> None:
36+
"""Configure the root path for worker processes.
37+
38+
Spawned workers (notably on Windows) start with a clean interpreter and may not
39+
inherit the parent's import path. We set both ``sys.path`` and ``PYTHONPATH`` so
40+
task modules are importable by reference, which avoids pickling module globals.
41+
42+
"""
43+
global _WORKER_ROOT
44+
root = os.fspath(path)
45+
_WORKER_ROOT = root
46+
if root not in sys.path:
47+
sys.path.insert(0, root)
48+
# Ensure custom process backends can import task modules by reference.
49+
separator = os.pathsep
50+
current = os.environ.get("PYTHONPATH", "")
51+
parts = [p for p in current.split(separator) if p] if current else []
52+
if root not in parts:
53+
parts.insert(0, root)
54+
os.environ["PYTHONPATH"] = separator.join(parts)
55+
56+
57+
def _configure_worker(root: str | None) -> None:
58+
"""Set cwd and sys.path for worker processes."""
59+
if not root:
60+
return
61+
os.chdir(root)
62+
if root not in sys.path:
63+
sys.path.insert(0, root)
2364

2465

2566
def _deserialize_and_run_with_cloudpickle(fn: bytes, kwargs: bytes) -> Any:
@@ -75,12 +116,16 @@ def _get_dask_executor(n_workers: int) -> Executor:
75116

76117
def _get_loky_executor(n_workers: int) -> Executor:
77118
"""Get a loky executor."""
78-
return get_reusable_executor(max_workers=n_workers)
119+
return get_reusable_executor(
120+
max_workers=n_workers, initializer=_configure_worker, initargs=(_WORKER_ROOT,)
121+
)
79122

80123

81124
def _get_process_pool_executor(n_workers: int) -> Executor:
82125
"""Get a process pool executor."""
83-
return _CloudpickleProcessPoolExecutor(max_workers=n_workers)
126+
return _CloudpickleProcessPoolExecutor(
127+
max_workers=n_workers, initializer=_configure_worker, initargs=(_WORKER_ROOT,)
128+
)
84129

85130

86131
def _get_thread_pool_executor(n_workers: int) -> Executor:

src/pytask_parallel/execute.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626

2727
from pytask_parallel.backends import WorkerType
2828
from pytask_parallel.backends import registry
29+
from pytask_parallel.backends import set_worker_root
2930
from pytask_parallel.typing import CarryOverPath
3031
from pytask_parallel.typing import is_coiled_function
3132
from pytask_parallel.utils import create_kwargs_for_task
3233
from pytask_parallel.utils import get_module
3334
from pytask_parallel.utils import parse_future_result
35+
from pytask_parallel.utils import should_pickle_module_by_value
3436

3537
if TYPE_CHECKING:
3638
from concurrent.futures import Future
@@ -57,6 +59,7 @@ def pytask_execute_build(session: Session) -> bool | None: # noqa: C901, PLR091
5759

5860
# The executor can only be created after the collection to give users the
5961
# possibility to inject their own executors.
62+
set_worker_root(session.config["root"])
6063
session.config["_parallel_executor"] = registry.get_parallel_backend(
6164
session.config["parallel_backend"], n_workers=session.config["n_workers"]
6265
)
@@ -208,7 +211,8 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
208211
# cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373
209212
# and pytask#374.
210213
task_module = get_module(task.function, getattr(task, "path", None))
211-
cloudpickle.register_pickle_by_value(task_module)
214+
if should_pickle_module_by_value(task_module):
215+
cloudpickle.register_pickle_by_value(task_module)
212216

213217
return cast("Any", wrapper_func).submit(
214218
task=task,
@@ -230,7 +234,8 @@ def pytask_execute_task(session: Session, task: PTask) -> Future[WrapperResult]:
230234
# cloudpickle will pickle it with the function. See cloudpickle#417, pytask#373
231235
# and pytask#374.
232236
task_module = get_module(task.function, getattr(task, "path", None))
233-
cloudpickle.register_pickle_by_value(task_module)
237+
if should_pickle_module_by_value(task_module):
238+
cloudpickle.register_pickle_by_value(task_module)
234239

235240
return session.config["_parallel_executor"].submit(
236241
wrap_task_in_process,

src/pytask_parallel/utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from __future__ import annotations
44

5+
import importlib.util
56
import inspect
7+
from pathlib import Path
68
from functools import partial
79
from typing import TYPE_CHECKING
810
from typing import Any
@@ -39,6 +41,7 @@ class CoiledFunction: ...
3941
"create_kwargs_for_task",
4042
"get_module",
4143
"parse_future_result",
44+
"should_pickle_module_by_value",
4245
]
4346

4447

@@ -150,3 +153,30 @@ def get_module(func: Callable[..., Any], path: Path | None) -> ModuleType:
150153
if path:
151154
return inspect.getmodule(func, path.as_posix()) # type: ignore[return-value]
152155
return inspect.getmodule(func) # type: ignore[return-value]
156+
157+
158+
def should_pickle_module_by_value(module: ModuleType) -> bool:
159+
"""Return whether a module should be pickled by value.
160+
161+
We only pickle by value when the module is not importable by name in the worker.
162+
This avoids serializing all module globals, which can fail for non-picklable
163+
objects (e.g., closed file handles or locks stored at module scope).
164+
165+
"""
166+
module_name = getattr(module, "__name__", None)
167+
module_file = getattr(module, "__file__", None)
168+
if not module_name or module_name == "__main__" or module_file is None:
169+
return True
170+
171+
try:
172+
spec = importlib.util.find_spec(module_name)
173+
except (ImportError, ValueError, AttributeError):
174+
return True
175+
176+
if spec is None or spec.origin is None:
177+
return True
178+
179+
try:
180+
return Path(spec.origin).resolve() != Path(module_file).resolve()
181+
except OSError:
182+
return True

tests/test_execute.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,56 @@ def task_create_file(
364364
)
365365
assert result.exit_code == ExitCode.OK
366366
assert tmp_path.joinpath("file.txt").read_text() == "This is the text."
367+
368+
369+
@pytest.mark.parametrize(
370+
"parallel_backend",
371+
[
372+
ParallelBackend.PROCESSES,
373+
pytest.param(ParallelBackend.LOKY, marks=skip_if_deadlock),
374+
],
375+
)
376+
def test_parallel_execution_with_mark_import(runner, tmp_path, parallel_backend):
377+
source = """
378+
from pytask import mark, task
379+
380+
@task
381+
def task_assert_math():
382+
assert 2 + 2 == 4
383+
"""
384+
tmp_path.joinpath("task_mark.py").write_text(textwrap.dedent(source))
385+
result = runner.invoke(
386+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
387+
)
388+
assert result.exit_code == ExitCode.OK
389+
390+
391+
@pytest.mark.parametrize(
392+
"parallel_backend",
393+
[
394+
ParallelBackend.PROCESSES,
395+
pytest.param(ParallelBackend.LOKY, marks=skip_if_deadlock),
396+
],
397+
)
398+
def test_parallel_execution_with_closed_file_handle(
399+
runner, tmp_path, parallel_backend
400+
):
401+
source = """
402+
from pathlib import Path
403+
from pytask import task
404+
405+
data_path = Path(__file__).parent / "data.txt"
406+
data_path.write_text("hello", encoding="utf-8")
407+
408+
with data_path.open(encoding="utf-8") as f:
409+
content = f.read()
410+
411+
@task
412+
def task_assert_math():
413+
assert content == "hello"
414+
"""
415+
tmp_path.joinpath("task_file.py").write_text(textwrap.dedent(source))
416+
result = runner.invoke(
417+
cli, [tmp_path.as_posix(), "-n", "2", "--parallel-backend", parallel_backend]
418+
)
419+
assert result.exit_code == ExitCode.OK

0 commit comments

Comments
 (0)