From 4b3722fa89505b1663110281c5341adc9a4be754 Mon Sep 17 00:00:00 2001 From: Alex Waygood Date: Sat, 11 Mar 2023 16:30:19 +0000 Subject: [PATCH] Improve the signatures of `expand_type` and `expand_type_by_instance` (#14879) By adding another overload, `CallableType -> CallableType`, we can avoid the need for several `cast`s across the code base. --- mypy/checker.py | 4 +--- mypy/checkexpr.py | 2 +- mypy/checkmember.py | 2 +- mypy/expandtype.py | 16 ++++++++++++++-- mypy/server/astdiff.py | 4 ++-- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index bd762942da48..c4a8d4205942 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1753,8 +1753,7 @@ def expand_typevars( result: list[tuple[FuncItem, CallableType]] = [] for substitutions in itertools.product(*subst): mapping = dict(substitutions) - expanded = cast(CallableType, expand_type(typ, mapping)) - result.append((expand_func(defn, mapping), expanded)) + result.append((expand_func(defn, mapping), expand_type(typ, mapping))) return result else: return [(defn, typ)] @@ -7111,7 +7110,6 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo exp_signature = expand_type( signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables} ) - assert isinstance(exp_signature, CallableType) return is_callable_compatible( exp_signature, other, is_compat=is_more_precise, ignore_return=True ) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 38b5c2419d95..ad1a7cca2074 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5518,7 +5518,7 @@ def merge_typevars_in_callables_by_name( variables.append(tv) rename[tv.id] = unique_typevars[name] - target = cast(CallableType, expand_type(target, rename)) + target = expand_type(target, rename) output.append(target) return output, variables diff --git a/mypy/checkmember.py b/mypy/checkmember.py index a2c580e13446..b0c422e1e4e0 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1150,7 +1150,7 @@ class B(A[str]): pass t = freshen_all_functions_type_vars(t) t = bind_self(t, original_type, is_classmethod=True) assert isuper is not None - t = cast(CallableType, expand_type_by_instance(t, isuper)) + t = expand_type_by_instance(t, isuper) freeze_all_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) elif isinstance(t, Overloaded): diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 7933283b24d6..d70b7108310f 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -47,6 +47,13 @@ ) +@overload +def expand_type( + typ: CallableType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ... +) -> CallableType: + ... + + @overload def expand_type( typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ... @@ -70,6 +77,11 @@ def expand_type( return typ.accept(ExpandTypeVisitor(env, allow_erased_callables)) +@overload +def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: + ... + + @overload def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ... @@ -133,7 +145,7 @@ def freshen_function_type_vars(callee: F) -> F: tv = ParamSpecType.new_unification_variable(v) tvs.append(tv) tvmap[v.id] = tv - fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvs) + fresh = expand_type(callee, tvmap).copy_modified(variables=tvs) return cast(F, fresh) else: assert isinstance(callee, Overloaded) @@ -346,7 +358,7 @@ def interpolate_args_for_unpack( ) return (arg_names, arg_kinds, arg_types) - def visit_callable_type(self, t: CallableType) -> Type: + def visit_callable_type(self, t: CallableType) -> CallableType: param_spec = t.param_spec() if param_spec is not None: repl = get_proper_type(self.variables.get(param_spec.id)) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index c942a5eb3b0f..83ae64fbc1a8 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' from __future__ import annotations -from typing import Sequence, Tuple, Union, cast +from typing import Sequence, Tuple, Union from typing_extensions import TypeAlias as _TypeAlias from mypy.expandtype import expand_type @@ -442,7 +442,7 @@ def normalize_callable_variables(self, typ: CallableType) -> CallableType: tv = v.copy_modified(id=tid) tvs.append(tv) tvmap[v.id] = tv - return cast(CallableType, expand_type(typ, tvmap)).copy_modified(variables=tvs) + return expand_type(typ, tvmap).copy_modified(variables=tvs) def visit_tuple_type(self, typ: TupleType) -> SnapshotItem: return ("TupleType", snapshot_types(typ.items))