diff --git a/src/attr/_make.py b/src/attr/_make.py index 793bfd89d..10da0cc0f 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -556,6 +556,60 @@ def _make_cached_property_getattr(cached_properties, original_getattr, cls): )["__getattr__"] +def _closure_cells(item, seen): + """ + Yield closure cells for *item* and any wrapped functions it closes over. + """ + item_id = id(item) + if item_id in seen: + return + seen.add(item_id) + + if isinstance(item, (classmethod, staticmethod)): + item = item.__func__ + elif isinstance(item, property): + for accessor in (item.fget, item.fset, item.fdel): + if accessor is not None: + yield from _closure_cells(accessor, seen) + return + elif isinstance(item, cached_property): + item = item.func + elif isinstance(item, types.MethodType): + item = item.__func__ + elif not isinstance(item, types.FunctionType): + wrapped = getattr(item, "__wrapped__", None) + if wrapped is not None: + yield from _closure_cells(wrapped, seen) + return + + closure = item.__closure__ + if closure: + yield from closure + + for cell in closure: + try: + cell_contents = cell.cell_contents + except ValueError: + continue + + if isinstance( + cell_contents, + ( + types.FunctionType, + types.MethodType, + classmethod, + staticmethod, + property, + cached_property, + ), + ): + yield from _closure_cells(cell_contents, seen) + + wrapped = getattr(item, "__wrapped__", None) + if wrapped is not None: + yield from _closure_cells(wrapped, seen) + + def _frozen_setattrs(self, name, value): """ Attached to frozen classes as __setattr__. @@ -973,20 +1027,7 @@ def _create_slots_class(self): for item in itertools.chain( cls.__dict__.values(), additional_closure_functions_to_update ): - if isinstance(item, (classmethod, staticmethod)): - # Class- and staticmethods hide their functions inside. - # These might need to be rewritten as well. - closure_cells = getattr(item.__func__, "__closure__", None) - elif isinstance(item, property): - # Workaround for property `super()` shortcut (PY3-only). - # There is no universal way for other descriptors. - closure_cells = getattr(item.fget, "__closure__", None) - else: - closure_cells = getattr(item, "__closure__", None) - - if not closure_cells: # Catch None or the empty list. - continue - for cell in closure_cells: + for cell in _closure_cells(item, set()): try: match = cell.cell_contents is self._cls except ValueError: # noqa: PERF203 diff --git a/tests/test_slots.py b/tests/test_slots.py index a74c32b03..6e9ff462f 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -6,6 +6,7 @@ import functools import pickle +import types import weakref from unittest import mock @@ -16,6 +17,7 @@ import attrs from attr._compat import PY_3_14_PLUS, PYPY +from attr._make import _closure_cells # Pympler doesn't work on PyPy. @@ -426,7 +428,64 @@ class C2(C1Bare): assert {"x": 1, "y": 2, "z": "test"} == attr.asdict(c2) +def _function_closing_over(value): + def function(): + return value + + return function + + class TestClosureCellRewriting: + def test_closure_cells_handles_wrapped_callables(self): + """ + Closure cells are found inside wrapped function shapes. + """ + + marker = object() + function = _function_closing_over(marker) + wrapped_function = _function_closing_over(marker) + + def wrapper(): + pass + + wrapper.__wrapped__ = wrapped_function + cached_function = functools.lru_cache()(_function_closing_over(marker)) + + for item in ( + functools.cached_property(function), + types.MethodType(function, object()), + wrapper, + cached_function, + ): + assert any( + cell.cell_contents is marker + for cell in _closure_cells(item, set()) + ) + + def test_closure_cells_does_not_inspect_arbitrary_cell_contents(self): + """ + Closure cell discovery doesn't introspect unrelated user objects. + """ + + class Unrelated: + def __getattr__(self, name): + msg = f"unexpected attribute lookup: {name}" + raise AssertionError(msg) + + list(_closure_cells(_function_closing_over(Unrelated()), set())) + + def test_closure_cells_stops_on_wrapped_cycles(self): + """ + Closure cell discovery avoids revisiting wrapped functions. + """ + + def wrapper(): + pass + + wrapper.__wrapped__ = wrapper + + assert list(_closure_cells(wrapper, set())) == [] + def test_closure_cell_rewriting(self): """ Slotted classes support proper closure cell rewriting. @@ -497,6 +556,30 @@ def statmethod(): assert D.statmethod() is D + def test_decorated_method(self, slots): + """ + Slotted classes rewrite closure cells in decorated methods. + """ + + def decorated(method): + def wrapped(self, *args, **kwargs): + return method(self, *args, **kwargs) + + return wrapped + + @attr.s(slots=slots) + class A: + def method(self): + return "A" + + @attr.s(slots=slots) + class B(A): + @decorated + def method(self): + return super().method() + + assert B().method() == "A" + @pytest.mark.skipif(PYPY, reason="__slots__ only block weakref on CPython") def test_not_weakrefable():