Skip to content

Commit 6b43cfd

Browse files
committed
Preserve data catalog node attributes on reload
1 parent 68d6018 commit 6b43cfd

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

src/_pytask/data_catalog.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __post_init__(self) -> None:
101101
# Initialize the data catalog with persisted nodes from previous runs.
102102
for path in self.path.glob("*-node.pkl"):
103103
node = pickle.loads(path.read_bytes()) # noqa: S301
104-
node.attributes = {DATA_CATALOG_NAME_FIELD: self.name}
104+
node.attributes[DATA_CATALOG_NAME_FIELD] = self.name
105105
self._entries[node.name] = node
106106

107107
def __getitem__(self, name: str) -> PNode | PProvisionalNode:

tests/test_data_catalog.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import hashlib
4+
import pickle
35
import sys
46
import textwrap
57
from pathlib import Path
68

79
import pytest
810

11+
from _pytask.data_catalog_utils import DATA_CATALOG_NAME_FIELD
912
from pytask import DataCatalog
1013
from pytask import ExitCode
1114
from pytask import PathNode
@@ -198,6 +201,25 @@ def test_adding_a_python_node():
198201
assert isinstance(data_catalog["node"], PythonNode)
199202

200203

204+
def test_reloading_data_catalog_preserves_node_attributes(tmp_path):
205+
data_catalog = DataCatalog(_instance_path=tmp_path)
206+
_ = data_catalog["node"]
207+
assert data_catalog.path is not None
208+
209+
filename = hashlib.sha256(b"node").hexdigest()
210+
path_to_node = data_catalog.path / f"{filename}-node.pkl"
211+
212+
node = pickle.loads(path_to_node.read_bytes()) # noqa: S301
213+
node.attributes["custom"] = "value"
214+
path_to_node.write_bytes(pickle.dumps(node))
215+
216+
reloaded_data_catalog = DataCatalog(_instance_path=tmp_path)
217+
reloaded_node = reloaded_data_catalog["node"]
218+
219+
assert reloaded_node.attributes["custom"] == "value"
220+
assert reloaded_node.attributes[DATA_CATALOG_NAME_FIELD] == "default"
221+
222+
201223
def test_use_data_catalog_with_provisional_node(runner, tmp_path):
202224
source = """
203225
from pathlib import Path

0 commit comments

Comments
 (0)