Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,10 @@ def pytask_collect_node( # noqa: C901, PLR0912
and not is_non_local_path(node.path)
and not node.path.is_absolute()
):
node.path = path.joinpath(node.path)
local_node_path = (
Path(str(node.path)) if isinstance(node.path, UPath) else node.path
)
node.path = path.joinpath(local_node_path)

# ``normpath`` removes ``../`` from the path which is necessary for the casing
# check which will fail since ``.resolves()`` also normalizes a path.
Expand Down Expand Up @@ -501,7 +504,7 @@ def pytask_collect_node( # noqa: C901, PLR0912

if isinstance(node, UPath): # pragma: no cover
if not node.protocol:
node = Path(node)
node = Path(str(node))
else:
return PathNode(name=node.name, path=node)

Expand Down
8 changes: 5 additions & 3 deletions src/_pytask/lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if TYPE_CHECKING:
from _pytask.session import Session
from _pytask.typing import NodePath

CURRENT_LOCKFILE_VERSION = "1"

Expand Down Expand Up @@ -71,13 +72,14 @@ def _encode_node_path(path: tuple[str | int, ...]) -> str:
return msgspec.json.encode(path).decode()


def _relative_path(path: Path, root: Path) -> str:
def _relative_path(path: NodePath, root: Path) -> str:
if isinstance(path, UPath) and path.protocol:
return str(path)
local_path = Path(str(path)) if isinstance(path, UPath) else path
try:
rel = os.path.relpath(path, root)
rel = os.path.relpath(local_path, root)
except ValueError:
return path.as_posix()
return local_path.as_posix()
return Path(rel).as_posix()


Expand Down
3 changes: 2 additions & 1 deletion src/_pytask/node_protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from _pytask.mark import Mark
from _pytask.tree_util import PyTree
from _pytask.typing import NodePath


__all__ = ["PNode", "PPathNode", "PProvisionalNode", "PTask", "PTaskWithPath"]
Expand Down Expand Up @@ -60,7 +61,7 @@ class PPathNode(PNode, Protocol):

"""

path: Path
path: NodePath


@runtime_checkable
Expand Down
22 changes: 11 additions & 11 deletions src/_pytask/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from _pytask.node_protocols import PTaskWithPath
from _pytask.path import hash_path
from _pytask.typing import NoDefault
from _pytask.typing import NodePath
from _pytask.typing import no_default

if TYPE_CHECKING:
from collections.abc import Callable
from io import BufferedReader
from io import BufferedWriter
from pathlib import Path
from typing import BinaryIO

from _pytask.mark import Mark
from _pytask.models import NodeInfo
Expand Down Expand Up @@ -176,7 +176,7 @@ class PathNode(PPathNode):

"""

path: Path
path: NodePath
name: str = ""
attributes: dict[Any, Any] = field(default_factory=dict)

Expand All @@ -187,7 +187,7 @@ def signature(self) -> str:
return hashlib.sha256(raw_key.encode()).hexdigest()

@classmethod
def from_path(cls, path: Path) -> PathNode:
def from_path(cls, path: NodePath) -> PathNode:
"""Instantiate class from path to file."""
return cls(name=path.as_posix(), path=path)

Expand All @@ -199,7 +199,7 @@ def state(self) -> str | None:
"""
return get_state_of_path(self.path)

def load(self, is_product: bool = False) -> Path: # noqa: ARG002
def load(self, is_product: bool = False) -> NodePath: # noqa: ARG002
"""Load the value."""
return self.path

Expand Down Expand Up @@ -330,11 +330,11 @@ class PickleNode(PPathNode):

"""

path: Path
path: NodePath
name: str = ""
attributes: dict[Any, Any] = field(default_factory=dict)
serializer: Callable[[Any, BufferedWriter], None] = field(default=pickle.dump)
deserializer: Callable[[BufferedReader], Any] = field(default=pickle.load)
serializer: Callable[[Any, BinaryIO], None] = field(default=pickle.dump)
deserializer: Callable[[BinaryIO], Any] = field(default=pickle.load)

@property
def signature(self) -> str:
Expand All @@ -343,7 +343,7 @@ def signature(self) -> str:
return hashlib.sha256(raw_key.encode()).hexdigest()

@classmethod
def from_path(cls, path: Path) -> PickleNode:
def from_path(cls, path: NodePath) -> PickleNode:
"""Instantiate class from path to file."""
if not path.is_absolute():
msg = "Node must be instantiated from absolute path."
Expand Down Expand Up @@ -409,7 +409,7 @@ def collect(self) -> list[Path]:
return list(self.root_dir.glob(self.pattern)) # type: ignore[union-attr]


def get_state_of_path(path: Path) -> str | None:
def get_state_of_path(path: NodePath) -> str | None:
"""Get state of a path.

A simple function to handle local and remote files.
Expand All @@ -436,7 +436,7 @@ def get_state_of_path(path: Path) -> str | None:


@deprecated("Use 'pytask.get_state_of_path' instead.")
def _get_state(path: Path) -> str | None:
def _get_state(path: NodePath) -> str | None:
"""Get state of a path.

A simple function to handle local and remote files.
Expand Down
35 changes: 21 additions & 14 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from _pytask.typing import NodePath

__all__ = [
"find_case_sensitive_path",
"find_closest_ancestor",
Expand Down Expand Up @@ -64,24 +66,27 @@ def relative_to(path: Path, source: Path, *, include_source: bool = True) -> Pat
return Path(source_name, path.relative_to(source))


def is_non_local_path(path: Path) -> bool:
def is_non_local_path(path: NodePath) -> bool:
"""Return whether a path points to a non-local `UPath` resource."""
return isinstance(path, UPath) and path.protocol not in _LOCAL_UPATH_PROTOCOLS


def normalize_local_upath(path: Path) -> Path:
def normalize_local_upath(path: NodePath) -> NodePath:
"""Convert local `UPath` variants to a stdlib `Path`."""
if isinstance(path, UPath) and path.protocol in {"file", "local"}:
local_path = path.path
if (
sys.platform == "win32"
and local_path.startswith("/")
and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH
and local_path[1].isalpha()
and local_path[2] == ":"
):
local_path = local_path[1:]
return Path(local_path)
if isinstance(path, UPath):
if path.protocol in {"file", "local"}:
local_path = path.path
if (
sys.platform == "win32"
and local_path.startswith("/")
and len(local_path) >= _WINDOWS_DRIVE_PREFIX_LENGTH
and local_path[1].isalpha()
and local_path[2] == ":"
):
local_path = local_path[1:]
return Path(local_path)
if not path.protocol:
return Path(str(path))
return path


Expand Down Expand Up @@ -451,7 +456,7 @@ def _insert_missing_modules(modules: dict[str, ModuleType], module_name: str) ->
module_name = ".".join(module_parts)


def shorten_path(path: Path, paths: Sequence[Path]) -> str:
def shorten_path(path: NodePath, paths: Sequence[NodePath]) -> str:
"""Shorten a path.

The whole path of a node - which includes the drive letter - can be very long
Expand All @@ -466,6 +471,8 @@ def shorten_path(path: Path, paths: Sequence[Path]) -> str:

path = normalize_local_upath(path)
paths = [normalize_local_upath(p) for p in paths]
path = Path(str(path)) if isinstance(path, UPath) else path
paths = [Path(str(p)) if isinstance(p, UPath) else p for p in paths]

ancestor = find_closest_ancestor(path, paths)
if ancestor is None:
Expand Down
11 changes: 9 additions & 2 deletions src/_pytask/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@
import functools
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Final
from typing import Literal
from typing import Protocol
from typing import TypeAlias
from typing import runtime_checkable

if TYPE_CHECKING:
from typing import TypeAlias
from upath import UPath

if TYPE_CHECKING:
from _pytask.models import CollectionMetadata
from pytask import PTask


__all__ = [
"NoDefault",
"NodePath",
"Product",
"ProductType",
"TaskFunction",
Expand All @@ -27,6 +30,10 @@
]


NodePath: TypeAlias = Path | UPath
"""A local stdlib path or a universal-pathlib path."""


@runtime_checkable
class TaskFunction(Protocol):
"""Protocol for callables decorated with @task that have pytask_meta attached.
Expand Down
16 changes: 13 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading