Skip to content

Commit e2943f6

Browse files
committed
Fix UPath typing for universal-pathlib 0.3
1 parent dcbce9d commit e2943f6

6 files changed

Lines changed: 49 additions & 32 deletions

File tree

src/_pytask/collect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def pytask_collect_node( # noqa: C901, PLR0912
501501

502502
if isinstance(node, UPath): # pragma: no cover
503503
if not node.protocol:
504-
node = Path(node)
504+
node = Path(str(node))
505505
else:
506506
return PathNode(name=node.name, path=node)
507507

src/_pytask/lockfile.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
if TYPE_CHECKING:
2929
from _pytask.session import Session
30+
from _pytask.typing import NodePath
3031

3132
CURRENT_LOCKFILE_VERSION = "1"
3233

@@ -71,13 +72,14 @@ def _encode_node_path(path: tuple[str | int, ...]) -> str:
7172
return msgspec.json.encode(path).decode()
7273

7374

74-
def _relative_path(path: Path, root: Path) -> str:
75+
def _relative_path(path: NodePath, root: Path) -> str:
7576
if isinstance(path, UPath) and path.protocol:
7677
return str(path)
78+
local_path = Path(str(path)) if isinstance(path, UPath) else path
7779
try:
78-
rel = os.path.relpath(path, root)
80+
rel = os.path.relpath(local_path, root)
7981
except ValueError:
80-
return path.as_posix()
82+
return local_path.as_posix()
8183
return Path(rel).as_posix()
8284

8385

src/_pytask/node_protocols.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from _pytask.mark import Mark
1313
from _pytask.tree_util import PyTree
14+
from _pytask.typing import NodePath
1415

1516

1617
__all__ = ["PNode", "PPathNode", "PProvisionalNode", "PTask", "PTaskWithPath"]
@@ -60,7 +61,7 @@ class PPathNode(PNode, Protocol):
6061
6162
"""
6263

63-
path: Path
64+
path: NodePath
6465

6566

6667
@runtime_checkable

src/_pytask/nodes.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
from _pytask.node_protocols import PTaskWithPath
2525
from _pytask.path import hash_path
2626
from _pytask.typing import NoDefault
27+
from _pytask.typing import NodePath
2728
from _pytask.typing import no_default
2829

2930
if TYPE_CHECKING:
3031
from collections.abc import Callable
31-
from io import BufferedReader
32-
from io import BufferedWriter
3332
from pathlib import Path
33+
from typing import BinaryIO
3434

3535
from _pytask.mark import Mark
3636
from _pytask.models import NodeInfo
@@ -176,7 +176,7 @@ class PathNode(PPathNode):
176176
177177
"""
178178

179-
path: Path
179+
path: NodePath
180180
name: str = ""
181181
attributes: dict[Any, Any] = field(default_factory=dict)
182182

@@ -187,7 +187,7 @@ def signature(self) -> str:
187187
return hashlib.sha256(raw_key.encode()).hexdigest()
188188

189189
@classmethod
190-
def from_path(cls, path: Path) -> PathNode:
190+
def from_path(cls, path: NodePath) -> PathNode:
191191
"""Instantiate class from path to file."""
192192
return cls(name=path.as_posix(), path=path)
193193

@@ -199,7 +199,7 @@ def state(self) -> str | None:
199199
"""
200200
return get_state_of_path(self.path)
201201

202-
def load(self, is_product: bool = False) -> Path: # noqa: ARG002
202+
def load(self, is_product: bool = False) -> NodePath: # noqa: ARG002
203203
"""Load the value."""
204204
return self.path
205205

@@ -330,11 +330,11 @@ class PickleNode(PPathNode):
330330
331331
"""
332332

333-
path: Path
333+
path: NodePath
334334
name: str = ""
335335
attributes: dict[Any, Any] = field(default_factory=dict)
336-
serializer: Callable[[Any, BufferedWriter], None] = field(default=pickle.dump)
337-
deserializer: Callable[[BufferedReader], Any] = field(default=pickle.load)
336+
serializer: Callable[[Any, BinaryIO], None] = field(default=pickle.dump)
337+
deserializer: Callable[[BinaryIO], Any] = field(default=pickle.load)
338338

339339
@property
340340
def signature(self) -> str:
@@ -343,7 +343,7 @@ def signature(self) -> str:
343343
return hashlib.sha256(raw_key.encode()).hexdigest()
344344

345345
@classmethod
346-
def from_path(cls, path: Path) -> PickleNode:
346+
def from_path(cls, path: NodePath) -> PickleNode:
347347
"""Instantiate class from path to file."""
348348
if not path.is_absolute():
349349
msg = "Node must be instantiated from absolute path."
@@ -409,7 +409,7 @@ def collect(self) -> list[Path]:
409409
return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr]
410410

