Skip to content

Commit

Permalink
pythongh-124176: create_autospec must not change how dataclass defa…
Browse files Browse the repository at this point in the history
…ults are mocked
  • Loading branch information
sobolevn committed Sep 28, 2024
1 parent 165ed68 commit bd73cda
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
15 changes: 15 additions & 0 deletions Lib/test/test_unittest/testmock/testhelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,21 @@ class WithNonFields:
with self.assertRaisesRegex(AttributeError, msg):
mock.b

def test_dataclass_with_wider_default(self):
# If field defines an actual default, we don't need to change
# the default type. Since this is how it used to work before #124176
@dataclass
class WithWiderDefault:
narrow_default: int | None = field(default=30)

for mock in [
create_autospec(WithWiderDefault, instance=True),
create_autospec(WithWiderDefault()),
]:
with self.subTest(mock=mock):
self.assertIs(mock.narrow_default.__class__, int)


class TestCallList(unittest.TestCase):

def test_args_list_contains_call_list(self):
Expand Down
8 changes: 5 additions & 3 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2758,13 +2758,15 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
f'[object={spec!r}]')
is_async_func = _is_async_func(spec)

entries = [(entry, _missing) for entry in dir(spec)]
base_entries = {entry: _missing for entry in dir(spec)}
if is_type and instance and is_dataclass(spec):
dataclass_fields = fields(spec)
entries.extend((f.name, f.type) for f in dataclass_fields)
entries = {f.name: f.type for f in dataclass_fields}
entries.update(base_entries)
_kwargs = {'spec': [f.name for f in dataclass_fields]}
else:
_kwargs = {'spec': spec}
entries = base_entries

if spec_set:
_kwargs = {'spec_set': spec}
Expand Down Expand Up @@ -2822,7 +2824,7 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None,
_name='()', _parent=mock,
wraps=wrapped)

for entry, original in entries:
for entry, original in entries.items():
if _is_magic(entry):
# MagicMock already does the useful magic methods for us
continue
Expand Down

0 comments on commit bd73cda

Please sign in to comment.