Skip to content

Commit

Permalink
Foundations for non-linear solver and polymorphic application (#15287)
Browse files Browse the repository at this point in the history
Fixes #1317
Fixes #5738
Fixes #12919 
(also fixes a `FIX` comment that is more than 10 years old according to
git blame)

Note: although this PR fixes most typical use-cases for type inference
against generic functions, it is intentionally incomplete, and it is
made in a way to limit implications to small scope.

This PR has essentially three components (better infer, better solve,
better apply - all three are needed for this MVP to work):
* A "tiny" change to `constraints.py`: if the actual function is
generic, we unify it with template before inferring constraints. This
prevents leaking generic type variables of actual in the solutions
(which makes no sense), but also introduces new kind of constraints `T
<: F[S]`, where type variables we solve for appear in target type. These
are much harder to solve, but also it is a great opportunity to play
with them to prepare for single bin inference (if we will switch to it
in some form later). Note unifying is not the best solution, but a good
first approximation (see below on what is the best solution).
* New more sophisticated constraint solver in `solve.py`. The full
algorithm is outlined in the docstring for `solve_non_linear()`. It
looks like it should be able to solve arbitrary constraints that don't
(indirectly) contain "F-bounded" things like `T <: list[T]`. Very short
the idea is to compute transitive closure, then organize constraints by
topologically sorted SCCs.
* Polymorphic type argument application in `checkexpr.py`. In cases
where solver identifies there are free variables (e.g. we have just one
constraint `S <: list[T]`, so `T` is free, and solution for `S` is
`list[T]`) it will apply the solutions while creating new generic
functions. For example, if we have a function `def [S, T] (fn:
Callable[[S], T]) -> Callable[[S], T]` applied to a function `def [U]
(x: U) -> U`, this will result in `def [T] (T) -> T` as the return.

I want to put here some thoughts on the last ingredient, since it may be
mysterious, but now it seems to me it is actually a very well defined
procedure. The key point here is thinking about generic functions as
about infinite intersections or infinite overloads. Now reducing these
infinite overloads/intersections to finite ones it is easy to understand
what is actually going on. For example, imagine we live in a world with
just two types `int` and `str`. Now we have two functions:
```python
T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")

def dec(fn: Callable[[T], S]) -> Callable[[T], S]: ...
def id(x: U) -> U: ...
```
the first one can be seen as overload over
```
((int) -> int) -> ((int) -> int)  # 1
((int) -> str) -> ((int) -> str)  # 2
((str) -> int) -> ((str) -> int)  # 3
((str) -> str) -> ((str) -> str)  # 4
```
and second as an overload over
```
(int) -> int
(str) -> str
```
Now what happens when I apply `dec(id)`? We need to choose an overload
that matches the argument (this is what we call type inference), but
here is a trick, in this case two overloads of `dec` match the argument
type. So (and btw I think we are missing this for real overloads) we
construct a new overload that returns intersection of matching overloads
`# 1` and `# 4`. So if we generalize this intuition to the general case,
the inference is selection of an (infinite) parametrized subset among
the bigger parameterized set of intersecting types. The only question is
whether resulting infinite intersection is representable in our type
system. For example `forall T. dict[T, T]` can make sense but is not
representable, while `forall T. (T) -> T` is a well defined type. And
finally, there is a very easy way to find whether a type is
representable or not, we are already doing this during semantic
analyzis. I use the same logic (that I used to view as ad-hoc because of
lack of good syntax for callables) to bind type variables in the
inferred type.

OK, so here is the list of missing features, and some comments on them:
1. Instead of unifying the actual with template we should include
actual's variables in variable set we solve for, as explained in
#5738 (comment). Note
however, this will work only together with the next item
2. We need to (iteratively) infer secondary constraints after linear
propagation, e.g. `Sequence[T] <: S <: Sequence[U] => T <: U`
3. Support `ParamSpec` (and probably `TypeVarTuple`). Current support
for applying callables with `ParamSpec` to generics is hacky, and kind
of dead-end. Although `(Callable[P, T]) -> Callable[P, List[T]]` works
when applied to `id`, even a slight variation like `(Callable[P,
List[T]]) -> Callable[P, T]` fails. I think it needs to be re-worked in
the framework I propose (the tests I added are just to be sure I don't
break existing code)
4. Support actual types that are generic in type variables with upper
bounds or values (likely we just need to be careful when propagating
constraints and choosing free variable within an SCC).
5. Add backtracking for upper/lower bound choice. In general, in the
current "Hanoi Tower" inference scheme it is very hard to backtrack, but
in in this specific choice in the new solver, it should be totally
possible to switch from lower to upper bound on a previous step, if we
found no solution (or `<nothing>`/`object`).
6. After we polish it, we can use the new solver in more situations,
e.g. for return type context, and for unification during callable
subtyping.
7. Long term we may want to allow instances to bind type variables, at
least for things like `LRUCache[[x: T], T]`. Btw note that I apply force
expansion to type aliases and callback protocols. Since I can't
transform e.g. `A = Callable[[T], T]` into a generic callable without
getting proper type.
8. We need to figure out a solution for scenarios where non-linear
targets with free variables and constant targets mix without secondary
constraints, like `T <: List[int], T <: List[S]`.

I am planning to address at least majority of the above items, but I
think we should move slowly, since in my experience type inference is
really fragile topic with hard to predict long reaching consequences.
Please play with this PR if you want to and have time, and please
suggest tests to add.
  • Loading branch information
ilevkivskyi authored Jun 18, 2023
1 parent 91b6740 commit 0873230
Show file tree
Hide file tree
Showing 17 changed files with 998 additions and 193 deletions.
106 changes: 2 additions & 104 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,12 @@
Callable,
ClassVar,
Dict,
Iterable,
Iterator,
Mapping,
NamedTuple,
NoReturn,
Sequence,
TextIO,
TypeVar,
)
from typing_extensions import Final, TypeAlias as _TypeAlias

