Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-118033: Fix __weakref__ not set for generic dataclasses #118099

Merged
merged 6 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,10 +1199,17 @@ def _dataclass_setstate(self, state):

def _get_slots(cls):
match cls.__dict__.get('__slots__'):
# A class which does not define __slots__ at all is equivalent
# to a class defining __slots__ = ('__dict__', '__weakref__')
# `__dictoffset__` and `__weakrefoffset__` can tell us whether
# the base type has dict/weakref slots, in a way that works correctly
# for both Python classes and C extension types. Extension types
# don't use `__slots__` for slot creation
case None:
yield from ('__dict__', '__weakref__')
slots = []
if getattr(cls, '__weakrefoffset__', -1) != 0:
slots.append('__weakref__')
if getattr(cls, '__dictrefoffset__', -1) != 0:
slots.append('__dict__')
yield from slots
case str(slot):
yield slot
sobolevn marked this conversation as resolved.
Show resolved Hide resolved
# Slots may be any iterable, but we cannot handle an iterator
Expand Down
106 changes: 106 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3515,8 +3515,114 @@ class A:
class B(A):
pass

self.assertEqual(B.__slots__, ())
B()

def test_dataclass_derived_generic(self):
T = typing.TypeVar('T')

@dataclass(slots=True, weakref_slot=True)
class A(typing.Generic[T]):
pass
self.assertEqual(A.__slots__, ('__weakref__',))
self.assertTrue(A.__weakref__)
A()

@dataclass(slots=True, weakref_slot=True)
class B[T2]:
pass
self.assertEqual(B.__slots__, ('__weakref__',))
self.assertTrue(B.__weakref__)
B()

def test_dataclass_derived_generic_from_base(self):
T = typing.TypeVar('T')

class RawBase: ...

@dataclass(slots=True, weakref_slot=True)
class C1(typing.Generic[T], RawBase):
pass
self.assertEqual(C1.__slots__, ())
self.assertTrue(C1.__weakref__)
C1()
@dataclass(slots=True, weakref_slot=True)
class C2(RawBase, typing.Generic[T]):
pass
self.assertEqual(C2.__slots__, ())
self.assertTrue(C2.__weakref__)
C2()

@dataclass(slots=True, weakref_slot=True)
class D[T2](RawBase):
pass
self.assertEqual(D.__slots__, ())
self.assertTrue(D.__weakref__)
D()

def test_dataclass_derived_generic_from_slotted_base(self):
T = typing.TypeVar('T')

class WithSlots:
__slots__ = ('a', 'b')

@dataclass(slots=True, weakref_slot=True)
class E1(WithSlots, Generic[T]):
pass
self.assertEqual(E1.__slots__, ('__weakref__',))
self.assertTrue(E1.__weakref__)
E1()
@dataclass(slots=True, weakref_slot=True)
class E2(Generic[T], WithSlots):
pass
self.assertEqual(E2.__slots__, ('__weakref__',))
self.assertTrue(E2.__weakref__)
E2()

@dataclass(slots=True, weakref_slot=True)
class F[T2](WithSlots):
pass
self.assertEqual(F.__slots__, ('__weakref__',))
self.assertTrue(F.__weakref__)
F()

def test_dataclass_derived_generic_from_slotted_base(self):
T = typing.TypeVar('T')

class WithWeakrefSlot:
__slots__ = ('__weakref__',)

@dataclass(slots=True, weakref_slot=True)
class G1(WithWeakrefSlot, Generic[T]):
pass
self.assertEqual(G1.__slots__, ())
self.assertTrue(G1.__weakref__)
G1()
@dataclass(slots=True, weakref_slot=True)
class G2(Generic[T], WithWeakrefSlot):
pass
self.assertEqual(G2.__slots__, ())
self.assertTrue(G2.__weakref__)
G2()

@dataclass(slots=True, weakref_slot=True)
class H[T2](WithWeakrefSlot):
pass
self.assertEqual(H.__slots__, ())
self.assertTrue(H.__weakref__)
H()

def test_dataclass_slot_dict(self):
class WithDictSlot:
__slots__ = ('__dict__',)

@dataclass(slots=True)
class A(WithDictSlot): ...

self.assertEqual(A.__slots__, ())
self.assertEqual(A().__dict__, {})
A()


class TestDescriptors(unittest.TestCase):
def test_set_name(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix :func:`dataclasses.dataclass` not creating a ``__weakref__`` slot when
subclassing :class:`typing.Generic`.
Loading