diff --git a/mypy/checker_shared.py b/mypy/checker_shared.py index 0014d2c6fc88..d857dc6d7da3 100644 --- a/mypy/checker_shared.py +++ b/mypy/checker_shared.py @@ -13,7 +13,7 @@ from mypy.errors import ErrorWatcher from mypy.message_registry import ErrorMessage from mypy.nodes import ( - ArgKind, + ArgKinds, Context, Expression, FuncItem, @@ -69,7 +69,7 @@ def check_call( self, callee: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None = None, callable_node: Expression | None = None, @@ -85,7 +85,7 @@ def transform_callee_type( callable_name: str | None, callee: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None = None, object_type: Type | None = None, @@ -102,7 +102,7 @@ def check_method_call_by_name( method: str, base_type: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, original_type: Type | None = None, ) -> tuple[Type, Type]: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 73282c94be4e..0a1b03f4dd76 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -43,6 +43,7 @@ REVEAL_LOCALS, REVEAL_TYPE, ArgKind, + ArgKinds, AssertTypeExpr, AssignmentExpr, AwaitExpr, @@ -772,7 +773,7 @@ def check_protocol_issubclass(self, e: CallExpr) -> None: def check_typeddict_call( self, callee: TypedDictType, - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None], args: list[Expression], context: Context, @@ -1213,7 +1214,7 @@ def try_infer_partial_value_type_from_call( def apply_function_plugin( self, callee: CallableType, - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_types: list[Type], arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], @@ -1237,7 +1238,7 @@ def apply_function_plugin( formal_arg_types: list[list[Type]] = [[] for _ in range(num_formals)] formal_arg_exprs: list[list[Expression]] = [[] for _ in range(num_formals)] formal_arg_names: list[list[str | None]] = [[] for _ in range(num_formals)] - formal_arg_kinds: list[list[ArgKind]] = [[] for _ in range(num_formals)] + formal_arg_kinds: list[ArgKinds] = [[] for _ in range(num_formals)] for formal, actuals in enumerate(formal_to_actual): for actual in actuals: formal_arg_types[formal].append(arg_types[actual]) @@ -1287,7 +1288,7 @@ def apply_signature_hook( self, callee: FunctionLike, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, hook: Callable[[list[list[Expression]], CallableType], FunctionLike], ) -> FunctionLike: @@ -1319,7 +1320,7 @@ def apply_function_signature_hook( self, callee: FunctionLike, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None, signature_hook: Callable[[FunctionSigContext], FunctionLike], @@ -1337,7 +1338,7 @@ def apply_method_signature_hook( self, callee: FunctionLike, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None, object_type: Type, @@ -1362,7 +1363,7 @@ def transform_callee_type( callable_name: str | None, callee: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None = None, object_type: Type | None = None, @@ -1529,7 +1530,7 @@ def check_call( self, callee: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None = None, callable_node: Expression | None = None, @@ -1651,7 +1652,7 @@ def check_callable_call( self, callee: CallableType, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, arg_names: Sequence[str | None] | None, callable_node: Expression | None, @@ -1934,7 +1935,7 @@ def infer_arg_types_in_context( self, callee: CallableType, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, formal_to_actual: list[list[int]], ) -> list[Type]: """Infer argument expression types using a callable type as context. @@ -2056,7 +2057,7 @@ def infer_function_type_arguments( self, callee_type: CallableType, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], need_refresh: bool, @@ -2115,7 +2116,7 @@ def infer_function_type_arguments( if ( callee_type.special_sig == "dict" and len(inferred_args) == 2 - and (ARG_NAMED in arg_kinds or ARG_STAR2 in arg_kinds) + and (ARG_NAMED in arg_kinds or arg_kinds.has_star2) ): # HACK: Infer str key type for dict(...) with keyword args. The type system # can't represent this so we special case it, as this is a pretty common @@ -2195,7 +2196,7 @@ def infer_function_type_arguments_pass2( self, callee_type: CallableType, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], old_inferred_args: Sequence[Type | None], @@ -2330,7 +2331,7 @@ def check_argument_count( self, callee: CallableType, actual_types: list[Type], - actual_kinds: list[ArgKind], + actual_kinds: ArgKinds, actual_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], context: Context | None, @@ -2410,7 +2411,7 @@ def check_for_extra_actual_arguments( self, callee: CallableType, actual_types: list[Type], - actual_kinds: list[ArgKind], + actual_kinds: ArgKinds, actual_names: Sequence[str | None] | None, all_actuals: dict[int, int], context: Context, @@ -2484,7 +2485,7 @@ def missing_classvar_callable_note( def check_argument_types( self, arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, args: list[Expression], callee: CallableType, formal_to_actual: list[list[int]], @@ -2671,7 +2672,7 @@ def check_overload_call( self, callee: Overloaded, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, callable_name: str | None, object_type: Type | None, @@ -2813,7 +2814,7 @@ def check_overload_call( def plausible_overload_call_targets( self, arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, overload: Overloaded, ) -> list[CallableType]: @@ -2874,7 +2875,7 @@ def infer_overload_return_type( plausible_targets: list[CallableType], args: list[Expression], arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, callable_name: str | None, object_type: Type | None, @@ -2957,7 +2958,7 @@ def overload_erased_call_targets( self, plausible_targets: list[CallableType], arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, args: list[Expression], context: Context, @@ -3017,7 +3018,7 @@ def union_overload_result( plausible_targets: list[CallableType], args: list[Expression], arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, callable_name: str | None, object_type: Type | None, @@ -3217,7 +3218,7 @@ def combine_function_signatures(self, types: list[ProperType]) -> AnyType | Call def erased_signature_similarity( self, arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, args: list[Expression], callee: CallableType, @@ -3298,7 +3299,7 @@ def check_union_call( self, callee: UnionType, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, context: Context, ) -> tuple[Type, Type]: @@ -3857,7 +3858,7 @@ def check_method_call_by_name( method: str, base_type: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, original_type: Type | None = None, self_type: Type | None = None, @@ -3896,7 +3897,7 @@ def check_union_method_call_by_name( method: str, base_type: UnionType, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, original_type: Type | None = None, ) -> tuple[Type, Type]: @@ -3925,7 +3926,7 @@ def check_method_call( base_type: Type, method_type: Type, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, context: Context, ) -> tuple[Type, Type]: """Type check a call to a method with the given name and type on an object. @@ -5531,7 +5532,7 @@ def infer_lambda_type_using_context( # See https://github.com/python/mypy/issues/9927 return None, None - arg_kinds = [arg.kind for arg in e.arguments] + arg_kinds = ArgKinds(arg.kind for arg in e.arguments) if callable_ctx.is_ellipsis_args or ctx.param_spec() is not None: # Fill in Any arguments to match the arguments of the lambda. @@ -5542,7 +5543,7 @@ def infer_lambda_type_using_context( arg_names=e.arg_names.copy(), ) - if ARG_STAR in arg_kinds or ARG_STAR2 in arg_kinds: + if arg_kinds.has_any_star: # TODO treat this case appropriately return callable_ctx, None @@ -6445,7 +6446,7 @@ def is_non_empty_tuple(t: Type) -> bool: def is_duplicate_mapping( - mapping: list[int], actual_types: list[Type], actual_kinds: list[ArgKind] + mapping: list[int], actual_types: list[Type], actual_kinds: ArgKinds ) -> bool: return ( len(mapping) > 1 @@ -6585,7 +6586,7 @@ def any_causes_overload_ambiguity( items: list[CallableType], return_types: list[Type], arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, ) -> bool: """May an argument containing 'Any' cause ambiguous result type on call to overloaded function? diff --git a/mypy/constraints.py b/mypy/constraints.py index 96c0c7ccaf35..56674aa244a9 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -18,7 +18,7 @@ ARG_STAR2, CONTRAVARIANT, COVARIANT, - ArgKind, + ArgKinds, TypeInfo, ) from mypy.type_visitor import ALL_STRATEGY, BoolTypeQuery @@ -110,7 +110,7 @@ def __eq__(self, other: object) -> bool: def infer_constraints_for_callable( callee: CallableType, arg_types: Sequence[Type | None], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], context: ArgumentInferContext, diff --git a/mypy/fastparse.py b/mypy/fastparse.py index 6b2eb532003c..d4ccc5d3bf9c 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -21,6 +21,7 @@ TYPE_VAR_KIND, TYPE_VAR_TUPLE_KIND, ArgKind, + ArgKinds, Argument, AssertStmt, AssignmentExpr, @@ -1634,9 +1635,10 @@ def visit_Call(self, n: Call) -> CallExpr: arg_types = self.translate_expr_list( [a.value if isinstance(a, Starred) else a for a in args] + [k.value for k in keywords] ) - arg_kinds = [ARG_STAR if type(a) is Starred else ARG_POS for a in args] + [ - ARG_STAR2 if arg is None else ARG_NAMED for arg in keyword_names - ] + arg_kinds = ArgKinds( + [ARG_STAR if type(a) is Starred else ARG_POS for a in args] + + [ARG_STAR2 if arg is None else ARG_NAMED for arg in keyword_names] + ) e = CallExpr( self.visit(n.func), arg_types, @@ -1688,7 +1690,7 @@ def visit_JoinedStr(self, n: ast3.JoinedStr) -> Expression: del strs_to_join.items[-1:] join_method = MemberExpr(empty_string, "join") join_method.set_line(empty_string) - result_expression = CallExpr(join_method, [strs_to_join], [ARG_POS], [None]) + result_expression = CallExpr(join_method, [strs_to_join], ArgKinds([ARG_POS]), [None]) return self.set_line(result_expression, n) # FormattedValue(expr value) diff --git a/mypy/infer.py b/mypy/infer.py index cdc43797d3b1..48fa9691d389 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -11,7 +11,7 @@ infer_constraints, infer_constraints_for_callable, ) -from mypy.nodes import ArgKind +from mypy.nodes import ArgKinds from mypy.solve import solve_constraints from mypy.types import CallableType, Instance, Type, TypeVarLikeType @@ -33,7 +33,7 @@ class ArgumentInferContext(NamedTuple): def infer_function_type_arguments( callee_type: CallableType, arg_types: Sequence[Type | None], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None] | None, formal_to_actual: list[list[int]], context: ArgumentInferContext, diff --git a/mypy/messages.py b/mypy/messages.py index 6329cad687f6..51b9651354c4 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -41,6 +41,7 @@ COVARIANT, SYMBOL_FUNCBASE_TYPES, ArgKind, + ArgKinds, CallExpr, ClassDef, Context, @@ -2528,7 +2529,7 @@ def quote_type_string(type_string: str) -> str: def format_callable_args( arg_types: list[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: list[str | None], format: Callable[[Type], str], verbosity: int, diff --git a/mypy/nodes.py b/mypy/nodes.py index 040f3fc28dce..1922c1bf0e5c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -6,12 +6,12 @@ import os from abc import abstractmethod from collections import defaultdict -from collections.abc import Iterator, Sequence +from collections.abc import Iterable, Iterator, Sequence from enum import Enum, unique from typing import TYPE_CHECKING, Any, Callable, Final, Optional, TypeVar, Union, cast from typing_extensions import TypeAlias as _TypeAlias, TypeGuard -from mypy_extensions import trait +from mypy_extensions import mypyc_attr, trait import mypy.strconv from mypy.cache import ( @@ -860,7 +860,7 @@ def __init__( super().__init__() self.arguments = arguments or [] self.arg_names = [None if arg.pos_only else arg.variable.name for arg in self.arguments] - self.arg_kinds: list[ArgKind] = [arg.kind for arg in self.arguments] + self.arg_kinds = ArgKinds(arg.kind for arg in self.arguments) self.max_pos: int = self.arg_kinds.count(ARG_POS) + self.arg_kinds.count(ARG_OPT) self.type_args: list[TypeParam] | None = type_args self.body: Block = body or Block([]) @@ -1015,7 +1015,7 @@ def deserialize(cls, data: JsonDict) -> FuncDef: # NOTE: ret.info is set in the fixup phase. ret.arg_names = data["arg_names"] ret.original_first_arg = data.get("original_first_arg") - ret.arg_kinds = [ARG_KINDS[x] for x in data["arg_kinds"]] + ret.arg_kinds = ArgKinds(ARG_KINDS[x] for x in data["arg_kinds"]) ret.abstract_status = data["abstract_status"] ret.dataclass_transform_spec = ( DataclassTransformSpec.deserialize(data["dataclass_transform_spec"]) @@ -1057,7 +1057,7 @@ def read(cls, data: Buffer) -> FuncDef: read_flags(data, ret, FUNCDEF_FLAGS) # NOTE: ret.info is set in the fixup phase. ret.arg_names = read_str_opt_list(data) - ret.arg_kinds = [ARG_KINDS[ak] for ak in read_int_list(data)] + ret.arg_kinds = ArgKinds(ARG_KINDS[ak] for ak in read_int_list(data)) ret.abstract_status = read_int(data) if read_bool(data): ret.dataclass_transform_spec = DataclassTransformSpec.read(data) @@ -2207,6 +2207,52 @@ def is_star(self) -> bool: return self == ARG_STAR or self == ARG_STAR2 +@mypyc_attr(native_class=False) +class ArgKinds(list[ArgKind]): + def __init__(self, values: Iterable[ArgKind] = None) -> None: + if values is None: + super().__init__() + else: + super().__init__(values) + self.__count_cache: dict[ArgKind, int] = {} + self.__index_cache: dict[ArgKind, int] = {} + self.__positional_only: bool | None = None + + @property + def positional_only(self) -> bool: + pos_only = self.__positional_only + if pos_only is None: + pos_only = self.__positional_only = all(kind == ARG_POS for kind in self) + return pos_only + + @property + def has_star(self) -> bool: + return ARG_STAR in self + + @property + def has_star2(self) -> bool: + return ARG_STAR2 in self + + @property + def has_any_star(self) -> bool: + return any(kind.is_star() for kind in self) + + def copy(self) -> ArgKinds: + return ArgKinds(kind for kind in self) + + def count(self, kind: ArgKind) -> int: + count = self.__count_cache.get(kind) + if count is None: + count = self.__count_cache[kind] = super().count(kind) + return count + + def index(self, kind: ArgKind) -> int: + index = self.__index_cache.get(kind) + if index is None: + index = self.__index_cache[kind] = super().index(kind) + return index + + ARG_POS: Final = ArgKind.ARG_POS ARG_OPT: Final = ArgKind.ARG_OPT ARG_STAR: Final = ArgKind.ARG_STAR @@ -2232,7 +2278,7 @@ def __init__( self, callee: Expression, args: list[Expression], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: list[str | None], analyzed: Expression | None = None, ) -> None: @@ -2242,6 +2288,7 @@ def __init__( self.callee = callee self.args = args + assert isinstance(arg_kinds, ArgKinds), type(arg_kinds) self.arg_kinds = arg_kinds # ARG_ constants # Each name can be None if not a keyword argument. self.arg_names: list[str | None] = arg_names @@ -4798,9 +4845,7 @@ def get_member_expr_fullname(expr: MemberExpr) -> str | None: } -def check_arg_kinds( - arg_kinds: list[ArgKind], nodes: list[T], fail: Callable[[str, T], None] -) -> None: +def check_arg_kinds(arg_kinds: ArgKinds, nodes: list[T], fail: Callable[[str, T], None]) -> None: is_var_arg = False is_kw_arg = False seen_named = False diff --git a/mypy/plugin.py b/mypy/plugin.py index 9019e3c2256f..570cba7b5f47 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -128,7 +128,7 @@ class C: pass from mypy.lookup import lookup_fully_qualified from mypy.message_registry import ErrorMessage from mypy.nodes import ( - ArgKind, + ArgKinds, CallExpr, ClassDef, Context, @@ -185,7 +185,7 @@ def analyze_type(self, typ: Type, /) -> Type: @abstractmethod def analyze_callable_args( self, arglist: TypeList - ) -> tuple[list[Type], list[ArgKind], list[str | None]] | None: + ) -> tuple[list[Type], ArgKinds, list[str | None]] | None: """Find types, kinds, and names of arguments from extended callable syntax.""" raise NotImplementedError @@ -446,7 +446,7 @@ class FunctionSigContext(NamedTuple): # callback at least sometimes can infer a more precise type. class FunctionContext(NamedTuple): arg_types: list[list[Type]] # List of actual caller types for each formal argument - arg_kinds: list[list[ArgKind]] # Ditto for argument kinds, see nodes.ARG_* constants + arg_kinds: list[ArgKinds] # Ditto for argument kinds, see nodes.ARG_* constants # Names of formal parameters from the callee definition, # these will be sufficient in most cases. callee_arg_names: list[str | None] @@ -483,7 +483,7 @@ class MethodContext(NamedTuple): type: ProperType # Base object type for method call arg_types: list[list[Type]] # List of actual caller types for each formal argument # see FunctionContext for details about names and kinds - arg_kinds: list[list[ArgKind]] + arg_kinds: list[ArgKinds] callee_arg_names: list[str | None] arg_names: list[list[str | None]] default_return_type: Type # Return type inferred by mypy diff --git a/mypy/semanal.py b/mypy/semanal.py index 17dc9bfadc1f..332c60e95cf3 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -91,6 +91,7 @@ TYPE_VAR_TUPLE_KIND, VARIANCE_NOT_READY, ArgKind, + ArgKinds, AssertStmt, AssertTypeExpr, AssignmentExpr, @@ -4786,7 +4787,7 @@ def process_typevar_parameters( self, args: list[Expression], names: list[str | None], - kinds: list[ArgKind], + kinds: ArgKinds, num_values: int, context: Context, ) -> tuple[int, Type, Type] | None: diff --git a/mypy/suggestions.py b/mypy/suggestions.py index 45aa5ade47a4..12f5f121a11f 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -42,7 +42,7 @@ from mypy.nodes import ( ARG_STAR, ARG_STAR2, - ArgKind, + ArgKinds, CallExpr, Decorator, Expression, @@ -92,7 +92,7 @@ class PyAnnotateSignature(TypedDict): class Callsite(NamedTuple): path: str line: int - arg_kinds: list[list[ArgKind]] + arg_kinds: list[ArgKinds] callee_arg_names: list[str | None] arg_names: list[list[str | None]] arg_types: list[list[Type]] @@ -521,7 +521,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: def format_args( self, - arg_kinds: list[list[ArgKind]], + arg_kinds: list[ArgKinds], arg_names: list[list[str | None]], arg_types: list[list[Type]], ) -> str: diff --git a/mypy/test/testinfer.py b/mypy/test/testinfer.py index 9c18624e0283..e1fe0f197582 100644 --- a/mypy/test/testinfer.py +++ b/mypy/test/testinfer.py @@ -5,7 +5,16 @@ from mypy.argmap import map_actuals_to_formals from mypy.checker import DisjointDict, group_comparison_operands from mypy.literals import Key -from mypy.nodes import ARG_NAMED, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind, NameExpr +from mypy.nodes import ( + ARG_NAMED, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + ArgKind, + ArgKinds, + NameExpr, +) from mypy.test.helpers import Suite, assert_equal from mypy.test.typefixture import TypeFixture from mypy.types import AnyType, TupleType, Type, TypeOfAny @@ -109,8 +118,8 @@ def assert_map( def assert_vararg_map( self, - caller_kinds: list[ArgKind], - callee_kinds: list[ArgKind], + caller_kinds: ArgKinds, + callee_kinds: ArgKinds, expected: list[list[int]], vararg_type: Type, ) -> None: @@ -118,9 +127,7 @@ def assert_vararg_map( assert_equal(result, expected) -def expand_caller_kinds( - kinds_or_names: list[ArgKind | str], -) -> tuple[list[ArgKind], list[str | None]]: +def expand_caller_kinds(kinds_or_names: list[ArgKind | str]) -> tuple[ArgKinds, list[str | None]]: kinds = [] names: list[str | None] = [] for k in kinds_or_names: @@ -135,7 +142,7 @@ def expand_caller_kinds( def expand_callee_kinds( kinds_and_names: list[ArgKind | tuple[ArgKind, str]], -) -> tuple[list[ArgKind], list[str | None]]: +) -> tuple[ArgKinds, list[str | None]]: kinds = [] names: list[str | None] = [] for v in kinds_and_names: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index d7a07c9f48e3..4ea8d9451937 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -33,6 +33,7 @@ MISSING_FALLBACK, SYMBOL_FUNCBASE_TYPES, ArgKind, + ArgKinds, Context, Decorator, ImportFrom, @@ -1157,9 +1158,9 @@ def visit_callable_type( param_spec_invalid = True if param_spec_invalid: - if ARG_STAR in arg_kinds: + if arg_kinds.has_star: arg_types[arg_kinds.index(ARG_STAR)] = AnyType(TypeOfAny.from_error) - if ARG_STAR2 in arg_kinds: + if arg_kinds.has_star2: arg_types[arg_kinds.index(ARG_STAR2)] = AnyType(TypeOfAny.from_error) # If there were multiple (invalid) unpacks, the arg types list will become shorter, @@ -1603,9 +1604,9 @@ def refers_to_full_names(self, arg: UnboundType, names: Sequence[str]) -> bool: def analyze_callable_args( self, arglist: TypeList - ) -> tuple[list[Type], list[ArgKind], list[str | None]] | None: + ) -> tuple[list[Type], ArgKinds, list[str | None]] | None: args: list[Type] = [] - kinds: list[ArgKind] = [] + kinds = ArgKinds() names: list[str | None] = [] seen_unpack = False unpack_types: list[Type] = [] diff --git a/mypy/types.py b/mypy/types.py index 38c17e240ccf..5d3d7d0933b9 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -32,7 +32,16 @@ write_str_opt_list, write_tag, ) -from mypy.nodes import ARG_KINDS, ARG_POS, ARG_STAR, ARG_STAR2, INVARIANT, ArgKind, SymbolNode +from mypy.nodes import ( + ARG_KINDS, + ARG_POS, + ARG_STAR, + ARG_STAR2, + INVARIANT, + ArgKind, + ArgKinds, + SymbolNode, +) from mypy.options import Options from mypy.state import state from mypy.util import IdMapper @@ -1895,7 +1904,7 @@ class Parameters(ProperType): def __init__( self, arg_types: Sequence[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None], *, variables: Sequence[TypeVarLikeType] | None = None, @@ -1908,7 +1917,11 @@ def __init__( self.arg_types = list(arg_types) self.arg_kinds = arg_kinds self.arg_names = list(arg_names) - assert len(arg_types) == len(arg_kinds) == len(arg_names) + assert len(arg_types) == len(arg_kinds) == len(arg_names), ( + len(arg_types), + len(arg_kinds), + len(arg_names), + ) assert not any(isinstance(t, Parameters) for t in arg_types) self.min_args = arg_kinds.count(ARG_POS) self.is_ellipsis_args = is_ellipsis_args @@ -1918,7 +1931,7 @@ def __init__( def copy_modified( self, arg_types: Bogus[Sequence[Type]] = _dummy, - arg_kinds: Bogus[list[ArgKind]] = _dummy, + arg_kinds: Bogus[ArgKinds] = _dummy, arg_names: Bogus[Sequence[str | None]] = _dummy, *, variables: Bogus[Sequence[TypeVarLikeType]] = _dummy, @@ -2128,7 +2141,7 @@ def __init__( self, # maybe this should be refactored to take a Parameters object arg_types: Sequence[Type], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None], ret_type: Type, fallback: Instance, @@ -2149,7 +2162,11 @@ def __init__( unpack_kwargs: bool = False, ) -> None: super().__init__(line, column) - assert len(arg_types) == len(arg_kinds) == len(arg_names) + assert len(arg_types) == len(arg_kinds) == len(arg_names), ( + len(arg_types), + len(arg_kinds), + len(arg_names), + ) self.arg_types = list(arg_types) for t in self.arg_types: if isinstance(t, ParamSpecType): @@ -2186,7 +2203,7 @@ def __init__( def copy_modified( self: CT, arg_types: Bogus[Sequence[Type]] = _dummy, - arg_kinds: Bogus[list[ArgKind]] = _dummy, + arg_kinds: Bogus[ArgKinds] = _dummy, arg_names: Bogus[Sequence[str | None]] = _dummy, ret_type: Bogus[Type] = _dummy, fallback: Bogus[Instance] = _dummy, diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 125382145991..19ab46adc0d8 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -20,6 +20,7 @@ TYPE_VAR_KIND, TYPE_VAR_TUPLE_KIND, ArgKind, + ArgKinds, CallExpr, Decorator, Expression, @@ -371,7 +372,7 @@ def py_call( function: Value, arg_values: list[Value], line: int, - arg_kinds: list[ArgKind] | None = None, + arg_kinds: ArgKinds | None = None, arg_names: Sequence[str | None] | None = None, ) -> Value: return self.builder.py_call(function, arg_values, line, arg_kinds, arg_names) @@ -389,7 +390,7 @@ def gen_method_call( arg_values: list[Value], result_type: RType | None, line: int, - arg_kinds: list[ArgKind] | None = None, + arg_kinds: ArgKinds | None = None, arg_names: list[str | None] | None = None, ) -> Value: return self.builder.gen_method_call( diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 51bdc76495f2..fe3e852acf3b 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -18,6 +18,7 @@ from mypy.nodes import ( ArgKind, + ArgKinds, ClassDef, Decorator, FuncBase, @@ -619,7 +620,7 @@ def gen_glue( class ArgInfo(NamedTuple): args: list[Value] arg_names: list[str | None] - arg_kinds: list[ArgKind] + arg_kinds: ArgKinds def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> ArgInfo: @@ -633,7 +634,7 @@ def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> Ar arg.name if arg.kind.is_named() or (arg.kind.is_optional() and not arg.pos_only) else None for arg in rt_args ] - arg_kinds = [arg.kind for arg in rt_args] + arg_kinds = ArgKinds(arg.kind for arg in rt_args) return ArgInfo(args, arg_names, arg_kinds) @@ -692,9 +693,7 @@ def f(builder: IRBuilder, x: object) -> int: ... # We can do a passthrough *args/**kwargs with a native call, but if the # args need to get distributed out to arguments, we just let python handle it - if any(kind.is_star() for kind in arg_kinds) and any( - not arg.kind.is_star() for arg in target.decl.sig.args - ): + if arg_kinds.has_any_star and any(not arg.kind.is_star() for arg in target.decl.sig.args): do_pycall = True if do_pycall: diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 37f2add4abbd..3f8c5b34a5d7 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -11,7 +11,7 @@ from typing import Callable, Final, Optional from mypy.argmap import map_actuals_to_formals -from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind +from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind, ArgKinds from mypy.operators import op_methods, unary_op_methods from mypy.types import AnyType, TypeOfAny from mypyc.common import ( @@ -958,7 +958,7 @@ def py_call( function: Value, arg_values: list[Value], line: int, - arg_kinds: list[ArgKind] | None = None, + arg_kinds: ArgKinds | None = None, arg_names: Sequence[str | None] | None = None, ) -> Value: """Call a Python function (non-native and slow). @@ -970,7 +970,7 @@ def py_call( return result # If all arguments are positional, we can use py_call_op. - if arg_kinds is None or all(kind == ARG_POS for kind in arg_kinds): + if arg_kinds is None or arg_kinds.positional_only: return self.call_c(py_call_op, [function] + arg_values, line) # Otherwise fallback to py_call_with_posargs_op or py_call_with_kwargs_op. @@ -991,7 +991,7 @@ def _py_vector_call( function: Value, arg_values: list[Value], line: int, - arg_kinds: list[ArgKind] | None = None, + arg_kinds: ArgKinds | None = None, arg_names: Sequence[str | None] | None = None, ) -> Value | None: """Call function using the vectorcall API if possible. @@ -1009,7 +1009,7 @@ def _py_vector_call( arg_ptr = self.setup_rarray(object_rprimitive, coerced_args, object_ptr=True) else: arg_ptr = Integer(0, object_pointer_rprimitive) - num_pos = num_positional_args(arg_values, arg_kinds) + num_pos = arg_kinds.count(ARG_POS) if arg_kinds else len(arg_values) keywords = self._vectorcall_keywords(arg_names) value = self.call_c( py_vectorcall_op, @@ -1041,7 +1041,7 @@ def py_method_call( method_name: str, arg_values: list[Value], line: int, - arg_kinds: list[ArgKind] | None, + arg_kinds: ArgKinds | None, arg_names: Sequence[str | None] | None, ) -> Value: """Call a Python method (non-native and slow).""" @@ -1051,7 +1051,7 @@ def py_method_call( if result is not None: return result - if arg_kinds is None or all(kind == ARG_POS for kind in arg_kinds): + if arg_kinds is None or arg_kinds.positional_only: # Use legacy method call API method_name_reg = self.load_str(method_name) return self.call_c(py_method_call_op, [obj, method_name_reg] + arg_values, line) @@ -1066,7 +1066,7 @@ def _py_vector_method_call( method_name: str, arg_values: list[Value], line: int, - arg_kinds: list[ArgKind] | None, + arg_kinds: ArgKinds | None, arg_names: Sequence[str | None] | None, ) -> Value | None: """Call method using the vectorcall API if possible. @@ -1082,7 +1082,7 @@ def _py_vector_method_call( self.coerce(arg, object_rprimitive, line) for arg in [obj] + arg_values ] arg_ptr = self.setup_rarray(object_rprimitive, coerced_args, object_ptr=True) - num_pos = num_positional_args(arg_values, arg_kinds) + num_pos = arg_kinds.count(ARG_POS) if arg_kinds else len(arg_values) keywords = self._vectorcall_keywords(arg_names) value = self.call_c( py_vectorcall_method_op, @@ -1105,7 +1105,7 @@ def call( self, decl: FuncDecl, args: Sequence[Value], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None], line: int, *, @@ -1126,7 +1126,7 @@ def call( def native_args_to_positional( self, args: Sequence[Value], - arg_kinds: list[ArgKind], + arg_kinds: ArgKinds, arg_names: Sequence[str | None], sig: FuncSignature, line: int, @@ -1224,13 +1224,13 @@ def gen_method_call( arg_values: list[Value], result_type: RType | None, line: int, - arg_kinds: list[ArgKind] | None = None, + arg_kinds: ArgKinds | None = None, arg_names: list[str | None] | None = None, can_borrow: bool = False, ) -> Value: """Generate either a native or Python method call.""" # If we have *args, then fallback to Python method call. - if arg_kinds is not None and any(kind.is_star() for kind in arg_kinds): + if arg_kinds is not None and arg_kinds.has_any_star: return self.py_method_call(base, name, arg_values, line, arg_kinds, arg_names) # If the base type is one of ours, do a MethodCall @@ -1286,7 +1286,7 @@ def union_method_call( arg_values: list[Value], return_rtype: RType | None, line: int, - arg_kinds: list[ArgKind] | None, + arg_kinds: ArgKinds | None, arg_names: list[str | None] | None, ) -> Value: """Generate a method call with a union type for the object.""" @@ -2736,13 +2736,3 @@ def _create_dict(self, keys: list[Value], values: list[Value], line: int) -> Val def error(self, msg: str, line: int) -> None: assert self.errors is not None, "cannot generate errors in this compiler phase" self.errors.error(msg, self.module_path, line) - - -def num_positional_args(arg_values: list[Value], arg_kinds: list[ArgKind] | None) -> int: - if arg_kinds is None: - return len(arg_values) - num_pos = 0 - for kind in arg_kinds: - if kind == ARG_POS: - num_pos += 1 - return num_pos