diff --git a/CHANGES.md b/CHANGES.md index 674ef1f..712d156 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,9 @@ releases are available on [conda-forge](https://anaconda.org/conda-forge/dags). ## 0.5.0 +- :gh:`77` Fix `decorator_rename_arguments` by calling `get_free_arguments` inside the + decorator (:ghuser:`hmgaudecker`). + - :gh:`76` Add a couple of tests to bring coverage to 100% (:ghuser:`hmgaudecker`). - :gh:`75` Streamline public API (:ghuser:`hmgaudecker`). diff --git a/src/dags/signature.py b/src/dags/signature.py index 0392968..1f46ac6 100644 --- a/src/dags/signature.py +++ b/src/dags/signature.py @@ -3,9 +3,9 @@ import functools import inspect from collections.abc import Callable, Mapping, Sequence -from typing import Any, cast, overload +from typing import Any, overload -from dags.annotations import get_annotations +from dags.annotations import get_annotations, get_free_arguments from dags.exceptions import DagsError, InvalidFunctionArgumentsError from dags.typing import P, R @@ -230,7 +230,12 @@ def rename_arguments( # noqa: C901 def decorator_rename_arguments(func: Callable[P, R]) -> Callable[..., R]: old_signature = inspect.signature(func) - old_parameters: dict[str, inspect.Parameter] = dict(old_signature.parameters) + free_arguments = set(get_free_arguments(func)) + old_parameters: dict[str, inspect.Parameter] = { + name: param + for name, param in old_signature.parameters.items() + if name in free_arguments + } old_annotations = get_annotations(func) parameters: list[inspect.Parameter] = [] @@ -271,16 +276,7 @@ def wrapper_rename_arguments(*args: P.args, **kwargs: P.kwargs) -> R: wrapper_rename_arguments.__signature__ = signature # ty: ignore[unresolved-attribute] wrapper_rename_arguments.__annotations__ = annotations - # Preserve function type - if isinstance(func, functools.partial): - partial_wrapper = functools.partial( - wrapper_rename_arguments, *func.args, **func.keywords - ) - out = cast("Callable[P, R]", partial_wrapper) - else: - out = wrapper_rename_arguments - - return out + return wrapper_rename_arguments if func is not None: return decorator_rename_arguments(func) diff --git a/tests/test_signature.py b/tests/test_signature.py index f9d8684..2822cf9 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -1,6 +1,7 @@ # Required because tests assert that annotations are strings. from __future__ import annotations +import functools import inspect import pytest @@ -245,3 +246,57 @@ def test_with_signature_invalid_args_type_int() -> None: @with_signature(args=42) # type: ignore[arg-type] def f(*args, **kwargs): pass + + +def test_rename_arguments_partial_with_positional_bound_arg() -> None: + def f(a, b): + return a + b + + p = functools.partial(f, 1) + renamed = rename_arguments(p, mapper={"b": "x"}) + assert renamed(x=2) == 3 + + +def test_rename_arguments_partial_with_keyword_bound_arg() -> None: + def f(a, b): + return a * b + + p = functools.partial(f, b=10) + renamed = rename_arguments(p, mapper={"a": "x"}) + assert renamed(x=5) == 50 + + +def test_rename_arguments_partial_positional_call() -> None: + def f(a, b): + return f"{a}-{b}" + + p = functools.partial(f, "hello") + renamed = rename_arguments(p, mapper={"b": "x"}) + assert renamed("world") == "hello-world" + + +def test_rename_arguments_partial_bound_arg_not_in_mapper() -> None: + def f(a, b, c): + return a + b + c + + p = functools.partial(f, 100) + renamed = rename_arguments(p, mapper={"b": "x"}) + assert renamed(x=20, c=3) == 123 + + +def test_rename_arguments_partial_multiple_bound_args() -> None: + def f(a, b, c): + return a + b + c + + p = functools.partial(f, 10, c=30) + renamed = rename_arguments(p, mapper={"b": "y"}) + assert renamed(y=20) == 60 + + +def test_rename_arguments_decorator_on_partial() -> None: + def f(a, b): + return a + b + + p = functools.partial(f, 1) + renamed = rename_arguments(mapper={"b": "x"})(p) + assert renamed(x=2) == 3