411411

412-
def get_state_of_path(path: Path) -> str | None:
412+
def get_state_of_path(path: NodePath) -> str | None:
413413
"""Get state of a path.
414414
415415
A simple function to handle local and remote files.
@@ -436,7 +436,7 @@ def get_state_of_path(path: Path) -> str | None:
436436

437437

438438
@deprecated("Use 'pytask.get_state_of_path' instead.")
439-
def _get_state(path: Path) -> str | None:
439+
def _get_state(path: NodePath) -> str | None:
440440
"""Get state of a path.
441441
442442
A simple function to handle local and remote files.

src/_pytask/path.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
if TYPE_CHECKING:
2222
from collections.abc import Sequence
2323

24+
from _pytask.typing import NodePath
25+
2426
__all__ = [
2527
"find_case_sensitive_path",
2628
"find_closest_ancestor",
@@ -64,24 +66,27 @@ def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Pat
6466
return Path(source_name, path.relative_to(source))
6567

6668

67-
def is_non_local_path(path: Path) -> bool:
69+
def is_non_local_path(path: NodePath) -> bool:
6870
"""Return whether a path points to a non-local `UPath` resource."""
6971
return isinstance(path, UPath) and path.protocol not in _LOCAL_UPATH_PROTOCOLS
7072

7173

72-
def normalize_local_upath(path: Path) -> Path:
74+
def normalize_local_upath(path: NodePath) -> NodePath:
7375
"""Convert local `UPath` variants to a stdlib `Path`."""
74-
if isinstance(path, UPath) and path.protocol in {"file", "local"}:
75-
local_path = path.path
76-
if (
77-
sys.platform == "win32"
78-
and local_path.startswith("/")
79-
and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH
80-
and local_path[1].isalpha()
81-
and local_path[2] == ":"
82-
):
83-
local_path = local_path[1:]
84-
return Path(local_path)
76+
if isinstance(path, UPath):
77+
if path.protocol in {"file", "local"}:
78+
local_path = path.path
79+
if (
80+
sys.platform == "win32"
81+
and local_path.startswith("/")
82+
and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH
83+
and local_path[1].isalpha()
84+
and local_path[2] == ":"
85+
):
86+
local_path = local_path[1:]
87+
return Path(local_path)
88+
if not path.protocol:
89+
return Path(str(path))
8590
return path
8691

8792

@@ -451,7 +456,7 @@ def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) ->
451456
module_name = ".".join(module_parts)
452457

453458

454-
def shorten_path(path: Path, paths: Sequence[Path]) -> str:
459+
def shorten_path(path: NodePath, paths: Sequence[NodePath]) -> str:
455460
"""Shorten a path.
456461
457462
The whole path of a node - which includes the drive letter - can be very long
@@ -466,6 +471,8 @@ def shorten_path(path: Path, paths: Sequence[Path]) -> str:
466471

467472
path = normalize_local_upath(path)
468473
paths = [normalize_local_upath(p) for p in paths]
474+
path = Path(str(path)) if isinstance(path, UPath) else path
475+
paths = [Path(str(p)) if isinstance(p, UPath) else p for p in paths]
469476

470477
ancestor = find_closest_ancestor(path, paths)
471478
if ancestor is None:

src/_pytask/typing.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,25 @@
33
import functools
44
from dataclasses import dataclass
55
from enum import Enum
6+
from pathlib import Path
67
from typing import TYPE_CHECKING
78
from typing import Any
89
from typing import Final
910
from typing import Literal
1011
from typing import Protocol
12+
from typing import TypeAlias
1113
from typing import runtime_checkable
1214

13-
if TYPE_CHECKING:
14-
from typing import TypeAlias
15+
from upath import UPath
1516

17+
if TYPE_CHECKING:
1618
from _pytask.models import CollectionMetadata
1719
from pytask import PTask
1820

1921

2022
__all__ = [
2123
"NoDefault",
24+
"NodePath",
2225
"Product",
2326
"ProductType",
2427
"TaskFunction",
@@ -27,6 +30,10 @@
2730
]
2831

2932

33+
NodePath: TypeAlias = Path | UPath
34+
"""A local stdlib path or a universal-pathlib path."""
35+
36+
3037
@runtime_checkable
3138
class TaskFunction(Protocol):
3239
"""Protocol for callables decorated with @task that have pytask_meta attached.

0 commit comments

Comments
 (0)