-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathtest_node_protocols.py
More file actions
122 lines (95 loc) · 3.32 KB
/
test_node_protocols.py
File metadata and controls
122 lines (95 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations
import pickle
import textwrap
from pytask import ExitCode
from pytask import cli
def test_node_protocol_for_custom_nodes(runner, tmp_path):
source = """
from typing import Annotated
from typing import Any
from pytask import Product
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
@dataclass
class CustomNode:
name: str
value: str
signature: str = "id"
attributes: dict[Any, Any] = field(default_factory=dict)
def state(self):
return self.value
def load(self, is_product):
return self.value
def save(self, value):
self.value = value
def task_example(
data = CustomNode("custom", "text"),
out: Annotated[Path, Product] = Path("out.txt"),
) -> None:
out.write_text(data)
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK
assert tmp_path.joinpath("out.txt").read_text() == "text"
assert "FutureWarning" not in result.output
def test_node_protocol_for_custom_nodes_with_paths(runner, tmp_path):
source = """
from typing import Annotated
from typing import Any
from pytask import Product
from pathlib import Path
from dataclasses import dataclass
from dataclasses import field
import pickle
@dataclass
class PickleFile:
name: str
path: Path
value: Path
signature: str = "id"
attributes: dict[Any, Any] = field(default_factory=dict)
def state(self):
return str(self.path.stat().st_mtime)
def load(self, is_product):
with self.path.open("rb") as f:
out = pickle.load(f)
return out
def save(self, value):
with self.path.open("wb") as f:
pickle.dump(value, f)
_PATH = Path(__file__).parent.joinpath("in.pkl")
def task_example(
data = PickleFile(_PATH.as_posix(), _PATH, _PATH),
out: Annotated[Path, Product] = Path("out.txt"),
) -> None:
out.write_text(data)
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
tmp_path.joinpath("in.pkl").write_bytes(pickle.dumps("text"))
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.OK
assert tmp_path.joinpath("out.txt").read_text() == "text"
def test_node_protocol_for_custom_nodes_requires_attributes(runner, tmp_path):
source = """
from typing import Annotated
from dataclasses import dataclass
@dataclass
class CustomNode:
name: str
value: str
signature: str = "id"
def state(self):
return self.value
def load(self, is_product):
return self.value
def save(self, value):
self.value = value
def task_example() -> Annotated[str, CustomNode("custom", "text")]:
return "text"
"""
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
result = runner.invoke(cli, [tmp_path.as_posix()])
assert result.exit_code == ExitCode.COLLECTION_FAILED
assert "does not follow the 'pytask.PNode' protocol" in result.output