Skip to content

Commit bc912a3

Browse files
committed
Remove more type errors.
1 parent 89edbb4 commit bc912a3

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

src/_pytask/tree_util.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
import functools
66
from pathlib import Path
7+
from typing import TYPE_CHECKING
8+
from typing import Any
9+
from typing import TypeVar
710

811
import optree
9-
from optree import PyTree
1012
from optree import tree_flatten_with_path as _optree_tree_flatten_with_path
1113
from optree import tree_leaves as _optree_tree_leaves
1214
from optree import tree_map as _optree_tree_map
@@ -23,6 +25,19 @@
2325
"tree_structure",
2426
]
2527

28+
_T = TypeVar("_T")
29+
30+
if TYPE_CHECKING:
31+
# Use our own recursive type alias for static type checking.
32+
# optree's PyTree uses __class_getitem__ to generate Union types at runtime,
33+
# but type checkers like ty cannot evaluate these dynamic types properly.
34+
# See: https://github.com/metaopt/optree/issues/251
35+
PyTree = (
36+
_T | tuple["PyTree[_T]", ...] | list["PyTree[_T]"] | dict[Any, "PyTree[_T]"]
37+
)
38+
else:
39+
from optree import PyTree
40+
2641
assert optree.__file__ is not None
2742
TREE_UTIL_LIB_DIRECTORY = Path(optree.__file__).parent
2843

tests/test_collect_command.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def task_example_2(path=Path("in_2.txt"), produces=Path("out_2.txt")): ...
306306

307307
@define
308308
class Node:
309-
path: Path
309+
path: str
310310

311311
def state(self): ...
312312

@@ -321,8 +321,8 @@ def test_print_collected_tasks_without_nodes(capsys):
321321
base_name="function",
322322
path=Path("task_path.py"),
323323
function=function,
324-
depends_on={"depends_on": Node("in.txt")}, # type: ignore[arg-type]
325-
produces={"produces": Node("out.txt")}, # type: ignore[arg-type]
324+
depends_on={"depends_on": Node("in.txt")},
325+
produces={"produces": Node("out.txt")},
326326
)
327327
]
328328
}
@@ -343,8 +343,8 @@ def test_print_collected_tasks_with_nodes(capsys):
343343
base_name="function",
344344
path=Path("task_path.py"),
345345
function=function,
346-
depends_on={"depends_on": PathNode(name="in.txt", path=Path("in.txt"))}, # type: ignore[arg-type]
347-
produces={"produces": PathNode(name="out.txt", path=Path("out.txt"))}, # type: ignore[arg-type]
346+
depends_on={"depends_on": PathNode(name="in.txt", path=Path("in.txt"))},
347+
produces={"produces": PathNode(name="out.txt", path=Path("out.txt"))},
348348
)
349349
]
350350
}
@@ -366,10 +366,10 @@ def test_find_common_ancestor_of_all_nodes(show_nodes, expected_add):
366366
base_name="function",
367367
path=Path.cwd() / "src" / "task_path.py",
368368
function=function,
369-
depends_on={ # type: ignore[arg-type]
369+
depends_on={
370370
"depends_on": PathNode.from_path(Path.cwd() / "src" / "in.txt")
371371
},
372-
produces={ # type: ignore[arg-type]
372+
produces={
373373
"produces": PathNode.from_path(
374374
Path.cwd().joinpath("..", "bld", "out.txt").resolve()
375375
)

0 commit comments

Comments
 (0)