Skip to content

Commit

Permalink
Improve the signatures of expand_type and expand_type_by_instance (
Browse files Browse the repository at this point in the history
…#14879)

By adding another overload, `CallableType -> CallableType`, we can avoid
the need for several `cast`s across the code base.
  • Loading branch information
AlexWaygood committed Mar 11, 2023
1 parent 106d57e commit 4b3722f
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 9 deletions.
4 changes: 1 addition & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,8 +1753,7 @@ def expand_typevars(
result: list[tuple[FuncItem, CallableType]] = []
for substitutions in itertools.product(*subst):
mapping = dict(substitutions)
expanded = cast(CallableType, expand_type(typ, mapping))
result.append((expand_func(defn, mapping), expanded))
result.append((expand_func(defn, mapping), expand_type(typ, mapping)))
return result
else:
return [(defn, typ)]
Expand Down Expand Up @@ -7111,7 +7110,6 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo
exp_signature = expand_type(
signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables}
)
assert isinstance(exp_signature, CallableType)
return is_callable_compatible(
exp_signature, other, is_compat=is_more_precise, ignore_return=True
)
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5518,7 +5518,7 @@ def merge_typevars_in_callables_by_name(
variables.append(tv)
rename[tv.id] = unique_typevars[name]

target = cast(CallableType, expand_type(target, rename))
target = expand_type(target, rename)
output.append(target)

return output, variables
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,7 +1150,7 @@ class B(A[str]): pass
t = freshen_all_functions_type_vars(t)
t = bind_self(t, original_type, is_classmethod=True)
assert isuper is not None
t = cast(CallableType, expand_type_by_instance(t, isuper))
t = expand_type_by_instance(t, isuper)
freeze_all_type_vars(t)
return t.copy_modified(variables=list(tvars) + list(t.variables))
elif isinstance(t, Overloaded):
Expand Down
16 changes: 14 additions & 2 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
)


@overload
def expand_type(
typ: CallableType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
) -> CallableType:
...


@overload
def expand_type(
typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
Expand All @@ -70,6 +77,11 @@ def expand_type(
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables))


@overload
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType:
...


@overload
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType:
...
Expand Down Expand Up @@ -133,7 +145,7 @@ def freshen_function_type_vars(callee: F) -> F:
tv = ParamSpecType.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvs)
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
return cast(F, fresh)
else:
assert isinstance(callee, Overloaded)
Expand Down Expand Up @@ -346,7 +358,7 @@ def interpolate_args_for_unpack(
)
return (arg_names, arg_kinds, arg_types)

def visit_callable_type(self, t: CallableType) -> Type:
def visit_callable_type(self, t: CallableType) -> CallableType:
param_spec = t.param_spec()
if param_spec is not None:
repl = get_proper_type(self.variables.get(param_spec.id))
Expand Down
4 changes: 2 additions & 2 deletions mypy/server/astdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'

from __future__ import annotations

from typing import Sequence, Tuple, Union, cast
from typing import Sequence, Tuple, Union
from typing_extensions import TypeAlias as _TypeAlias

from mypy.expandtype import expand_type
Expand Down Expand Up @@ -442,7 +442,7 @@ def normalize_callable_variables(self, typ: CallableType) -> CallableType:
tv = v.copy_modified(id=tid)
tvs.append(tv)
tvmap[v.id] = tv
return cast(CallableType, expand_type(typ, tvmap)).copy_modified(variables=tvs)
return expand_type(typ, tvmap).copy_modified(variables=tvs)

def visit_tuple_type(self, typ: TupleType) -> SnapshotItem:
return ("TupleType", snapshot_types(typ.items))
Expand Down

0 comments on commit 4b3722f

Please sign in to comment.