Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ releases are available on [conda-forge](https://anaconda.org/conda-forge/dags).

## 0.5.0

- :gh:`77` Fix `decorator_rename_arguments` by calling `get_free_arguments` inside the
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Messed up in multiple ways; fixed in #79

decorator (:ghuser:`hmgaudecker`).

- :gh:`76` Add a couple of tests to bring coverage to 100% (:ghuser:`hmgaudecker`).

- :gh:`75` Streamline public API (:ghuser:`hmgaudecker`).
Expand Down
22 changes: 9 additions & 13 deletions src/dags/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import functools
import inspect
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast, overload
from typing import Any, overload

from dags.annotations import get_annotations
from dags.annotations import get_annotations, get_free_arguments
from dags.exceptions import DagsError, InvalidFunctionArgumentsError
from dags.typing import P, R

Expand Down Expand Up @@ -230,7 +230,12 @@ def rename_arguments( # noqa: C901

def decorator_rename_arguments(func: Callable[P, R]) -> Callable[..., R]:
old_signature = inspect.signature(func)
old_parameters: dict[str, inspect.Parameter] = dict(old_signature.parameters)
free_arguments = set(get_free_arguments(func))
old_parameters: dict[str, inspect.Parameter] = {
name: param
for name, param in old_signature.parameters.items()
if name in free_arguments
}
old_annotations = get_annotations(func)

parameters: list[inspect.Parameter] = []
Expand Down Expand Up @@ -271,16 +276,7 @@ def wrapper_rename_arguments(*args: P.args, **kwargs: P.kwargs) -> R:
wrapper_rename_arguments.__signature__ = signature # ty: ignore[unresolved-attribute]
wrapper_rename_arguments.__annotations__ = annotations

# Preserve function type
if isinstance(func, functools.partial):
partial_wrapper = functools.partial(
wrapper_rename_arguments, *func.args, **func.keywords
)
out = cast("Callable[P, R]", partial_wrapper)
else:
out = wrapper_rename_arguments

return out
return wrapper_rename_arguments

if func is not None:
return decorator_rename_arguments(func)
Expand Down
55 changes: 55 additions & 0 deletions tests/test_signature.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Required because tests assert that annotations are strings.
from __future__ import annotations

import functools
import inspect

import pytest
Expand Down Expand Up @@ -245,3 +246,57 @@ def test_with_signature_invalid_args_type_int() -> None:
@with_signature(args=42) # type: ignore[arg-type]
def f(*args, **kwargs):
pass


def test_rename_arguments_partial_with_positional_bound_arg() -> None:
def f(a, b):
return a + b

p = functools.partial(f, 1)
renamed = rename_arguments(p, mapper={"b": "x"})
assert renamed(x=2) == 3


def test_rename_arguments_partial_with_keyword_bound_arg() -> None:
def f(a, b):
return a * b

p = functools.partial(f, b=10)
renamed = rename_arguments(p, mapper={"a": "x"})
assert renamed(x=5) == 50


def test_rename_arguments_partial_positional_call() -> None:
def f(a, b):
return f"{a}-{b}"

p = functools.partial(f, "hello")
renamed = rename_arguments(p, mapper={"b": "x"})
assert renamed("world") == "hello-world"


def test_rename_arguments_partial_bound_arg_not_in_mapper() -> None:
def f(a, b, c):
return a + b + c

p = functools.partial(f, 100)
renamed = rename_arguments(p, mapper={"b": "x"})
assert renamed(x=20, c=3) == 123


def test_rename_arguments_partial_multiple_bound_args() -> None:
def f(a, b, c):
return a + b + c

p = functools.partial(f, 10, c=30)
renamed = rename_arguments(p, mapper={"b": "y"})
assert renamed(y=20) == 60


def test_rename_arguments_decorator_on_partial() -> None:
def f(a, b):
return a + b

p = functools.partial(f, 1)
renamed = rename_arguments(mapper={"b": "x"})(p)
assert renamed(x=2) == 3
Loading