Skip to content

Commit

Permalink
Begin unifying logic for constraint building (#14406)
Browse files Browse the repository at this point in the history
Implements support for unpacking varlength tuples from *args, but
because it became apparent that several parts of constraints building
were doing nearly the same thing for typevar tuples, we begin extracting
out some of the logic for re-use. Some existing callsites still should
be switched to the new helpers but it is defered to future PRs.
  • Loading branch information
jhance authored Jan 9, 2023
1 parent 7efe8e5 commit e959565
Show file tree
Hide file tree
Showing 10 changed files with 387 additions and 171 deletions.
31 changes: 25 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@
)
from mypy.typestate import type_state
from mypy.typevars import fill_typevars
from mypy.typevartuples import find_unpack_in_list
from mypy.util import split_module_names
from mypy.visitor import ExpressionVisitor

Expand Down Expand Up @@ -2064,12 +2065,30 @@ def check_argument_types(
actual_kinds = [arg_kinds[a] for a in actuals]
if isinstance(orig_callee_arg_type, UnpackType):
unpacked_type = get_proper_type(orig_callee_arg_type.type)
# Only case we know of thus far.
assert isinstance(unpacked_type, TupleType)
actual_types = [arg_types[a] for a in actuals]
actual_kinds = [arg_kinds[a] for a in actuals]
callee_arg_types = unpacked_type.items
callee_arg_kinds = [ARG_POS] * len(actuals)
if isinstance(unpacked_type, TupleType):
inner_unpack_index = find_unpack_in_list(unpacked_type.items)
if inner_unpack_index is None:
callee_arg_types = unpacked_type.items
callee_arg_kinds = [ARG_POS] * len(actuals)
else:
inner_unpack = get_proper_type(unpacked_type.items[inner_unpack_index])
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
# We assume heterogenous tuples are desugared earlier
assert isinstance(inner_unpacked_type, Instance)
assert inner_unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = (
unpacked_type.items[:inner_unpack_index]
+ [inner_unpacked_type.args[0]]
* (len(actuals) - len(unpacked_type.items) + 1)
+ unpacked_type.items[inner_unpack_index + 1 :]
)
callee_arg_kinds = [ARG_POS] * len(actuals)
else:
assert isinstance(unpacked_type, Instance)
assert unpacked_type.type.fullname == "builtins.tuple"
callee_arg_types = [unpacked_type.args[0]] * len(actuals)
callee_arg_kinds = [ARG_POS] * len(actuals)
else:
callee_arg_types = [orig_callee_arg_type] * len(actuals)
callee_arg_kinds = [callee.arg_kinds[i]] * len(actuals)
Expand Down
214 changes: 123 additions & 91 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from mypy.typevartuples import (
extract_unpack,
find_unpack_in_list,
split_with_instance,
split_with_mapped_and_template,
split_with_prefix_and_suffix,
)
Expand Down Expand Up @@ -566,7 +565,7 @@ def visit_type_var_tuple(self, template: TypeVarTupleType) -> list[Constraint]:
raise NotImplementedError

def visit_unpack_type(self, template: UnpackType) -> list[Constraint]:
raise NotImplementedError
raise RuntimeError("Mypy bug: unpack should be handled at a higher level.")

def visit_parameters(self, template: Parameters) -> list[Constraint]:
# constraining Any against C[P] turns into infer_against_any([P], Any)
Expand Down Expand Up @@ -638,47 +637,22 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
tvars = mapped.type.defn.type_vars

if instance.type.has_type_var_tuple_type:
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
instance_prefix, instance_middle, instance_suffix = split_with_instance(
instance
)

# Add a constraint for the type var tuple, and then
# remove it for the case below.
instance_unpack = extract_unpack(instance_middle)
if instance_unpack is not None:
if isinstance(instance_unpack, TypeVarTupleType):
res.append(
Constraint(
instance_unpack,
SUBTYPE_OF,
TupleType(list(mapped_middle), instance_unpack.tuple_fallback),
)
)
elif (
isinstance(instance_unpack, Instance)
and instance_unpack.type.fullname == "builtins.tuple"
):
for item in mapped_middle:
res.extend(
infer_constraints(
instance_unpack.args[0], item, self.direction
)
)
elif isinstance(instance_unpack, TupleType):
if len(instance_unpack.items) == len(mapped_middle):
for instance_arg, item in zip(
instance_unpack.items, mapped_middle
):
res.extend(
infer_constraints(instance_arg, item, self.direction)
)

mapped_args = mapped_prefix + mapped_suffix
instance_args = instance_prefix + instance_suffix

assert instance.type.type_var_tuple_prefix is not None
assert instance.type.type_var_tuple_suffix is not None
assert mapped.type.type_var_tuple_prefix is not None
assert mapped.type.type_var_tuple_suffix is not None

unpack_constraints, mapped_args, instance_args = build_constraints_for_unpack(
mapped.args,
mapped.type.type_var_tuple_prefix,
mapped.type.type_var_tuple_suffix,
instance.args,
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
self.direction,
)
res.extend(unpack_constraints)

tvars_prefix, _, tvars_suffix = split_with_prefix_and_suffix(
tuple(tvars),
instance.type.type_var_tuple_prefix,
Expand Down Expand Up @@ -732,57 +706,22 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
mapped = map_instance_to_supertype(instance, template.type)
tvars = template.type.defn.type_vars
if template.type.has_type_var_tuple_type:
mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped)
template_prefix, template_middle, template_suffix = split_with_instance(
template
)
split_result = split_with_mapped_and_template(mapped, template)
assert split_result is not None
(
mapped_prefix,
mapped_middle,
mapped_suffix,
template_prefix,
template_middle,
template_suffix,
) = split_result

# Add a constraint for the type var tuple, and then
# remove it for the case below.
template_unpack = extract_unpack(template_middle)
if template_unpack is not None:
if isinstance(template_unpack, TypeVarTupleType):
res.append(
Constraint(
template_unpack,
SUPERTYPE_OF,
TupleType(list(mapped_middle), template_unpack.tuple_fallback),
)
)
elif (
isinstance(template_unpack, Instance)
and template_unpack.type.fullname == "builtins.tuple"
):
for item in mapped_middle:
res.extend(
infer_constraints(
template_unpack.args[0], item, self.direction
)
)
elif isinstance(template_unpack, TupleType):
if len(template_unpack.items) == len(mapped_middle):
for template_arg, item in zip(
template_unpack.items, mapped_middle
):
res.extend(
infer_constraints(template_arg, item, self.direction)
)

mapped_args = mapped_prefix + mapped_suffix
template_args = template_prefix + template_suffix

assert mapped.type.type_var_tuple_prefix is not None
assert mapped.type.type_var_tuple_suffix is not None
assert template.type.type_var_tuple_prefix is not None
assert template.type.type_var_tuple_suffix is not None

unpack_constraints, mapped_args, template_args = build_constraints_for_unpack(
mapped.args,
mapped.type.type_var_tuple_prefix,
mapped.type.type_var_tuple_suffix,
template.args,
template.type.type_var_tuple_prefix,
template.type.type_var_tuple_suffix,
self.direction,
)
res.extend(unpack_constraints)

tvars_prefix, _, tvars_suffix = split_with_prefix_and_suffix(
tuple(tvars),
template.type.type_var_tuple_prefix,
Expand Down Expand Up @@ -945,12 +884,28 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
if not template.is_ellipsis_args:
if find_unpack_in_list(template.arg_types) is not None:
(
unpack_constraints,
cactual_args_t,
template_args_t,
) = find_and_build_constraints_for_unpack(
tuple(cactual.arg_types), tuple(template.arg_types), self.direction
)
template_args = list(template_args_t)
cactual_args = list(cactual_args_t)
res.extend(unpack_constraints)
assert len(template_args) == len(cactual_args)
else:
template_args = template.arg_types
cactual_args = cactual.arg_types
# The lengths should match, but don't crash (it will error elsewhere).
for t, a in zip(template.arg_types, cactual.arg_types):
for t, a in zip(template_args, cactual_args):
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
else:
# sometimes, it appears we try to get constraints between two paramspec callables?

# TODO: Direction
# TODO: check the prefixes match
prefix = param_spec.prefix
Expand Down Expand Up @@ -1197,3 +1152,80 @@ def find_matching_overload_items(
# it maintains backward compatibility.
res = items[:]
return res


def find_and_build_constraints_for_unpack(
mapped: tuple[Type, ...], template: tuple[Type, ...], direction: int
) -> tuple[list[Constraint], tuple[Type, ...], tuple[Type, ...]]:
mapped_prefix_len = find_unpack_in_list(mapped)
if mapped_prefix_len is not None:
mapped_suffix_len: int | None = len(mapped) - mapped_prefix_len - 1
else:
mapped_suffix_len = None

template_prefix_len = find_unpack_in_list(template)
assert template_prefix_len is not None
template_suffix_len = len(template) - template_prefix_len - 1

return build_constraints_for_unpack(
mapped,
mapped_prefix_len,
mapped_suffix_len,
template,
template_prefix_len,
template_suffix_len,
direction,
)


def build_constraints_for_unpack(
mapped: tuple[Type, ...],
mapped_prefix_len: int | None,
mapped_suffix_len: int | None,
template: tuple[Type, ...],
template_prefix_len: int,
template_suffix_len: int,
direction: int,
) -> tuple[list[Constraint], tuple[Type, ...], tuple[Type, ...]]:
split_result = split_with_mapped_and_template(
mapped,
mapped_prefix_len,
mapped_suffix_len,
template,
template_prefix_len,
template_suffix_len,
)
assert split_result is not None
(
mapped_prefix,
mapped_middle,
mapped_suffix,
template_prefix,
template_middle,
template_suffix,
) = split_result

template_unpack = extract_unpack(template_middle)
res = []

if template_unpack is not None:
if isinstance(template_unpack, TypeVarTupleType):
res.append(
Constraint(
template_unpack,
direction,
TupleType(list(mapped_middle), template_unpack.tuple_fallback),
)
)
elif (
isinstance(template_unpack, Instance)
and template_unpack.type.fullname == "builtins.tuple"
):
for item in mapped_middle:
res.extend(infer_constraints(template_unpack.args[0], item, direction))

elif isinstance(template_unpack, TupleType):
if len(template_unpack.items) == len(mapped_middle):
for template_arg, item in zip(template_unpack.items, mapped_middle):
res.extend(infer_constraints(template_arg, item, direction))
return (res, mapped_prefix + mapped_suffix, template_prefix + template_suffix)
Loading

0 comments on commit e959565

Please sign in to comment.