Skip to content

Commit 9dd5018

Browse files
committed
Remove attrs dependency.
1 parent 17e6a34 commit 9dd5018

29 files changed

Lines changed: 159 additions & 150 deletions

docs_src/how_to_guides/the_data_catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from dataclasses import dataclass
12
from pathlib import Path
23
from typing import Any
34

45
import cloudpickle
5-
from attrs import define
66

77

8-
@define
8+
@dataclass
99
class PickleNode:
1010
"""A node for pickle files.
1111

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ classifiers = [
2121
]
2222
dynamic = ["version"]
2323
dependencies = [
24-
"attrs>=21.3.0",
2524
"click>=8.1.8,!=8.2.0",
2625
"click-default-group>=1.2.4",
2726
"networkx>=2.4.0",

src/_pytask/cache.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,15 @@
55
import functools
66
import hashlib
77
import inspect
8+
from dataclasses import dataclass
9+
from dataclasses import field
810
from inspect import FullArgSpec
911
from typing import TYPE_CHECKING
1012
from typing import Any
1113
from typing import ParamSpec
1214
from typing import Protocol
1315
from typing import TypeVar
1416

15-
from attrs import define
16-
from attrs import field
17-
1817
from _pytask._hashlib import hash_value
1918

2019
if TYPE_CHECKING:
@@ -35,17 +34,17 @@ class HasCache(Protocol):
3534
cache: Cache
3635

3736

38-
@define
37+
@dataclass
3938
class CacheInfo:
4039
hits: int = 0
4140
misses: int = 0
4241

4342

44-
@define
43+
@dataclass
4544
class Cache:
46-
_cache: dict[str, Any] = field(factory=dict)
47-
_sentinel: Any = field(factory=object)
48-
cache_info: CacheInfo = field(factory=CacheInfo)
45+
_cache: dict[str, Any] = field(default_factory=dict)
46+
_sentinel: Any = field(default_factory=object)
47+
cache_info: CacheInfo = field(default_factory=CacheInfo)
4948

5049
def memoize(self, func: Callable[P, R]) -> Memoized[P, R]:
5150
func_module = getattr(func, "__module__", "")

src/_pytask/clean.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import itertools
77
import shutil
88
import sys
9+
from dataclasses import dataclass
910
from typing import TYPE_CHECKING
1011
from typing import Any
1112

1213
import click
13-
from attrs import define
1414

1515
from _pytask.click import ColoredCommand
1616
from _pytask.click import EnumChoice
@@ -243,7 +243,7 @@ def _find_all_unknown_paths(
243243
)
244244

245245

246-
@define(repr=False)
246+
@dataclass(repr=False)
247247
class _RecursivePathNode:
248248
"""A class for a path to a file or directory which recursively instantiates itself.
249249

src/_pytask/coiled_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from __future__ import annotations
22

3+
from dataclasses import dataclass
34
from typing import TYPE_CHECKING
45
from typing import Any
56

6-
from attrs import define
7-
87
if TYPE_CHECKING:
98
from collections.abc import Callable
109

1110
try:
1211
from coiled.function import Function
1312
except ImportError:
1413

15-
@define
14+
@dataclass
1615
class Function:
1716
cluster_kwargs: dict[str, Any]
1817
environ: dict[str, Any]

src/_pytask/collect_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
from __future__ import annotations
44

55
import inspect
6+
from dataclasses import replace
67
from typing import TYPE_CHECKING
78
from typing import Annotated
89
from typing import Any
910
from typing import get_origin
1011