Expand All @@ -47,6 +45,7 @@
import mypy.semanal_main
from mypy.checker import TypeChecker
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.graph_utils import prepare_sccs, strongly_connected_components, topsort
from mypy.indirection import TypeIndirectionVisitor
from mypy.messages import MessageBuilder
from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable, TypeInfo
Expand Down Expand Up @@ -3466,15 +3465,8 @@ def sorted_components(
edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices}
sccs = list(strongly_connected_components(vertices, edges))
# Topsort.
sccsmap = {id: frozenset(scc) for scc in sccs for id in scc}
data: dict[AbstractSet[str], set[AbstractSet[str]]] = {}
for scc in sccs:
deps: set[AbstractSet[str]] = set()
for id in scc:
deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max))
data[frozenset(scc)] = deps
res = []
for ready in topsort(data):
for ready in topsort(prepare_sccs(sccs, edges)):
# Sort the sets in ready by reversed smallest State.order. Examples:
#
# - If ready is [{x}, {y}], x.order == 1, y.order == 2, we get
Expand All @@ -3499,100 +3491,6 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in
]


def strongly_connected_components(
vertices: AbstractSet[str], edges: dict[str, list[str]]
) -> Iterator[set[str]]:
"""Compute Strongly Connected Components of a directed graph.
Args:
vertices: the labels for the vertices
edges: for each vertex, gives the target vertices of its outgoing edges
Returns:
An iterator yielding strongly connected components, each
represented as a set of vertices. Each input vertex will occur
exactly once; vertices not part of a SCC are returned as
singleton sets.
From https://code.activestate.com/recipes/578507/.
"""
identified: set[str] = set()
stack: list[str] = []
index: dict[str, int] = {}
boundaries: list[int] = []

def dfs(v: str) -> Iterator[set[str]]:
index[v] = len(stack)
stack.append(v)
boundaries.append(index[v])

