Skip to content

Commit

Permalink
Improve zero argument support for super() in dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobronium committed Sep 27, 2024
1 parent 0a3577b commit 2e0ca01
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 13 deletions.
64 changes: 51 additions & 13 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,11 +1219,6 @@ def _get_slots(cls):


def _update_func_cell_for__class__(f, oldcls, newcls):
# Returns True if we update a cell, else False.
if f is None:
# f will be None in the case of a property where not all of
# fget, fset, and fdel are used. Nothing to do in that case.
return False
try:
idx = f.__code__.co_freevars.index("__class__")
except ValueError:
Expand All @@ -1232,13 +1227,36 @@ def _update_func_cell_for__class__(f, oldcls, newcls):
# Fix the cell to point to the new class, if it's already pointing
# at the old class. I'm not convinced that the "is oldcls" test
# is needed, but other than performance can't hurt.
closure = f.__closure__[idx]
if closure.cell_contents is oldcls:
closure.cell_contents = newcls
cell = f.__closure__[idx]
if cell.cell_contents is oldcls:
cell.cell_contents = newcls
return True
return False


def _find_inner_functions(obj, _seen=None, _depth=0):
if _seen is None:
_seen = set()
if id(obj) in _seen:
return None
_seen.add(id(obj))

_depth += 1
if _depth > 2:
return None

obj = inspect.unwrap(obj)

for attr in dir(obj):
value = getattr(obj, attr, None)
if value is None:
continue
if isinstance(obj, types.FunctionType):
yield obj
return
yield from _find_inner_functions(value, _seen, _depth)


def _add_slots(cls, is_frozen, weakref_slot):
# Need to create a new class, since we can't set __slots__ after a
# class has been created, and the @dataclass decorator is called
Expand Down Expand Up @@ -1297,7 +1315,10 @@ def _add_slots(cls, is_frozen, weakref_slot):
# (the newly created one, which we're returning) and not the
# original class. We can break out of this loop as soon as we
# make an update, since all closures for a class will share a
# given cell.
# given cell. First we try to find a pure function/properties,
# and then fallback to inspecting custom descriptors.

custom_descriptors_to_check = []
for member in newcls.__dict__.values():
# If this is a wrapped function, unwrap it.
member = inspect.unwrap(member)
Expand All @@ -1306,10 +1327,27 @@ def _add_slots(cls, is_frozen, weakref_slot):
if _update_func_cell_for__class__(member, cls, newcls):
break
elif isinstance(member, property):
if (_update_func_cell_for__class__(member.fget, cls, newcls)
or _update_func_cell_for__class__(member.fset, cls, newcls)
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
break
for f in member.fget, member.fset, member.fdel:
if f is None:
continue
# unwrap once more in case function
# was wrapped before it became property
f = inspect.unwrap(f)
if _update_func_cell_for__class__(f, cls, newcls):
break
elif hasattr(member, "__get__") and not inspect.ismemberdescriptor(
member
):
# we don't want to inspect custom descriptors just yet
# there's still a chance we'll encounter a pure function
# or a property
custom_descriptors_to_check.append(member)
else:
# now let's ensure custom descriptors won't be left out
for descriptor in custom_descriptors_to_check:
for f in _find_inner_functions(descriptor):
if _update_func_cell_for__class__(f, cls, newcls):
break

return newcls

Expand Down
41 changes: 41 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5012,6 +5012,47 @@ def foo(self):

A().foo()

def test_wrapped_property(self):
def mydecorator(f):
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper

class B:
@property
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@property
@mydecorator
def foo(self):
return super().foo

self.assertEqual(A().foo, "bar")

def test_custom_descriptor(self):
class CustomDescriptor:
def __init__(self, f):
self._f = f

def __get__(self, instance, owner):
return self._f(instance)

class B:
def foo(self):
return "bar"

@dataclass(slots=True)
class A(B):
@CustomDescriptor
def foo(cls):
return super().foo()

self.assertEqual(A().foo, "bar")

def test_remembered_class(self):
# Apply the dataclass decorator manually (not when the class
# is created), so that we can keep a reference to the
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Modify dataclasses to enable zero argument support for ``super()`` when ``slots=True`` is
specified and custom descriptor is used or `property` function is wrapped.

0 comments on commit 2e0ca01

Please sign in to comment.