Skip to content

Commit cac27fe

Browse files
committed
Fix lockfile validation scope
1 parent 66736ae commit cac27fe

2 files changed

Lines changed: 79 additions & 14 deletions

File tree

src/_pytask/lockfile.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,14 @@ class _JournalEntry(msgspec.Struct):
5959
produces: dict[str, str] = msgspec.field(default_factory=dict)
6060

6161

62+
def _should_initialize_lockfile_state(command: str | None) -> bool:
63+
return command in (None, "build")
64+
65+
66+
def _should_validate_lockfile_ids(command: str | None) -> bool:
67+
return command in (None, "build", "collect")
68+
69+
6270
def _encode_node_path(path: tuple[str | int, ...]) -> str:
6371
return msgspec.json.encode(path).decode()
6472

@@ -276,16 +284,6 @@ def _raise_error_if_lockfile_ids_are_ambiguous(tasks: list[PTask], root: Path) -
276284
)
277285

278286
for kind, node in chain(dependencies, products):
279-
try:
280-
node_state = node.state()
281-
except Exception: # noqa: BLE001
282-
# Preserve existing behavior where state() errors are raised during
283-
# task execution, not during collection-time lockfile validation.
284-
node_state = None
285-
286-
if node_state is None:
287-
continue
288-
289287
node_id = build_portable_node_id(node, root)
290288
current = (node.signature, kind, node.name)
291289
previous = seen.get(node_id)
@@ -306,9 +304,8 @@ def _raise_error_if_lockfile_ids_are_ambiguous(tasks: list[PTask], root: Path) -
306304

307305
if errors:
308306
msg = (
309-
"Ambiguous lockfile ids detected. Each dependency/product that contributes "
310-
"state must map to a unique lockfile id within a task.\n\n"
311-
+ "\n".join(errors)
307+
"Ambiguous lockfile ids detected. Each dependency/product must map to a "
308+
"unique lockfile id within a task.\n\n" + "\n".join(errors)
312309
)
313310
raise ValueError(msg)
314311

@@ -423,6 +420,8 @@ def flush(self) -> None:
423420
@hookimpl
424421
def pytask_post_parse(config: dict[str, Any]) -> None:
425422
"""Initialize the lockfile state."""
423+
if not _should_initialize_lockfile_state(config.get("command")):
424+
return
426425
path = config["root"] / "pytask.lock"
427426
config["lockfile_path"] = path
428427
config["lockfile_state"] = LockfileState.from_path(path, config["root"])
@@ -431,6 +430,8 @@ def pytask_post_parse(config: dict[str, Any]) -> None:
431430
@hookimpl(trylast=True)
432431
def pytask_collect_modify_tasks(session: Session, tasks: list[PTask]) -> None:
433432
"""Validate that lockfile ids are unambiguous for collected tasks."""
433+
if not _should_validate_lockfile_ids(session.config.get("command")):
434+
return
434435
_raise_error_if_lockfile_ids_are_ambiguous(tasks, session.config["root"])
435436

436437

tests/test_lockfile.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,14 @@ def test_python_node_id_is_collision_free(tmp_path):
100100
assert left_id != right_id
101101

102102

103-
def test_collection_fails_for_ambiguous_lockfile_ids(runner, tmp_path):
103+
@pytest.mark.parametrize(
104+
"args",
105+
[
106+
pytest.param(lambda path: [path.as_posix()], id="build"),
107+
pytest.param(lambda path: ["collect", path.as_posix()], id="collect"),
108+
],
109+
)
110+
def test_collection_fails_for_ambiguous_lockfile_ids(runner, tmp_path, args):
104111
source = """
105112
from dataclasses import dataclass, field
106113
from pathlib import Path
@@ -131,6 +138,63 @@ def task_example(
131138
"""
132139
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
133140

141+
result = runner.invoke(cli, args(tmp_path))
142+
143+
assert result.exit_code == ExitCode.COLLECTION_FAILED
144+
assert "Ambiguous lockfile ids detected" in result.output
145+
assert "lockfile id 'dup'" in result.output
146+
147+
148+
def test_markers_command_ignores_invalid_lockfile(runner, tmp_path):
149+
tmp_path.joinpath("pytask.lock").write_text("{not toml")
150+
151+
result = runner.invoke(cli, ["markers", tmp_path.as_posix()])
152+
153+
assert result.exit_code == ExitCode.OK
154+
assert "persist" in result.output
155+
156+
157+
def test_collection_fails_for_ambiguous_lockfile_ids_with_missing_product_state(
158+
runner, tmp_path
159+
):
160+
source = """
161+
from dataclasses import dataclass, field
162+
from pathlib import Path
163+
from typing import Annotated, Any
164+
165+
from pytask import Product
166+
167+
@dataclass
168+
class CustomNode:
169+
name: str
170+
filepath: Path
171+
signature: str
172+
attributes: dict[Any, Any] = field(default_factory=dict)
173+
174+
def state(self):
175+
if not self.filepath.exists():
176+
return None
177+
return self.filepath.read_text()
178+
179+
def load(self, is_product=False):
180+
return self if is_product else self.filepath.read_text()
181+
182+
def save(self, value):
183+
self.filepath.write_text(value)
184+
185+
def task_example(
186+
dependency=CustomNode(
187+
name="dup", filepath=Path("in.txt"), signature="signature-a"
188+
),
189+
product: Annotated[CustomNode, Product] = CustomNode(
190+
name="dup", filepath=Path("out.txt"), signature="signature-b"
191+
),
192+
):
193+
product.save(dependency.upper())
194+
"""
195+
tmp_path.joinpath("task_module.py").write_text(textwrap.dedent(source))
196+
tmp_path.joinpath("in.txt").write_text("hello")
197+
134198
result = runner.invoke(cli, [tmp_path.as_posix()])
135199

136200
assert result.exit_code == ExitCode.COLLECTION_FAILED

0 commit comments

Comments
 (0)