|
3 | 3 | from dataclasses import dataclass |
4 | 4 | from typing import TYPE_CHECKING |
5 | 5 | from typing import Any |
| 6 | +from typing import cast |
6 | 7 |
|
7 | 8 | if TYPE_CHECKING: |
8 | 9 | from collections.abc import Callable |
| 10 | + from typing import Protocol |
9 | 11 |
|
10 | 12 | try: |
11 | 13 | from coiled.function import Function |
12 | 14 | except ImportError: |
13 | 15 |
|
14 | 16 | @dataclass |
15 | 17 | class Function: |
16 | | - cluster_kwargs: dict[str, Any] |
17 | | - environ: dict[str, Any] |
| 18 | + _cluster_kwargs: dict[str, Any] |
| 19 | + _environ: dict[str, Any] | None |
| 20 | + _local: bool |
| 21 | + _name: str |
18 | 22 | function: Callable[..., Any] | None |
19 | | - keepalive: str | None |
| 23 | + keepalive: Any | None |
20 | 24 |
|
21 | 25 |
|
22 | 26 | __all__ = ["Function"] |
23 | 27 |
|
24 | 28 |
|
25 | | -_MISSING = object() |
| 29 | +if TYPE_CHECKING: |
| 30 | + |
| 31 | + class _CoiledFunctionPrivateAttrs(Protocol): |
| 32 | + _cluster_kwargs: dict[str, Any] |
| 33 | + _environ: dict[str, Any] | None |
| 34 | + _local: bool |
| 35 | + _name: str |
| 36 | + keepalive: Any | None |
26 | 37 |
|
27 | 38 |
|
28 | | -def _get_coiled_attribute(func: Function, *names: str, default: Any = _MISSING) -> Any: |
29 | | - """Get an attribute from coiled function objects with private/public fallbacks.""" |
30 | | - for name in names: |
31 | | - value = getattr(func, name, _MISSING) |
32 | | - if value is not _MISSING: |
33 | | - return value |
34 | | - if default is not _MISSING: |
35 | | - return default |
36 | | - names_as_text = ", ".join(repr(name) for name in names) |
37 | | - msg = f"Cannot find coiled attribute(s) {names_as_text} on {func!r}." |
38 | | - raise AttributeError(msg) |
| 39 | +def _as_coiled_private_attrs(func: Function) -> _CoiledFunctionPrivateAttrs: |
| 40 | + """Cast to the private-attribute layout used by coiled's Function class.""" |
| 41 | + return cast("_CoiledFunctionPrivateAttrs", func) |
39 | 42 |
|
40 | 43 |
|
41 | 44 | def extract_coiled_function_kwargs(func: Function) -> dict[str, Any]: |
42 | 45 | """Extract the kwargs for a coiled function.""" |
| 46 | + coiled_function = _as_coiled_private_attrs(func) |
43 | 47 | return { |
44 | | - "cluster_kwargs": _get_coiled_attribute( |
45 | | - func, "_cluster_kwargs", "cluster_kwargs" |
46 | | - ), |
47 | | - "keepalive": func.keepalive, |
48 | | - "environ": _get_coiled_attribute(func, "_environ", "environ"), |
49 | | - "local": _get_coiled_attribute(func, "_local", "local", default=None), |
50 | | - "name": _get_coiled_attribute(func, "_name", "name", default=None), |
| 48 | + "cluster_kwargs": coiled_function._cluster_kwargs, |
| 49 | + "keepalive": coiled_function.keepalive, |
| 50 | + "environ": coiled_function._environ, |
| 51 | + "local": coiled_function._local, |
| 52 | + "name": coiled_function._name, |
51 | 53 | } |
0 commit comments