From 396e44ebd81e7e5d5a4faef297e132fec109bbb3 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Fri, 10 Mar 2023 22:31:41 +0000 Subject: [PATCH 01/11] [dataclass_transform] support implicit default for "init" parameter in field specifiers --- mypy/plugins/dataclasses.py | 32 +++++++++++++++++-- test-data/unit/check-dataclass-transform.test | 32 +++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index a68410765367..e2c7bc634462 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -39,6 +39,7 @@ ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( + _get_callee_type, _get_decorator_bool_argument, add_attribute_to_class, add_method_to_class, @@ -47,7 +48,7 @@ from mypy.semanal_shared import find_dataclass_transform_spec, require_bool_literal_argument from mypy.server.trigger import make_wildcard_trigger from mypy.state import state -from mypy.typeops import map_type_from_supertype +from mypy.typeops import map_type_from_supertype, try_getting_literals_from_type from mypy.types import ( AnyType, CallableType, @@ -509,7 +510,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: is_in_init_param = field_args.get("init") if is_in_init_param is None: - is_in_init = True + is_in_init = self._get_default_init_value_for_field_specifier(stmt.rvalue) else: is_in_init = bool(self._api.parse_bool(is_in_init_param)) @@ -738,6 +739,33 @@ def _get_bool_arg(self, name: str, default: bool) -> bool: return require_bool_literal_argument(self._api, expression, name, default) return default + def _get_default_init_value_for_field_specifier(self, call: Expression) -> bool: + """ + Find a default value for the `init` parameter of the specifier being called. If the + specifier's type signature includes an `init` parameter with a type of `Literal[True]` or + `Literal[False]`, return the appropriate boolean value from the literal. Otherwise, + fall back to the standard default of `True`. + """ + if not isinstance(call, CallExpr): + return True + + specifier_type = _get_callee_type(call) + print("type", specifier_type, type(specifier_type)) + if specifier_type is None: + return True + + parameter = specifier_type.argument_by_name("init") + print("parameter", parameter) + if parameter is None: + return True + + literals = try_getting_literals_from_type(parameter.typ, bool, "builtins.bool") + print("literals", literals) + if literals is None or len(literals) != 1: + return True + + return literals[0] + def add_dataclass_tag(info: TypeInfo) -> None: # The value is ignored, only the existence matters. diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index b0c1cdf56097..e8e7802d3072 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -328,6 +328,38 @@ Foo(a=1, b='bye') [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] +[case testDataclassTransformFieldSpecifierImplicitInit] +# flags: --python-version 3.11 +from typing import dataclass_transform, Literal, overload + +def init(*, init: Literal[True] = True): ... +def no_init(*, init: Literal[False] = False): ... + +@overload +def field_overload(*, custom: None, init: Literal[True] = True): ... +@overload +def field_overload(*, custom: str, init: Literal[False] = False): ... +def field_overload(*, custom, init): ... + +@dataclass_transform(field_specifiers=(init, no_init, field_overload)) +def my_dataclass(cls): return cls + +@my_dataclass +class Foo: + a: int = init() + b: int = field_overload(custom=None) + + bad1: int = no_init() + bad2: int = field_overload(custom="bad2") + +reveal_type(Foo) # N: Revealed type is "def (a: builtins.int, b: builtins.int) -> __main__.Foo" +Foo(a=1, b=2) +Foo(a=1, b=2, bad1=0) # E: Unexpected keyword argument "bad1" for "Foo" +Foo(a=1, b=2, bad2=0) # E: Unexpected keyword argument "bad2" for "Foo" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + [case testDataclassTransformOverloadsDecoratorOnOverload] # flags: --python-version 3.11 from typing import dataclass_transform, overload, Any, Callable, Type, Literal From 1ce6db106eeff1b0e145e35dd453c62d1c059c59 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Fri, 10 Mar 2023 23:46:30 +0000 Subject: [PATCH 02/11] remove debugging print statements --- mypy/plugins/dataclasses.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index e2c7bc634462..d1b074ca9e8e 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -750,17 +750,14 @@ def _get_default_init_value_for_field_specifier(self, call: Expression) -> bool: return True specifier_type = _get_callee_type(call) - print("type", specifier_type, type(specifier_type)) if specifier_type is None: return True parameter = specifier_type.argument_by_name("init") - print("parameter", parameter) if parameter is None: return True literals = try_getting_literals_from_type(parameter.typ, bool, "builtins.bool") - print("literals", literals) if literals is None or len(literals) != 1: return True From 54386426b191c0386882a4b6cd8cd5fdf2b546fa Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Tue, 4 Apr 2023 15:43:58 +0100 Subject: [PATCH 03/11] Implement simple overload resolution --- mypy/plugins/common.py | 52 ++++++++++++++++++++++++++++++++++++++++-- mypy/semanal.py | 10 ++++---- mypy/semanal_shared.py | 12 +++++++++- 3 files changed, 65 insertions(+), 9 deletions(-) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 0acf3e3a6369..2fb237c571c3 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -18,12 +18,15 @@ RefExpr, SymbolTableNode, Var, + NameExpr, ) +from mypy.argmap import map_actuals_to_formals from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface from mypy.semanal_shared import ( ALLOW_INCOMPATIBLE_OVERRIDE, require_bool_literal_argument, set_callable_name, + parse_bool, ) from mypy.typeops import ( # noqa: F401 # Part of public API try_getting_str_literals as try_getting_str_literals, @@ -34,8 +37,12 @@ Type, TypeType, TypeVarType, + NoneType, deserialize_type, get_proper_type, + AnyType, + TypeOfAny, + LiteralType, ) from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name @@ -87,6 +94,48 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None: return None +def find_shallow_matching_overload_item(overload: Overloaded, + call: CallExpr) -> CallableType | None: + """Perform limited lookup of a matching overload item. + + Full overload resolution is only supported during type checking, but plugins + sometimes need to resolve overloads. This can be used in some such use cases. + + Resolve overloads based on these things only: + + * Match using argument kinds and names + * If formal argument has type None, only accept the "None" expression in the callee + * If formal argument has type Literal[True] or Literal[False], only accept the + relevant bool literal + + Return the first matching overload item. + """ + for item in overload.items[:-1]: + ok = True + mapped = map_actuals_to_formals( + call.arg_kinds, call.arg_names, item.arg_kinds, item.arg_names, + lambda i: AnyType(TypeOfAny.special_form)) + for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped): + if kind.is_required() and not actuals: + # Missing required argument + ok = False + break + elif actuals: + args = [call.args[i] for i in actuals] + arg_type = get_proper_type(arg_type) + if isinstance(arg_type, NoneType): + if not any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args): + ok = False + break + elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool: + if not any(parse_bool(arg) == arg_type.value for arg in args): + ok = False + break + if ok: + return item + return overload.items[-1] + + def _get_callee_type(call: CallExpr) -> CallableType | None: """Return the type of the callee, regardless of its syntatic form.""" @@ -103,8 +152,7 @@ def _get_callee_type(call: CallExpr) -> CallableType | None: if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type: callee_node_type = get_proper_type(callee_node.type) if isinstance(callee_node_type, Overloaded): - # We take the last overload. - return callee_node_type.items[-1] + return find_shallow_matching_overload_item(callee_node_type, call) elif isinstance(callee_node_type, CallableType): return callee_node_type diff --git a/mypy/semanal.py b/mypy/semanal.py index 4ee18d5ff4d3..beb097c9131a 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -218,6 +218,7 @@ has_placeholder, require_bool_literal_argument, set_callable_name as set_callable_name, + parse_bool, ) from mypy.semanal_typeddict import TypedDictAnalyzer from mypy.tvar_scope import TypeVarLikeScope @@ -6462,12 +6463,9 @@ def is_initial_mangled_global(self, name: str) -> bool: return name == unmangle(name) + "'" def parse_bool(self, expr: Expression) -> bool | None: - if isinstance(expr, NameExpr): - if expr.fullname == "builtins.True": - return True - if expr.fullname == "builtins.False": - return False - return None + # This wrapper is preserved for plugin backward compatibility. New code + # should not use this and call the wrapped parse_bool() function directly. + return parse_bool(expr) def parse_str_literal(self, expr: Expression) -> str | None: """Attempt to find the string literal value of the given expression. Returns `None` if no diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index 03efbe6ca1b8..2e69108c92de 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -25,6 +25,7 @@ SymbolTable, SymbolTableNode, TypeInfo, + NameExpr, ) from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.tvar_scope import TypeVarLikeScope @@ -451,7 +452,7 @@ def require_bool_literal_argument( default: bool | None = None, ) -> bool | None: """Attempt to interpret an expression as a boolean literal, and fail analysis if we can't.""" - value = api.parse_bool(expression) + value = parse_bool(expression) if value is None: api.fail( f'"{name}" argument must be a True or False literal', expression, code=LITERAL_REQ @@ -459,3 +460,12 @@ def require_bool_literal_argument( return default return value + + +def parse_bool(expr: Expression) -> bool | None: + if isinstance(expr, NameExpr): + if expr.fullname == "builtins.True": + return True + if expr.fullname == "builtins.False": + return False + return None From 1e718c2610a7bb574350e9fdb1e30b40f26b49c1 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 10:50:13 +0100 Subject: [PATCH 04/11] Minor tweaks --- mypy/plugins/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 2fb237c571c3..021304c97b8c 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -95,7 +95,7 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None: def find_shallow_matching_overload_item(overload: Overloaded, - call: CallExpr) -> CallableType | None: + call: CallExpr) -> CallableType: """Perform limited lookup of a matching overload item. Full overload resolution is only supported during type checking, but plugins @@ -108,7 +108,7 @@ def find_shallow_matching_overload_item(overload: Overloaded, * If formal argument has type Literal[True] or Literal[False], only accept the relevant bool literal - Return the first matching overload item. + Return the first matching overload item, or the last one if nothing matches. """ for item in overload.items[:-1]: ok = True From f4ff5bad3e7a1206ec20151f87a32a5e105aa471 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 10:50:23 +0100 Subject: [PATCH 05/11] WIP unit tests --- mypy/test/testtypes.py | 64 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index ee0256e2057a..a06952184b7e 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -7,7 +7,19 @@ from mypy.indirection import TypeIndirectionVisitor from mypy.join import join_simple, join_types from mypy.meet import meet_types, narrow_declared_type -from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, INVARIANT +from mypy.nodes import ( + ARG_NAMED, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + CONTRAVARIANT, + COVARIANT, + INVARIANT, + CallExpr, + NameExpr, +) +from mypy.plugins.common import find_shallow_matching_overload_item from mypy.state import state from mypy.subtypes import is_more_precise, is_proper_subtype, is_same_type, is_subtype from mypy.test.helpers import Suite, assert_equal, assert_type, skip @@ -1287,3 +1299,53 @@ def assert_union_result(self, t: ProperType, expected: list[Type]) -> None: t2 = remove_instance_last_known_values(t) assert type(t2) is UnionType assert t2.items == expected + + +class ShallowOverloadMatchingSuite(Suite): + def setUp(self) -> None: + self.fx = TypeFixture() + + def test_simple(self) -> None: + fx = self.fx + ov = self.make_overload([[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_NAMED)]]) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 1) + + def assert_find_shallow_matching_overload_item( + self, ov: Overloaded, call: CallExpr, expected_index: int + ) -> None: + c = find_shallow_matching_overload_item(ov, call) + assert c in ov.items + assert ov.items.index(c) == expected_index + + def make_overload(self, items: list[list[tuple[str, Type, ArgKind]]]) -> Overloaded: + result = [] + for item in items: + arg_types = [] + arg_names = [] + arg_kinds = [] + for name, typ, kind in item: + arg_names.append(name) + arg_types.append(typ) + arg_kinds.append(kind) + result.append( + CallableType( + arg_types, arg_kinds, arg_names, ret_type=NoneType(), fallback=self.fx.o + ) + ) + return Overloaded(result) + + +def make_call(*items: tuple[str, str | None]) -> CallExpr: + args = [] + arg_names = [] + arg_kinds = [] + for arg, name in items: + args.append(NameExpr(arg)) + arg_names.append(name) + if name: + arg_kinds.append(ARG_NAMED) + else: + arg_kinds.append(ARG_POS) + return CallExpr(NameExpr("f"), args, arg_kinds, arg_names) From d5a99ea19fce9a3d16f8ed17973abd26a4ae4c7f Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 13:48:19 +0100 Subject: [PATCH 06/11] Improvements --- mypy/plugins/common.py | 8 ++++++++ mypy/test/testtypes.py | 29 ++++++++++++++++++++++++++++- mypy/test/typefixture.py | 4 ++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 021304c97b8c..8ebd2ab3e9e0 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -115,6 +115,14 @@ def find_shallow_matching_overload_item(overload: Overloaded, mapped = map_actuals_to_formals( call.arg_kinds, call.arg_names, item.arg_kinds, item.arg_names, lambda i: AnyType(TypeOfAny.special_form)) + + # Look for extra actuals + matched_actuals = set() + for actuals in mapped: + matched_actuals.update(actuals) + if any(i not in matched_actuals for i in range(len(call.args))): + ok = False + for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped): if kind.is_required() and not actuals: # Missing required argument diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index a06952184b7e..d4802574a003 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1308,9 +1308,33 @@ def setUp(self) -> None: def test_simple(self) -> None: fx = self.fx ov = self.make_overload([[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_NAMED)]]) + # Match first only self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0) + # Match second only self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1) + # No match -- invalid keyword arg name self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 1) + # No match -- missing arg + self.assert_find_shallow_matching_overload_item(ov, make_call(), 1) + # No match -- extra arg + self.assert_find_shallow_matching_overload_item( + ov, make_call(("foo", "x"), ("foo", "z")), 1 + ) + + def test_match_using_types(self) -> None: + fx = self.fx + ov = self.make_overload( + [ + [("x", fx.nonet, ARG_POS)], + [("x", fx.lit_false, ARG_POS)], + [("x", fx.lit_true, ARG_POS)], + [("x", fx.anyt, ARG_POS)], + ] + ) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.False", None)), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.True", None)), 2) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", None)), 3) def assert_find_shallow_matching_overload_item( self, ov: Overloaded, call: CallExpr, expected_index: int @@ -1342,7 +1366,10 @@ def make_call(*items: tuple[str, str | None]) -> CallExpr: arg_names = [] arg_kinds = [] for arg, name in items: - args.append(NameExpr(arg)) + shortname = arg.split(".")[-1] + n = NameExpr(shortname) + n.fullname = arg + args.append(n) arg_names.append(name) if name: arg_kinds.append(ARG_NAMED) diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py index d12e7abab0e2..1013b87c213f 100644 --- a/mypy/test/typefixture.py +++ b/mypy/test/typefixture.py @@ -136,6 +136,7 @@ def make_type_var( self.type_type = Instance(self.type_typei, []) # type self.function = Instance(self.functioni, []) # function TODO self.str_type = Instance(self.str_type_info, []) + self.bool_type = Instance(self.bool_type_info, []) self.a = Instance(self.ai, []) # A self.b = Instance(self.bi, []) # B self.c = Instance(self.ci, []) # C @@ -197,6 +198,9 @@ def make_type_var( self.lit_str2_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str2) self.lit_str3_inst = Instance(self.str_type_info, [], last_known_value=self.lit_str3) + self.lit_false = LiteralType(False, self.bool_type) + self.lit_true = LiteralType(True, self.bool_type) + self.type_a = TypeType.make_normalized(self.a) self.type_b = TypeType.make_normalized(self.b) self.type_c = TypeType.make_normalized(self.c) From 07c04355b5b20ed0fd4565b877bafda51f34c187 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 13:55:59 +0100 Subject: [PATCH 07/11] Add unit tests --- mypy/test/testtypes.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index d4802574a003..e708a5c5a896 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1336,6 +1336,34 @@ def test_match_using_types(self) -> None: self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.True", None)), 2) self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", None)), 3) + def test_optional_arg(self) -> None: + fx = self.fx + ov = self.make_overload( + [[("x", fx.anyt, ARG_NAMED)], [("y", fx.anyt, ARG_OPT)], [("z", fx.anyt, ARG_NAMED)]] + ) + self.assert_find_shallow_matching_overload_item(ov, make_call(), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "y")), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "z")), 2) + + def test_two_args(self) -> None: + fx = self.fx + ov = self.make_overload( + [ + [("x", fx.nonet, ARG_OPT), ("y", fx.anyt, ARG_OPT)], + [("x", fx.anyt, ARG_OPT), ("y", fx.anyt, ARG_OPT)], + ] + ) + self.assert_find_shallow_matching_overload_item(ov, make_call(), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", "x")), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", "x")), 1) + self.assert_find_shallow_matching_overload_item( + ov, make_call(("foo", "y"), ("None", "x")), 0 + ) + self.assert_find_shallow_matching_overload_item( + ov, make_call(("foo", "y"), ("bar", "x")), 1 + ) + def assert_find_shallow_matching_overload_item( self, ov: Overloaded, call: CallExpr, expected_index: int ) -> None: From 1919bcf18aac0acc079f92e0d838ef57f07710dc Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 13:58:03 +0100 Subject: [PATCH 08/11] Fix type check --- mypy/test/testtypes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index e708a5c5a896..61a6b813ee2c 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -16,7 +16,9 @@ CONTRAVARIANT, COVARIANT, INVARIANT, + ArgKind, CallExpr, + Expression, NameExpr, ) from mypy.plugins.common import find_shallow_matching_overload_item @@ -1390,7 +1392,7 @@ def make_overload(self, items: list[list[tuple[str, Type, ArgKind]]]) -> Overloa def make_call(*items: tuple[str, str | None]) -> CallExpr: - args = [] + args: list[Expression] = [] arg_names = [] arg_kinds = [] for arg, name in items: From 61259c7ca23d922f94c21dfa21607e67d749d001 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 14:22:08 +0100 Subject: [PATCH 09/11] Fix style --- mypy/plugins/common.py | 25 ++++++++++++++----------- mypy/semanal.py | 5 ++--- mypy/semanal_shared.py | 2 +- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 8ebd2ab3e9e0..c08c2403e2d2 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -1,5 +1,6 @@ from __future__ import annotations +from mypy.argmap import map_actuals_to_formals from mypy.fixup import TypeFixer from mypy.nodes import ( ARG_POS, @@ -13,36 +14,35 @@ Expression, FuncDef, JsonDict, + NameExpr, Node, PassStmt, RefExpr, SymbolTableNode, Var, - NameExpr, ) -from mypy.argmap import map_actuals_to_formals from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface from mypy.semanal_shared import ( ALLOW_INCOMPATIBLE_OVERRIDE, + parse_bool, require_bool_literal_argument, set_callable_name, - parse_bool, ) from mypy.typeops import ( # noqa: F401 # Part of public API try_getting_str_literals as try_getting_str_literals, ) from mypy.types import ( + AnyType, CallableType, + LiteralType, + NoneType, Overloaded, Type, + TypeOfAny, TypeType, TypeVarType, - NoneType, deserialize_type, get_proper_type, - AnyType, - TypeOfAny, - LiteralType, ) from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name @@ -94,8 +94,7 @@ def _get_argument(call: CallExpr, name: str) -> Expression | None: return None -def find_shallow_matching_overload_item(overload: Overloaded, - call: CallExpr) -> CallableType: +def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType: """Perform limited lookup of a matching overload item. Full overload resolution is only supported during type checking, but plugins @@ -113,8 +112,12 @@ def find_shallow_matching_overload_item(overload: Overloaded, for item in overload.items[:-1]: ok = True mapped = map_actuals_to_formals( - call.arg_kinds, call.arg_names, item.arg_kinds, item.arg_names, - lambda i: AnyType(TypeOfAny.special_form)) + call.arg_kinds, + call.arg_names, + item.arg_kinds, + item.arg_names, + lambda i: AnyType(TypeOfAny.special_form), + ) # Look for extra actuals matched_actuals = set() diff --git a/mypy/semanal.py b/mypy/semanal.py index beb097c9131a..84573790fd68 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -216,9 +216,9 @@ calculate_tuple_fallback, find_dataclass_transform_spec, has_placeholder, + parse_bool, require_bool_literal_argument, set_callable_name as set_callable_name, - parse_bool, ) from mypy.semanal_typeddict import TypedDictAnalyzer from mypy.tvar_scope import TypeVarLikeScope @@ -6463,8 +6463,7 @@ def is_initial_mangled_global(self, name: str) -> bool: return name == unmangle(name) + "'" def parse_bool(self, expr: Expression) -> bool | None: - # This wrapper is preserved for plugin backward compatibility. New code - # should not use this and call the wrapped parse_bool() function directly. + # This wrapper is preserved for plugins. return parse_bool(expr) def parse_str_literal(self, expr: Expression) -> str | None: diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index 2e69108c92de..c86ed828b2b9 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -18,6 +18,7 @@ Decorator, Expression, FuncDef, + NameExpr, Node, OverloadedFuncDef, RefExpr, @@ -25,7 +26,6 @@ SymbolTable, SymbolTableNode, TypeInfo, - NameExpr, ) from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.tvar_scope import TypeVarLikeScope From 773ba9abbcaaf6675581210b64370d05e038ca14 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 14:34:30 +0100 Subject: [PATCH 10/11] Improve overload resolution with None --- mypy/plugins/common.py | 15 ++++++++++++++- mypy/test/testtypes.py | 24 ++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index c08c2403e2d2..b07899f91b69 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -34,6 +34,7 @@ from mypy.types import ( AnyType, CallableType, + Instance, LiteralType, NoneType, Overloaded, @@ -43,6 +44,7 @@ TypeVarType, deserialize_type, get_proper_type, + is_optional, ) from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name @@ -134,10 +136,21 @@ def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> elif actuals: args = [call.args[i] for i in actuals] arg_type = get_proper_type(arg_type) + arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args) if isinstance(arg_type, NoneType): - if not any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args): + if not arg_none: ok = False break + elif ( + arg_none + and not is_optional(arg_type) + and not ( + isinstance(arg_type, Instance) + and arg_type.type.fullname == "builtins.object" + ) + ): + ok = False + break elif isinstance(arg_type, LiteralType) and type(arg_type.value) is bool: if not any(parse_bool(arg) == arg_type.value for arg in args): ok = False diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 61a6b813ee2c..5b69a58ec890 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1338,6 +1338,30 @@ def test_match_using_types(self) -> None: self.assert_find_shallow_matching_overload_item(ov, make_call(("builtins.True", None)), 2) self.assert_find_shallow_matching_overload_item(ov, make_call(("foo", None)), 3) + def test_none_special_cases(self) -> None: + fx = self.fx + ov = self.make_overload( + [[("x", fx.callable(fx.nonet), ARG_POS)], [("x", fx.nonet, ARG_POS)]] + ) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) + ov = self.make_overload([[("x", fx.str_type, ARG_POS)], [("x", fx.nonet, ARG_POS)]]) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) + ov = self.make_overload( + [[("x", UnionType([fx.str_type, fx.a]), ARG_POS)], [("x", fx.nonet, ARG_POS)]] + ) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 1) + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) + ov = self.make_overload([[("x", fx.o, ARG_POS)], [("x", fx.nonet, ARG_POS)]]) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) + ov = self.make_overload( + [[("x", UnionType([fx.str_type, fx.nonet]), ARG_POS)], [("x", fx.nonet, ARG_POS)]] + ) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) + def test_optional_arg(self) -> None: fx = self.fx ov = self.make_overload( From 115e50c4d8f67ea5877ffb1558e5bbb7e8a5bc06 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Wed, 5 Apr 2023 14:38:52 +0100 Subject: [PATCH 11/11] Fix another special case --- mypy/plugins/common.py | 1 + mypy/test/testtypes.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index b07899f91b69..67796ef15cf3 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -148,6 +148,7 @@ def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> isinstance(arg_type, Instance) and arg_type.type.fullname == "builtins.object" ) + and not isinstance(arg_type, AnyType) ): ok = False break diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 5b69a58ec890..6fe65675554b 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -1361,6 +1361,9 @@ def test_none_special_cases(self) -> None: ) self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) + ov = self.make_overload([[("x", fx.anyt, ARG_POS)], [("x", fx.nonet, ARG_POS)]]) + self.assert_find_shallow_matching_overload_item(ov, make_call(("None", None)), 0) + self.assert_find_shallow_matching_overload_item(ov, make_call(("func", None)), 0) def test_optional_arg(self) -> None: fx = self.fx