Skip to content

Commit 03ca18d

Browse files
committed
Handle local UPath protocols consistently
1 parent c16d307 commit 03ca18d

File tree

6 files changed

+87
-3
lines changed

6 files changed

+87
-3
lines changed

src/_pytask/collect.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from _pytask.path import find_case_sensitive_path
4444
from _pytask.path import import_path
4545
from _pytask.path import is_non_local_path
46+
from _pytask.path import normalize_local_upath
4647
from _pytask.path import shorten_path
4748
from _pytask.pluginmanager import hookimpl
4849
from _pytask.reports import CollectionReport
@@ -456,6 +457,9 @@ def pytask_collect_node( # noqa: C901, PLR0912
456457
node.name = create_name_of_python_node(node_info)
457458
return node
458459

460+
if isinstance(node, PPathNode):
461+
node.path = normalize_local_upath(node.path)
462+
459463
if (
460464
isinstance(node, PPathNode)
461465
and not is_non_local_path(node.path)
@@ -492,6 +496,9 @@ def pytask_collect_node( # noqa: C901, PLR0912
492496
node.name = create_name_of_python_node(node_info)
493497
return node
494498

499+
if isinstance(node, UPath): # pragma: no cover
500+
node = normalize_local_upath(node)
501+
495502
if isinstance(node, UPath): # pragma: no cover
496503
if not node.protocol:
497504
node = Path(node)

src/_pytask/collect_command.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from _pytask.outcomes import ExitCode
3232
from _pytask.path import find_common_ancestor
3333
from _pytask.path import is_non_local_path
34+
from _pytask.path import normalize_local_upath
3435
from _pytask.path import relative_to
3536
from _pytask.pluginmanager import hookimpl
3637
from _pytask.pluginmanager import storage
@@ -126,12 +127,12 @@ def _find_common_ancestor_of_all_nodes(
126127
all_paths.append(task.path)
127128
if show_nodes:
128129
all_paths.extend(
129-
x.path
130+
normalize_local_upath(x.path)
130131
for x in tree_leaves(task.depends_on)
131132
if isinstance(x, PPathNode) and not is_non_local_path(x.path)
132133
)
133134
all_paths.extend(
134-
x.path
135+
normalize_local_upath(x.path)
135136
for x in tree_leaves(task.produces)
136137
if isinstance(x, PPathNode) and not is_non_local_path(x.path)
137138
)

src/_pytask/path.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@
2828
"hash_path",
2929
"import_path",
3030
"is_non_local_path",
31+
"normalize_local_upath",
3132
"relative_to",
3233
"shorten_path",
3334
]
3435

3536

37+
_LOCAL_UPATH_PROTOCOLS = frozenset(("", "file", "local"))
38+
39+
3640
def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Path:
3741
"""Make a path relative to another path.
3842
@@ -61,7 +65,14 @@ def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Pat
6165

6266
def is_non_local_path(path: Path) -> bool:
6367
"""Return whether a path points to a non-local `UPath` resource."""
64-
return isinstance(path, UPath) and bool(path.protocol)
68+
return isinstance(path, UPath) and path.protocol not in _LOCAL_UPATH_PROTOCOLS
69+
70+
71+
def normalize_local_upath(path: Path) -> Path:
72+
"""Convert local `UPath` variants to a stdlib `Path`."""
73+
if isinstance(path, UPath) and path.protocol in {"file", "local"}:
74+
return Path(path.path)
75+
return path
6576

6677

6778
def find_closest_ancestor(
@@ -443,6 +454,9 @@ def shorten_path(path: Path, paths: Sequence[Path]) -> str:
443454
if is_non_local_path(path):
444455
return path.as_posix()
445456

457+
path = normalize_local_upath(path)
458+
paths = [normalize_local_upath(p) for p in paths]
459+
446460
ancestor = find_closest_ancestor(path, paths)
447461
if ancestor is None:
448462
try:

tests/test_collect.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,31 @@ def test_pytask_collect_remote_path_node_keeps_uri_name():
214214
assert result.name == "s3://bucket/file.pkl"
215215

216216

217+
@pytest.mark.parametrize("protocol", ["file", "local"])
218+
def test_pytask_collect_local_upath_protocol_node_is_shortened(tmp_path, protocol):
219+
upath = pytest.importorskip("upath")
220+
221+
session = Session.from_config(
222+
{"check_casing_of_paths": False, "paths": (tmp_path,), "root": tmp_path}
223+
)
224+
225+
result = pytask_collect_node(
226+
session,
227+
tmp_path,
228+
NodeInfo(
229+
arg_name="path",
230+
path=(),
231+
value=PickleNode(path=upath.UPath(f"{protocol}://{tmp_path}/file.pkl")),
232+
task_path=tmp_path / "task_example.py",
233+
task_name="task_example",
234+
),
235+
)
236+
237+
assert isinstance(result, PPathNode)
238+
assert result.path == tmp_path / "file.pkl"
239+
assert result.name == f"{tmp_path.name}/file.pkl"
240+
241+
217242
@pytest.mark.skipif(
218243
sys.platform != "win32", reason="Only works on case-insensitive file systems."
219244
)

tests/test_collect_command.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,32 @@ def task_example(
421421
assert "s3://bucket/in.pkl" in result.output
422422

423423

424+
@pytest.mark.parametrize("protocol", ["file", "local"])
425+
def test_collect_task_with_local_upath_protocol_node(runner, tmp_path, protocol):
426+
pytest.importorskip("upath")
427+
428+
source = f"""
429+
from pathlib import Path
430+
from typing import Annotated
431+
432+
from upath import UPath
433+
434+
from pytask import PickleNode
435+
from pytask import Product
436+
437+
def task_example(
438+
data=PickleNode(path=UPath("{protocol}://{tmp_path.as_posix()}/in.pkl")),
439+
path: Annotated[Path, Product] = Path("out.txt"),
440+
): ...
441+
"""
442+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
443+
444+
result = runner.invoke(cli, ["collect", "--nodes", tmp_path.as_posix()])
445+
446+
assert result.exit_code == ExitCode.OK
447+
assert f"{tmp_path.name}/in.pkl" in result.output
448+
449+
424450
def test_python_node_is_collected(runner, tmp_path):
425451
source = """
426452
from pytask import Product

tests/test_path.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from _pytask.path import find_case_sensitive_path
2020
from _pytask.path import find_closest_ancestor
2121
from _pytask.path import find_common_ancestor
22+
from _pytask.path import is_non_local_path
2223
from _pytask.path import relative_to
2324
from _pytask.path import shorten_path
2425
from pytask.path import import_path
@@ -119,6 +120,16 @@ def test_shorten_path_keeps_non_local_uri():
119120
assert shorten_path(path, [Path.cwd()]) == "s3://bucket/file.pkl"
120121

121122

123+
@pytest.mark.parametrize("protocol", ["file", "local"])
124+
def test_shorten_path_treats_local_upath_protocols_as_local(tmp_path, protocol):
125+
upath = pytest.importorskip("upath")
126+
127+
path = upath.UPath(f"{protocol}://{tmp_path.as_posix()}/file.pkl")
128+
129+
assert not is_non_local_path(path)
130+
assert shorten_path(path, [tmp_path]) == f"{tmp_path.name}/file.pkl"
131+
132+
122133
@pytest.mark.skipif(sys.platform != "win32", reason="Only works on Windows.")
123134
@pytest.mark.parametrize(
124135
("path", "existing_paths", "expected"),

0 commit comments

Comments
 (0)