Skip to content

Commit

Permalink
Add support for cached_properties to slotted attrs classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Danny Cooper committed Nov 7, 2023
1 parent 671c53c commit b1ea100
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 2 deletions.
48 changes: 47 additions & 1 deletion src/attr/_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import contextlib
import copy
import enum
import functools
import inspect
import linecache
import sys
import types
import typing

from _thread import RLock
from operator import itemgetter

# We need to import _compat itself in addition to the _compat members to avoid
Expand Down Expand Up @@ -598,6 +600,28 @@ def _transform_attrs(
return _Attributes((AttrsClass(attrs), base_attrs, base_attr_map))


def _make_cached_property_getattr(cached_properties, original_getattr=None):
lock = RLock() # This is global for the class, can that be avoided?

def __getattr__(instance, item: str):
func = cached_properties.get(item)
if func is not None:
with lock:
try:
# In case another thread has set it.
return object.__getattribute__(instance, item)
except AttributeError:
result = func(instance)
object.__setattr__(instance, item, result)
return result
elif original_getattr is not None:
return original_getattr(instance, item)
else:
raise AttributeError(item)

return __getattr__


def _frozen_setattrs(self, name, value):
"""
Attached to frozen classes as __setattr__.
Expand Down Expand Up @@ -858,9 +882,31 @@ def _create_slots_class(self):
):
names += ("__weakref__",)

cached_properties = {
name: cached_property.func
for name, cached_property in cd.items()
if isinstance(cached_property, functools.cached_property)
}

if cached_properties:
# Add cached properties to names for slotting.
names += tuple(cached_properties.keys())

if "__annotations__" in cd:
for name, func in cached_properties.items():
annotation = inspect.signature(func).return_annotation
if annotation is not inspect.Parameter.empty:
cd["__annotations__"][name] = annotation

cd["__getattr__"] = _make_cached_property_getattr(
cached_properties, cd.get("__getattr__")
)
del cd[name]

# We only add the names of attributes that aren't inherited.
# Setting __slots__ to inherited attributes wastes memory.
slot_names = [name for name in names if name not in base_names]

# There are slots for attributes from current class
# that are defined in parent classes.
# As their descriptors may be overridden by a child class,
Expand All @@ -874,6 +920,7 @@ def _create_slots_class(self):
cd.update(reused_slots)
if self._cache_hash:
slot_names.append(_hash_cache_field)

cd["__slots__"] = tuple(slot_names)

cd["__qualname__"] = self._cls.__qualname__
Expand Down Expand Up @@ -910,7 +957,6 @@ def _create_slots_class(self):
else:
if match:
set_closure_cell(cell, cls)

return cls

def add_repr(self, ns):
Expand Down
158 changes: 157 additions & 1 deletion tests/test_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
Unit tests for slots-related functionality.
"""

import functools
import pickle
import sys
import types
Expand All @@ -14,6 +14,7 @@
import pytest

import attr
import attrs

from attr._compat import PYPY, just_warn, make_set_closure_cell

Expand Down Expand Up @@ -747,6 +748,161 @@ def f(self):
assert B(17).f == 289


def test_slots_cached_property_allows_call():
"""
cached_property in slotted class allows call.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

assert A(11).f == 11


def test_slots_cached_property_class_does_not_have__dict__():
"""
slotted class with cached property has no __dict__ attribute.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

assert set(A.__slots__) == {"x", "f", "__weakref__"}
assert "__dict__" not in dir(A)


def test_slots_cached_property_infers_type():
"""
Infers type of cached property.
"""

@attrs.frozen(slots=True)
class A:
x: int

@functools.cached_property
def f(self) -> int:
return self.x

assert A.__annotations__ == {"x": int, "f": int}


def test_slots_sub_class_avoids_duplicated_slots():
"""
Duplicating the slots is a wast of memory.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

@attr.s(slots=True)
class B(A):
@functools.cached_property
def f(self):
return self.x * 2

assert B(1).f == 2
assert B.__slots__ == ()


def test_slots_sub_class_with_actual_slot():
"""
A sub-class can have an explicit attrs field that replaces a cached property.
"""

@attr.s(slots=True)
class A: # slots : (x, f)
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

@attr.s(slots=True)
class B(A):
f: int = attr.ib()

assert B(1, 2).f == 2
assert B.__slots__ == ()


def test_slots_cached_property_is_not_called_at_construction():
"""
A cached property function should only be called at property access point.
"""
call_count = 0

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
nonlocal call_count
call_count += 1
return self.x

A(1)
assert call_count == 0


def test_slots_cached_property_repeat_call_only_once():
"""
A cached property function should be called only once, on repeated attribute access.
"""
call_count = 0

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
nonlocal call_count
call_count += 1
return self.x

obj = A(1)
obj.f
obj.f
assert call_count == 1


def test_slots_cached_property_called_independent_across_instances():
"""
A cached property value should be specific to the given instance.
"""

@attr.s(slots=True)
class A:
x = attr.ib()

@functools.cached_property
def f(self):
return self.x

obj_1 = A(1)
obj_2 = A(2)

assert obj_1.f == 1
assert obj_2.f == 2


@attr.s(slots=True)
class A:
x = attr.ib()
Expand Down

0 comments on commit b1ea100

Please sign in to comment.