for w in edges[v]:
if w not in index:
yield from dfs(w)
elif w not in identified:
while index[w] < boundaries[-1]:
boundaries.pop()

if boundaries[-1] == index[v]:
boundaries.pop()
scc = set(stack[index[v] :])
del stack[index[v] :]
identified.update(scc)
yield scc

for v in vertices:
if v not in index:
yield from dfs(v)


T = TypeVar("T")


def topsort(data: dict[T, set[T]]) -> Iterable[set[T]]:
"""Topological sort.
Args:
data: A map from vertices to all vertices that it has an edge
connecting it to. NOTE: This data structure
is modified in place -- for normalization purposes,
self-dependencies are removed and entries representing
orphans are added.
Returns:
An iterator yielding sets of vertices that have an equivalent
ordering.
Example:
Suppose the input has the following structure:
{A: {B, C}, B: {D}, C: {D}}
This is normalized to:
{A: {B, C}, B: {D}, C: {D}, D: {}}
The algorithm will yield the following values:
{D}
{B, C}
{A}
From https://code.activestate.com/recipes/577413/.
"""
# TODO: Use a faster algorithm?
for k, v in data.items():
v.discard(k) # Ignore self dependencies.
for item in set.union(*data.values()) - set(data.keys()):
data[item] = set()
while True:
ready = {item for item, dep in data.items() if not dep}
if not ready:
break
yield ready
data = {item: (dep - ready) for item, dep in data.items() if item not in ready}
assert not data, f"A cyclic dependency exists amongst {data!r}"


def missing_stubs_file(cache_dir: str) -> str:
return os.path.join(cache_dir, "missing_stubs")

