diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 0f14d706ccca..eef1b15f18ba 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -257,3 +257,12 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"' MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern' CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"' + +# Dataclass plugin +DATACLASS_VERSION_DEPENDENT_KEYWORD: Final = ( + 'Keyword argument "{}" for "dataclass" is only valid in Python {} and higher' +) +DATACLASS_TWO_KINDS_OF_SLOTS: Final = ( + '"{}" both defines "__slots__" and is used with "slots=True"' +) +DATACLASS_HASH_OVERRIDE: Final = 'Cannot overwrite attribute "__hash__" in class "{}"' diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 985a3f0fa6c7..974e46f50d1d 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -15,6 +15,16 @@ from mypy.fixup import TypeFixer +def _has_decorator_argument(ctx: ClassDefContext, name: str) -> bool: + """Returns whether or not some argument was passed to a decorator. + + We mostly need this because some arguments are version specific. + """ + if isinstance(ctx.reason, CallExpr): + return bool(name) and name in ctx.reason.arg_names + return False + + def _get_decorator_bool_argument( ctx: ClassDefContext, name: str, @@ -157,7 +167,7 @@ def add_method_to_class( def add_attribute_to_class( - api: SemanticAnalyzerPluginInterface, + api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface], cls: ClassDef, name: str, typ: Type, diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 87b42a499a1c..3ab34bd58d92 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -3,6 +3,7 @@ from typing import Dict, List, Set, Tuple, Optional from typing_extensions import Final +from mypy import message_registry from mypy.nodes import ( ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, ARG_STAR, ARG_STAR2, MDEF, Argument, AssignmentStmt, CallExpr, TypeAlias, Context, Expression, JsonDict, @@ -11,7 +12,8 @@ ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( - add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, add_attribute_to_class, + add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, + _has_decorator_argument, add_attribute_to_class, ) from mypy.typeops import map_type_from_supertype from mypy.types import ( @@ -31,7 +33,6 @@ 'dataclasses.field', } - SELF_TVAR_NAME: Final = "_DT" @@ -138,10 +139,14 @@ def transform(self) -> bool: 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), 'order': _get_decorator_bool_argument(self._ctx, 'order', False), 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), - 'slots': _get_decorator_bool_argument(self._ctx, 'slots', False), 'match_args': _get_decorator_bool_argument(self._ctx, 'match_args', True), + 'unsafe_hash': _get_decorator_bool_argument(self._ctx, 'unsafe_hash', False), } py_version = self._ctx.api.options.python_version + if py_version >= (3, 10): + decorator_arguments.update({ + 'slots': _get_decorator_bool_argument(self._ctx, 'slots', False), + }) # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip generating @@ -222,8 +227,8 @@ def transform(self) -> bool: else: self._propertize_callables(attributes) - if decorator_arguments['slots']: - self.add_slots(info, attributes, correct_version=py_version >= (3, 10)) + self.add_slots(info, decorator_arguments, attributes, current_version=py_version) + self.add_hash(info, decorator_arguments) self.reset_init_only_vars(info, attributes) @@ -249,18 +254,20 @@ def transform(self) -> bool: def add_slots(self, info: TypeInfo, + decorator_arguments: Dict[str, bool], attributes: List[DataclassAttribute], *, - correct_version: bool) -> None: - if not correct_version: + current_version: Tuple[int, ...]) -> None: + if _has_decorator_argument(self._ctx, 'slots') and current_version < (3, 10): # This means that version is lower than `3.10`, # it is just a non-existent argument for `dataclass` function. self._ctx.api.fail( - 'Keyword argument "slots" for "dataclass" ' - 'is only valid in Python 3.10 and higher', + message_registry.DATACLASS_VERSION_DEPENDENT_KEYWORD.format('slots', '3.10'), self._ctx.reason, ) return + if not decorator_arguments.get('slots'): + return # `slots` is not provided, skip. generated_slots = {attr.name for attr in attributes} if ((info.slots is not None and info.slots != generated_slots) @@ -270,15 +277,72 @@ def add_slots(self, # And `@dataclass(slots=True)` is used. # In runtime this raises a type error. self._ctx.api.fail( - '"{}" both defines "__slots__" and is used with "slots=True"'.format( - self._ctx.cls.name, - ), + message_registry.DATACLASS_TWO_KINDS_OF_SLOTS.format(self._ctx.cls.name), self._ctx.cls, ) return info.slots = generated_slots + def add_hash(self, + info: TypeInfo, + decorator_arguments: Dict[str, bool]) -> None: + unsafe_hash = decorator_arguments.get('unsafe_hash', False) + eq = decorator_arguments['eq'] + frozen = decorator_arguments['frozen'] + + existing_hash = info.names.get('__hash__') + existing = existing_hash and not existing_hash.plugin_generated + + # https://github.com/python/cpython/blob/24af9a40a8f85af813ea89998aa4e931fcc78cd9/Lib/dataclasses.py#L846 + if ((not unsafe_hash and not eq and not frozen) + or (not unsafe_hash and not eq and frozen)): + # "No __eq__, use the base class __hash__" + # It will use the base's class `__hash__` method by default. + # Nothing to do here. + pass + elif not unsafe_hash and eq and not frozen: + # "the default, not hashable" + # In this case, we just add `__hash__: None` to the body of the class + if not existing: + add_attribute_to_class( + self._ctx.api, + self._ctx.cls, + name='__hash__', + typ=NoneType(), + ) + elif not unsafe_hash and eq and frozen: + # "Frozen, so hashable, allows override" + # In this case we never raise an error, even if superclass definition + # is incompatible. + if not existing: + add_method( + self._ctx, + name='__hash__', + args=[], + return_type=self._ctx.api.named_type('builtins.int'), + ) + else: + # "Has no __eq__, but hashable" or + # "Not frozen, but hashable" or + # "Frozen, so hashable" + if existing: + # When class already has `__hash__` defined, we do not allow + # to override it. So, raise an error and do nothing. + self._ctx.api.fail( + message_registry.DATACLASS_HASH_OVERRIDE.format(self._ctx.cls.name), + self._ctx.cls, + ) + return + + # This means that class does not have `__hash__`, but we can add it. + add_method( + self._ctx, + name='__hash__', + args=[], + return_type=self._ctx.api.named_type('builtins.int'), + ) + def reset_init_only_vars(self, info: TypeInfo, attributes: List[DataclassAttribute]) -> None: """Remove init-only vars from the class and reset init var declarations.""" for attr in attributes: @@ -326,6 +390,12 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: lhs = stmt.lvalues[0] if not isinstance(lhs, NameExpr): continue + if lhs.name == '__hash__': + # Special case, annotation like `__hash__: None` is fine + # It is not a field, it is: + # + # https://github.com/python/mypy/issues/11495 + continue sym = cls.info.names.get(lhs.name) if sym is None: diff --git a/test-data/unit/check-dataclasses.test b/test-data/unit/check-dataclasses.test index 40c6b66d5c39..e24e715254b8 100644 --- a/test-data/unit/check-dataclasses.test +++ b/test-data/unit/check-dataclasses.test @@ -1483,10 +1483,15 @@ class DynamicDef: # E: "DynamicDef" both defines "__slots__" and is used with " x: int [builtins fixtures/dataclasses.pyi] + [case testDataclassWithSlotsArgBefore310] # flags: --python-version 3.9 from dataclasses import dataclass +@dataclass() # ok +class Correct: + x: int + @dataclass(slots=True) # E: Keyword argument "slots" for "dataclass" is only valid in Python 3.10 and higher class Some: x: int @@ -1496,6 +1501,139 @@ class Some: class Other: __slots__ = ('x',) x: int + +@dataclass(slots=False) # E: Keyword argument "slots" for "dataclass" is only valid in Python 3.10 and higher +class Third: + x: int +[builtins fixtures/dataclasses.pyi] + + +[case testDataclassWithUnsafeHashFalse] +# flags: --python-version 3.7 +from dataclasses import dataclass + +@dataclass(unsafe_hash=False, eq=False, frozen=False) +class FirstCase1: + x: int +reveal_type(FirstCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=False, eq=False, frozen=False) +class FirstCase2: + x: int + def __hash__(self) -> int: pass +reveal_type(FirstCase2(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=False, eq=False, frozen=False) +class FirstCase3: + x: int + __hash__: None # E: Incompatible types in assignment (expression has type "None", base class "object" defined the type as "Callable[[object], int]") +reveal_type(FirstCase3(1).__hash__) # N: Revealed type is "None" + +@dataclass(unsafe_hash=False, eq=False, frozen=True) +class FirstCase4: + x: int +reveal_type(FirstCase4(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=False, eq=False, frozen=True) +class FirstCase5: + x: int + def __hash__(self) -> int: pass +reveal_type(FirstCase5(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=False, eq=False, frozen=True) +class FirstCase6: + x: int + __hash__: None # E: Incompatible types in assignment (expression has type "None", base class "object" defined the type as "Callable[[object], int]") +reveal_type(FirstCase6(1).__hash__) # N: Revealed type is "None" + + +@dataclass(unsafe_hash=False, eq=True, frozen=False) +class SecondCase1: + x: int +reveal_type(SecondCase1(1).__hash__) # N: Revealed type is "None" + +@dataclass(unsafe_hash=False, eq=True, frozen=False) +class SecondCase2: + x: int + __hash__: None # E: Incompatible types in assignment (expression has type "None", base class "object" defined the type as "Callable[[object], int]") +reveal_type(SecondCase2(1).__hash__) # N: Revealed type is "None" + +@dataclass(unsafe_hash=False, eq=True, frozen=False) +class SecondCase3: + x: int + def __hash__(self) -> int: pass # Custom impl +reveal_type(SecondCase3(1).__hash__) # N: Revealed type is "def () -> builtins.int" + + +@dataclass(unsafe_hash=False, eq=True, frozen=True) +class ThirdCase1: + x: int +reveal_type(ThirdCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=False, eq=True, frozen=True) +class ThirdCase2: + x: int + def __hash__(self) -> int: pass # Custom impl +reveal_type(ThirdCase2(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=False, eq=True, frozen=True) +class ThirdCase3: + x: int + __hash__: None # E: Incompatible types in assignment (expression has type "None", base class "object" defined the type as "Callable[[object], int]") +reveal_type(ThirdCase3(1).__hash__) # N: Revealed type is "None" +[builtins fixtures/dataclasses.pyi] + + +[case testDataclassWithUnsafeHashTrue] +# flags: --python-version 3.7 +from dataclasses import dataclass + +@dataclass(unsafe_hash=True, eq=False, frozen=False) +class FirstCase1: + x: int +reveal_type(FirstCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=True, eq=False, frozen=False) +class FirstCase2: # E: Cannot overwrite attribute "__hash__" in class "FirstCase2" + x: int + def __hash__(self) -> int: pass +reveal_type(FirstCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + + +@dataclass(unsafe_hash=True, eq=False, frozen=True) +class SecondCase1: + x: int +reveal_type(SecondCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=True, eq=False, frozen=True) +class SecondCase2: # E: Cannot overwrite attribute "__hash__" in class "SecondCase2" + x: int + def __hash__(self) -> int: pass +reveal_type(SecondCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + + +@dataclass(unsafe_hash=True, eq=True, frozen=False) +class ThirdCase1: + x: int +reveal_type(ThirdCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=True, eq=True, frozen=False) +class ThirdCase2: # E: Cannot overwrite attribute "__hash__" in class "ThirdCase2" + x: int + def __hash__(self) -> int: pass +reveal_type(ThirdCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + + +@dataclass(unsafe_hash=True, eq=True, frozen=True) +class FourthCase1: + x: int +reveal_type(FourthCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" + +@dataclass(unsafe_hash=True, eq=True, frozen=True) +class FourthCase2: # E: Cannot overwrite attribute "__hash__" in class "FourthCase2" + x: int + def __hash__(self) -> int: pass +reveal_type(FourthCase1(1).__hash__) # N: Revealed type is "def () -> builtins.int" [builtins fixtures/dataclasses.pyi] @@ -1578,6 +1716,34 @@ A(1) A(a="foo") # E: Argument "a" to "A" has incompatible type "str"; expected "int" [builtins fixtures/dataclasses.pyi] +[case testDataclassHashDefault] +# flags: --python-version 3.7 +from dataclasses import dataclass + +@dataclass +class FirstCase: + x: int +reveal_type(FirstCase(1).__hash__) # N: Revealed type is "None" + +@dataclass +class SecondCase: + x: int + def __hash__(self) -> int: pass +reveal_type(SecondCase(1).__hash__) # N: Revealed type is "def () -> builtins.int" +[builtins fixtures/dataclasses.pyi] + +[case testSemanalTwoPassesWhileHashGeneration] +# flags: --python-version 3.7 +from dataclasses import dataclass + +a = f() # Forward ref to force two semantic analysis passes + +@dataclass(unsafe_hash=True) # Used to be an error here: +class C: # Cannot overwrite attribute "__hash__" in class "C" + x: str + +def f(): pass + [case testDataclassesCallableFrozen] # flags: --python-version 3.7 from dataclasses import dataclass diff --git a/test-data/unit/deps.test b/test-data/unit/deps.test index 884b10f166b0..644cb145aa46 100644 --- a/test-data/unit/deps.test +++ b/test-data/unit/deps.test @@ -1438,6 +1438,7 @@ class B(A): [out] -> , m -> + -> -> , m.B.__init__ -> -> diff --git a/test-data/unit/fixtures/dataclasses.pyi b/test-data/unit/fixtures/dataclasses.pyi index 206843a88b24..57186f454baf 100644 --- a/test-data/unit/fixtures/dataclasses.pyi +++ b/test-data/unit/fixtures/dataclasses.pyi @@ -12,6 +12,7 @@ class object: def __init__(self) -> None: pass def __eq__(self, o: object) -> bool: pass def __ne__(self, o: object) -> bool: pass + def __hash__(self) -> int: pass class type: pass class ellipsis: pass