11-
import attrs
12-
1312
from _pytask._inspect import get_annotations
1413
from _pytask.exceptions import NodeNotCollectedError
1514
from _pytask.models import NodeInfo
@@ -308,7 +307,7 @@ def collect_dependency(
308307
# If a node is a dependency and its value is not set, the node is a product in
309308
# another task and the value will be set there. Thus, we wrap the original node
310309
# in another node to retrieve the value after it is set.
311-
new_node = attrs.evolve(node, value=node)
310+
new_node = replace(node, value=node)
312311
node_info = node_info._replace(value=new_node)
313312

314313
collected_node = session.hook.pytask_collect_node(

src/_pytask/dag_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
from __future__ import annotations
44

55
import itertools
6+
from dataclasses import dataclass
7+
from dataclasses import field
68
from typing import TYPE_CHECKING
79

810
import networkx as nx
9-
from attrs import define
10-
from attrs import field
1111

1212
from _pytask.mark_utils import has_mark
1313

@@ -61,7 +61,7 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]:
6161
return itertools.chain(dag.predecessors(node), [node], dag.successors(node))
6262

6363

64-
@define
64+
@dataclass
6565
class TopologicalSorter:
6666
"""The topological sorter class.
6767
@@ -78,9 +78,9 @@ class TopologicalSorter:
7878
"""
7979

8080
dag: nx.DiGraph
81-
priorities: dict[str, int] = field(factory=dict)
82-
_nodes_processing: set[str] = field(factory=set)
83-
_nodes_done: set[str] = field(factory=set)
81+
priorities: dict[str, int] = field(default_factory=dict)
82+
_nodes_processing: set[str] = field(default_factory=set)
83+
_nodes_done: set[str] = field(default_factory=set)
8484

8585
@classmethod
8686
def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:

src/_pytask/data_catalog.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from _pytask.exceptions import NodeNotCollectedError
2121
from _pytask.models import NodeInfo
2222
from _pytask.node_protocols import PNode
23-
from _pytask.node_protocols import PPathNode
2423
from _pytask.node_protocols import PProvisionalNode
2524
from _pytask.node_protocols import warn_about_upcoming_attributes_field_on_nodes
2625
from _pytask.nodes import PickleNode
@@ -39,6 +38,14 @@ def _get_parent_path_of_data_catalog_module(stacklevel: int = 2) -> Path:
3938
return Path.cwd()
4039

4140

41+
def _is_path_node_type(node_type: type[Any]) -> bool:
42+
"""Return True if the class looks like a path-based node."""
43+
for cls in node_type.__mro__:
44+
if "path" in getattr(cls, "__annotations__", {}):
45+
return True
46+
return False
47+
48+
4249
@dataclass(kw_only=True)
4350
class DataCatalog:
4451
"""A data catalog.
@@ -115,7 +122,7 @@ def add(self, name: str, node: PNode | PProvisionalNode | Any = None) -> None:
115122

116123
if node is None:
117124
filename = hashlib.sha256(name.encode()).hexdigest()
118-
if isinstance(self.default_node, PPathNode):
125+
if _is_path_node_type(self.default_node):
119126
assert self.path is not None
120127
self._entries[name] = self.default_node(
121128
name=name, path=self.path / f"{filename}.pkl"

src/_pytask/explain.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from __future__ import annotations
44

5+
from dataclasses import dataclass
6+
from dataclasses import field
57
from typing import TYPE_CHECKING
68
from typing import Any
79
from typing import Literal
810

9-
from attrs import define
10-
from attrs import field
1111
from rich.text import Text
1212

1313
from _pytask.console import console
@@ -34,14 +34,14 @@
3434
]
3535

3636

37-
@define
37+
@dataclass
3838
class ChangeReason:
3939
"""Represents a reason why a node changed."""
4040

4141
node_name: str
4242
node_type: NodeType
4343
reason: ReasonType
44-
details: dict[str, Any] = field(factory=dict)
44+
details: dict[str, Any] = field(default_factory=dict)
4545
verbose: int = 1
4646

4747
def __rich_console__(
@@ -71,11 +71,11 @@ def __rich_console__(
7171
yield Text(f" • {self.node_name}: {self.reason}")
7272

7373

74-
@define
74+
@dataclass
7575
class TaskExplanation:
7676
"""Represents the explanation for why a task needs to be executed."""
7777

78-
reasons: list[ChangeReason] = field(factory=list)
78+
reasons: list[ChangeReason] = field(default_factory=list)
7979
task: PTask | None = None
8080
outcome: TaskOutcome | None = None
8181
verbose: int = 1

src/_pytask/live.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass
6+
from dataclasses import field
67
from typing import TYPE_CHECKING
78
from typing import Any
89
from typing import NamedTuple
910

1011
import click
11-
from attrs import define
12-
from attrs import field
1312
from rich.box import ROUNDED
1413
from rich.errors import LiveError
1514
from rich.live import Live
@@ -85,7 +84,7 @@ def pytask_execute(session: Session) -> Generator[None, None, None]:
8584
return (yield)
8685

8786

88-
@define(eq=False)
87+
@dataclass(eq=False)
8988
class LiveManager:
9089
"""A class for live displays during a session.
9190
@@ -106,7 +105,9 @@ class LiveManager:
106105
"""
107106

108107
_live: Live = field(
109-
factory=lambda: Live(renderable=None, console=console, auto_refresh=False)
108+
default_factory=lambda: Live(
109+
renderable=None, console=console, auto_refresh=False
110+
)
110111
)
111112

112113
def start(self) -> None:
@@ -157,7 +158,7 @@ class _ReportEntry(NamedTuple):
157158
task: PTask
158159

159160

160-
@define(eq=False, kw_only=True)
161+
@dataclass(eq=False, kw_only=True)
161162
class LiveExecution:
162163
"""A class for managing the table displaying task progress during the execution."""
163164

@@ -168,8 +169,8 @@ class LiveExecution:
168169
initial_status: TaskExecutionStatus = TaskExecutionStatus.RUNNING
169170
sort_final_table: bool = False
170171
n_tasks: int | str = "x"
171-
_reports: list[_ReportEntry] = field(factory=list)
172-
_running_tasks: dict[str, _TaskEntry] = field(factory=dict)
172+
_reports: list[_ReportEntry] = field(default_factory=list)
173+
_running_tasks: dict[str, _TaskEntry] = field(default_factory=dict)
173174

174175
@hookimpl(wrapper=True)
175176
def pytask_execute_build(self) -> Generator[None, None, None]:
@@ -306,7 +307,7 @@ def update_report(self, new_report: ExecutionReport) -> None:
306307
self._update_table()
307308

308309

309-
@define(eq=False, kw_only=True)
310+
@dataclass(eq=False, kw_only=True)
310311
class LiveCollection:
311312
"""A class for managing the live status during the collection."""
312313

0 commit comments

Comments
 (0)