Expand Down
145 changes: 143 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import mypy.errorcodes as codes
from mypy import applytype, erasetype, join, message_registry, nodes, operators, types
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
from mypy.checkmember import analyze_member_access, type_object_type
from mypy.checkmember import analyze_member_access, freeze_all_type_vars, type_object_type
from mypy.checkstrformat import StringFormatterChecker
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
from mypy.errors import ErrorWatcher, report_internal_error
Expand Down Expand Up @@ -98,8 +98,15 @@
)
from mypy.semanal_enum import ENUM_BASES
from mypy.state import state
from mypy.subtypes import is_equivalent, is_same_type, is_subtype, non_method_protocol_members
from mypy.subtypes import (
find_member,
is_equivalent,
is_same_type,
is_subtype,
non_method_protocol_members,
)
from mypy.traverser import has_await_expression
from mypy.type_visitor import TypeTranslator
from mypy.typeanal import (
check_for_explicit_any,
has_any_from_unimported_type,
Expand All @@ -114,6 +121,7 @@
false_only,
fixup_partial_type,
function_type,
get_type_vars,
is_literal_type_like,
make_simplified_union,
simple_literal_type,
Expand Down Expand Up @@ -146,6 +154,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
Expand Down Expand Up @@ -300,6 +309,7 @@ def __init__(
# on whether current expression is a callee, to give better error messages
# related to type context.
self.is_callee = False
type_state.infer_polymorphic = self.chk.options.new_type_inference

def reset(self) -> None:
self.resolved_type = {}
Expand Down Expand Up @@ -1791,6 +1801,51 @@ def infer_function_type_arguments(
inferred_args[0] = self.named_type("builtins.str")
elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg):
self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context)

if self.chk.options.new_type_inference and any(
a is None
or isinstance(get_proper_type(a), UninhabitedType)
or set(get_type_vars(a)) & set(callee_type.variables)
for a in inferred_args
):
# If the regular two-phase inference didn't work, try inferring type
# variables while allowing for polymorphic solutions, i.e. for solutions
# potentially involving free variables.
# TODO: support the similar inference for return type context.
poly_inferred_args = infer_function_type_arguments(
callee_type,
arg_types,
arg_kinds,
formal_to_actual,
context=self.argument_infer_context(),
strict=self.chk.in_checked_function(),
allow_polymorphic=True,
)
for i, pa in enumerate(get_proper_types(poly_inferred_args)):
if isinstance(pa, (NoneType, UninhabitedType)) or has_erased_component(pa):
# Indicate that free variables should not be applied in the call below.
poly_inferred_args[i] = None
poly_callee_type = self.apply_generic_arguments(
callee_type, poly_inferred_args, context
)
yes_vars = poly_callee_type.variables
no_vars = {v for v in callee_type.variables if v not in poly_callee_type.variables}
if not set(get_type_vars(poly_callee_type)) & no_vars:
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
applied = apply_poly(poly_callee_type, yes_vars)
if applied is not None and poly_inferred_args != [UninhabitedType()] * len(
poly_inferred_args
):
freeze_all_type_vars(applied)
return applied
# If it didn't work, erase free variables as <nothing>, to avoid confusing errors.
inferred_args = [
expand_type(a, {v.id: UninhabitedType() for v in callee_type.variables})
if a is not None
else None
for a in inferred_args
]
else:
# In dynamically typed functions use implicit 'Any' types for
# type variables.
Expand Down Expand Up @@ -5393,6 +5448,92 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
return c.copy_modified(ret_type=new_ret_type)


def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> Optional[CallableType]:
"""Make free type variables generic in the type if possible.
This will translate the type `tp` while trying to create valid bindings for
type variables `poly_tvars` while traversing the type. This follows the same rules
as we do during semantic analysis phase, examples:
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
* List[T] -> None (not possible)
"""
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None


class PolyTranslationError(Exception):
pass


class PolyTranslator(TypeTranslator):
"""Make free type variables generic in the type if possible.
See docstring for apply_poly() for details.
"""

def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars: set[TypeVarLikeType] = set()
self.seen_aliases: set[TypeInfo] = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
for arg in t.arg_types:
found_vars |= set(get_type_vars(arg)) & self.poly_tvars

found_vars -= self.bound_tvars
self.bound_tvars |= found_vars
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + list(found_vars)
return result

def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: Support polymorphic apply for ParamSpec.
raise PolyTranslationError()

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# TODO: Support polymorphic apply for TypeVarTuple.
raise PolyTranslationError()

def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
return t.copy_modified()
if not t.is_recursive:
return get_proper_type(t).accept(self)
# We can't handle polymorphic application for recursive generic aliases
# without risking an infinite recursion, just give up for now.
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
self.seen_aliases.add(t.type)
call = find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(self)
return super().visit_instance(t)


class ArgInferSecondPassQuery(types.BoolTypeQuery):
"""Query whether an argument type should be inferred in the second pass.
Expand Down
25 changes: 24 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,30 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
# FIX what if one of the functions is generic
# TODO: Erase template variables if it is generic?
if (
type_state.infer_polymorphic
and cactual.variables
and cactual.param_spec() is None
# Technically, the correct inferred type for application of e.g.
# Callable[..., T] -> Callable[..., T] (with literal ellipsis), to a generic
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
# allow it to leak, to be later bound to self. A bunch of existing code
# depends on this old behaviour.
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
):
# If actual is generic, unify it with template. Note: this is
# not an ideal solution (which would be adding the generic variables
# to the constraint inference set), but it's a good first approximation,
# and this will prevent leaking these variables in the solutions.
# Note: this may infer constraints like T <: S or T <: List[S]
# that contain variables in the target.
unified = mypy.subtypes.unify_generic_callable(
cactual, template, ignore_return=True
)
if unified is not None:
cactual = unified
res.extend(infer_constraints(cactual, template, neg_op(self.direction)))

# We can't infer constraints from arguments if the template is Callable[..., T]
# (with literal '...').
Expand Down
Loading

0 comments on commit 0873230

Please sign in to comment.