Skip to content

Commit c6f37f2

Browse files
committed
Refactor coiled kwargs extraction adapter
1 parent 1dda71a commit c6f37f2

1 file changed

Lines changed: 24 additions & 22 deletions

File tree

src/_pytask/coiled_utils.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,51 @@
33
from dataclasses import dataclass
44
from typing import TYPE_CHECKING
55
from typing import Any
6+
from typing import cast
67

78
if TYPE_CHECKING:
89
from collections.abc import Callable
10+
from typing import Protocol
911

1012
try:
1113
from coiled.function import Function
1214
except ImportError:
1315

1416
@dataclass
1517
class Function:
16-
cluster_kwargs: dict[str, Any]
17-
environ: dict[str, Any]
18+
_cluster_kwargs: dict[str, Any]
19+
_environ: dict[str, Any] | None
20+
_local: bool
21+
_name: str
1822
function: Callable[..., Any] | None
19-
keepalive: str | None
23+
keepalive: Any | None
2024

2125

2226
__all__ = ["Function"]
2327

2428

25-
_MISSING = object()
29+
if TYPE_CHECKING:
30+
31+
class _CoiledFunctionPrivateAttrs(Protocol):
32+
_cluster_kwargs: dict[str, Any]
33+
_environ: dict[str, Any] | None
34+
_local: bool
35+
_name: str
36+
keepalive: Any | None
2637

2738

28-
def _get_coiled_attribute(func: Function, *names: str, default: Any = _MISSING) -> Any:
29-
"""Get an attribute from coiled function objects with private/public fallbacks."""
30-
for name in names:
31-
value = getattr(func, name, _MISSING)
32-
if value is not _MISSING:
33-
return value
34-
if default is not _MISSING:
35-
return default
36-
names_as_text = ", ".join(repr(name) for name in names)
37-
msg = f"Cannot find coiled attribute(s) {names_as_text} on {func!r}."
38-
raise AttributeError(msg)
39+
def _as_coiled_private_attrs(func: Function) -> _CoiledFunctionPrivateAttrs:
40+
"""Cast to the private-attribute layout used by coiled's Function class."""
41+
return cast("_CoiledFunctionPrivateAttrs", func)
3942

4043

4144
def extract_coiled_function_kwargs(func: Function) -> dict[str, Any]:
4245
"""Extract the kwargs for a coiled function."""
46+
coiled_function = _as_coiled_private_attrs(func)
4347
return {
44-
"cluster_kwargs": _get_coiled_attribute(
45-
func, "_cluster_kwargs", "cluster_kwargs"
46-
),
47-
"keepalive": func.keepalive,
48-
"environ": _get_coiled_attribute(func, "_environ", "environ"),
49-
"local": _get_coiled_attribute(func, "_local", "local", default=None),
50-
"name": _get_coiled_attribute(func, "_name", "name", default=None),
48+
"cluster_kwargs": coiled_function._cluster_kwargs,
49+
"keepalive": coiled_function.keepalive,
50+
"environ": coiled_function._environ,
51+
"local": coiled_function._local,
52+
"name": coiled_function._name,
5153
}

0 commit comments

Comments
 (0)