Skip to content

Special-case unions in polymorphic inference #16461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Iterable, Sequence
from typing_extensions import TypeAlias as _TypeAlias

from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints
from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints, neg_op
from mypy.expandtype import expand_type
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.join import join_types
Expand Down Expand Up @@ -69,6 +69,10 @@ def solve_constraints(
extra_vars.extend([v.id for v in c.extra_tvars if v.id not in vars + extra_vars])
originals.update({v.id: v for v in c.extra_tvars if v.id not in originals})

if allow_polymorphic:
# Constraints inferred from unions require special handling in polymorphic inference.
constraints = skip_reverse_union_constraints(constraints)

# Collect a list of constraints for each type variable.
cmap: dict[TypeVarId, list[Constraint]] = {tv: [] for tv in vars + extra_vars}
for con in constraints:
Expand Down Expand Up @@ -431,19 +435,15 @@ def transitive_closure(
uppers[l] |= uppers[upper]
for lt in lowers[lower]:
for ut in uppers[upper]:
# TODO: what if secondary constraints result in inference
# against polymorphic actual (also in below branches)?
remaining |= set(infer_constraints(lt, ut, SUBTYPE_OF))
remaining |= set(infer_constraints(ut, lt, SUPERTYPE_OF))
add_secondary_constraints(remaining, lt, ut)
elif c.op == SUBTYPE_OF:
if c.target in uppers[c.type_var]:
continue
for l in tvars:
if (l, c.type_var) in graph:
uppers[l].add(c.target)
for lt in lowers[c.type_var]:
remaining |= set(infer_constraints(lt, c.target, SUBTYPE_OF))
remaining |= set(infer_constraints(c.target, lt, SUPERTYPE_OF))
add_secondary_constraints(remaining, lt, c.target)
else:
assert c.op == SUPERTYPE_OF
if c.target in lowers[c.type_var]:
Expand All @@ -452,11 +452,24 @@ def transitive_closure(
if (c.type_var, u) in graph:
lowers[u].add(c.target)
for ut in uppers[c.type_var]:
remaining |= set(infer_constraints(ut, c.target, SUPERTYPE_OF))
remaining |= set(infer_constraints(c.target, ut, SUBTYPE_OF))
add_secondary_constraints(remaining, c.target, ut)
return graph, lowers, uppers


def add_secondary_constraints(cs: set[Constraint], lower: Type, upper: Type) -> None:
"""Add secondary constraints inferred between lower and upper (in place)."""
if isinstance(get_proper_type(upper), UnionType) and isinstance(
get_proper_type(lower), UnionType
):
# When both types are unions, this can lead to inferring spurious constraints,
# for example Union[T, int] <: S <: Union[T, int] may infer T <: int.
# To avoid this, just skip them for now.
return
# TODO: what if secondary constraints result in inference against polymorphic actual?
cs.update(set(infer_constraints(lower, upper, SUBTYPE_OF)))
cs.update(set(infer_constraints(upper, lower, SUPERTYPE_OF)))


def compute_dependencies(
tvars: list[TypeVarId], graph: Graph, lowers: Bounds, uppers: Bounds
) -> dict[TypeVarId, list[TypeVarId]]:
Expand Down Expand Up @@ -494,6 +507,28 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool:
return True


def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
"""Avoid ambiguities for constraints inferred from unions during polymorphic inference.

Polymorphic inference implicitly relies on assumption that a reverse of a linear constraint
is a linear constraint. This is however not true in presence of union types, for example
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
solution T = Union[S, int], S = <free>.

TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
this would require passing around a flag through all infer_constraints() calls.
"""
reverse_union_cs = set()
for c in cs:
p_target = get_proper_type(c.target)
if isinstance(p_target, UnionType):
for item in p_target.items:
if isinstance(item, TypeVarType):
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
return [c for c in cs if c not in reverse_union_cs]


def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]:
"""Find type variables for which we are solving in a target type."""
return {tv.id for tv in get_all_type_vars(target)} & set(vars)
Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3767,3 +3767,24 @@ def f(values: List[T]) -> T: ...
x = foo(f([C()]))
reveal_type(x) # N: Revealed type is "__main__.C"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallableUnion]
from typing import Callable, TypeVar, List, Union

T = TypeVar("T")
S = TypeVar("S")

def dec(f: Callable[[S], T]) -> Callable[[S], List[T]]: ...
@dec
def func(arg: T) -> Union[T, str]:
...
reveal_type(func) # N: Revealed type is "def [S] (S`1) -> builtins.list[Union[S`1, builtins.str]]"
reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"

def dec2(f: Callable[[S], List[T]]) -> Callable[[S], T]: ...
@dec2
def func2(arg: T) -> List[Union[T, str]]:
...
reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/list.pyi]
22 changes: 22 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2086,3 +2086,25 @@ reveal_type(d(b, f1)) # E: Cannot infer type argument 1 of "d" \
# N: Revealed type is "def (*Any, **Any)"
reveal_type(d(b, f2)) # N: Revealed type is "def (builtins.int)"
[builtins fixtures/paramspec.pyi]

[case testInferenceAgainstGenericCallableUnionParamSpec]
from typing import Callable, TypeVar, List, Union
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")

def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ...
@dec
def func(arg: T) -> Union[T, str]:
...
reveal_type(func) # N: Revealed type is "def [T] (arg: T`-1) -> builtins.list[Union[T`-1, builtins.str]]"
reveal_type(func(42)) # N: Revealed type is "builtins.list[Union[builtins.int, builtins.str]]"

def dec2(f: Callable[P, List[T]]) -> Callable[P, T]: ...
@dec2
def func2(arg: T) -> List[Union[T, str]]:
...
reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/paramspec.pyi]