diff --git a/src/attr/_make.py b/src/attr/_make.py index 553e9e559..a6f8d49c1 100644 --- a/src/attr/_make.py +++ b/src/attr/_make.py @@ -3,12 +3,14 @@ import contextlib import copy import enum +import functools import inspect import linecache import sys import types import typing +from _thread import RLock from operator import itemgetter # We need to import _compat itself in addition to the _compat members to avoid @@ -598,6 +600,28 @@ def _transform_attrs( return _Attributes((AttrsClass(attrs), base_attrs, base_attr_map)) +def _make_cached_property_getattr(cached_properties, original_getattr=None): + lock = RLock() # This is global for the class, can that be avoided? + + def __getattr__(instance, item: str): + func = cached_properties.get(item) + if func is not None: + with lock: + try: + # In case another thread has set it. + return object.__getattribute__(instance, item) + except AttributeError: + result = func(instance) + object.__setattr__(instance, item, result) + return result + elif original_getattr is not None: + return original_getattr(instance, item) + else: + raise AttributeError(item) + + return __getattr__ + + def _frozen_setattrs(self, name, value): """ Attached to frozen classes as __setattr__. @@ -858,9 +882,31 @@ def _create_slots_class(self): ): names += ("__weakref__",) + cached_properties = { + name: cached_property.func + for name, cached_property in cd.items() + if isinstance(cached_property, functools.cached_property) + } + + if cached_properties: + # Add cached properties to names for slotting. + names += tuple(cached_properties.keys()) + + if "__annotations__" in cd: + for name, func in cached_properties.items(): + annotation = inspect.signature(func).return_annotation + if annotation is not inspect.Parameter.empty: + cd["__annotations__"][name] = annotation + + cd["__getattr__"] = _make_cached_property_getattr( + cached_properties, cd.get("__getattr__") + ) + del cd[name] + # We only add the names of attributes that aren't inherited. # Setting __slots__ to inherited attributes wastes memory. slot_names = [name for name in names if name not in base_names] + # There are slots for attributes from current class # that are defined in parent classes. # As their descriptors may be overridden by a child class, @@ -874,6 +920,7 @@ def _create_slots_class(self): cd.update(reused_slots) if self._cache_hash: slot_names.append(_hash_cache_field) + cd["__slots__"] = tuple(slot_names) cd["__qualname__"] = self._cls.__qualname__ @@ -910,7 +957,6 @@ def _create_slots_class(self): else: if match: set_closure_cell(cell, cls) - return cls def add_repr(self, ns): diff --git a/tests/test_slots.py b/tests/test_slots.py index f26fdb9de..9fca72478 100644 --- a/tests/test_slots.py +++ b/tests/test_slots.py @@ -3,7 +3,7 @@ """ Unit tests for slots-related functionality. """ - +import functools import pickle import sys import types @@ -14,6 +14,7 @@ import pytest import attr +import attrs from attr._compat import PYPY, just_warn, make_set_closure_cell @@ -747,6 +748,161 @@ def f(self): assert B(17).f == 289 +def test_slots_cached_property_allows_call(): + """ + cached_property in slotted class allows call. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + assert A(11).f == 11 + + +def test_slots_cached_property_class_does_not_have__dict__(): + """ + slotted class with cached property has no __dict__ attribute. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + assert set(A.__slots__) == {"x", "f", "__weakref__"} + assert "__dict__" not in dir(A) + + +def test_slots_cached_property_infers_type(): + """ + Infers type of cached property. + """ + + @attrs.frozen(slots=True) + class A: + x: int + + @functools.cached_property + def f(self) -> int: + return self.x + + assert A.__annotations__ == {"x": int, "f": int} + + +def test_slots_sub_class_avoids_duplicated_slots(): + """ + Duplicating the slots is a wast of memory. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + @attr.s(slots=True) + class B(A): + @functools.cached_property + def f(self): + return self.x * 2 + + assert B(1).f == 2 + assert B.__slots__ == () + + +def test_slots_sub_class_with_actual_slot(): + """ + A sub-class can have an explicit attrs field that replaces a cached property. + """ + + @attr.s(slots=True) + class A: # slots : (x, f) + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + @attr.s(slots=True) + class B(A): + f: int = attr.ib() + + assert B(1, 2).f == 2 + assert B.__slots__ == () + + +def test_slots_cached_property_is_not_called_at_construction(): + """ + A cached property function should only be called at property access point. + """ + call_count = 0 + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + nonlocal call_count + call_count += 1 + return self.x + + A(1) + assert call_count == 0 + + +def test_slots_cached_property_repeat_call_only_once(): + """ + A cached property function should be called only once, on repeated attribute access. + """ + call_count = 0 + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + nonlocal call_count + call_count += 1 + return self.x + + obj = A(1) + obj.f + obj.f + assert call_count == 1 + + +def test_slots_cached_property_called_independent_across_instances(): + """ + A cached property value should be specific to the given instance. + """ + + @attr.s(slots=True) + class A: + x = attr.ib() + + @functools.cached_property + def f(self): + return self.x + + obj_1 = A(1) + obj_2 = A(2) + + assert obj_1.f == 1 + assert obj_2.f == 2 + + @attr.s(slots=True) class A: x = attr.ib()