diff --git a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md index 22210f88d2b051..f0e9367a5a1f0e 100644 --- a/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md +++ b/crates/ty_python_semantic/resources/mdtest/generics/specialize_constrained.md @@ -303,3 +303,33 @@ def mutually_bound[T: Base, U](): # revealed: ty_extensions.Specialization[T@mutually_bound = Base, U@mutually_bound = Sub] reveal_type(generic_context(mutually_bound).specialize_constrained(ConstraintSet.range(Never, U, Sub) & ConstraintSet.range(Never, U, T))) ``` + +## Nested typevars + +A typevar's constraint can _mention_ another typevar without _constraining_ it. In this example, `U` +must be specialized to `list[T]`, but it cannot affect what `T` is specialized to. + +```py +from typing import Never +from ty_extensions import ConstraintSet, generic_context + +def mentions[T, U](): + constraints = ConstraintSet.range(Never, T, int) & ConstraintSet.range(list[T], U, list[T]) + # revealed: ty_extensions.ConstraintSet[((T@mentions ≤ int) ∧ (U@mentions = list[T@mentions]))] + reveal_type(constraints) + # revealed: ty_extensions.Specialization[T@mentions = int, U@mentions = list[int]] + reveal_type(generic_context(mentions).specialize_constrained(constraints)) +``` + +If the constraint set contains mutually recursive bounds, specialization inference will not +converge. This test ensures that our cycle detection prevents an endless loop or stack overflow in +this case. + +```py +def divergent[T, U](): + constraints = ConstraintSet.range(list[U], T, list[U]) & ConstraintSet.range(list[T], U, list[T]) + # revealed: ty_extensions.ConstraintSet[((T@divergent = list[U@divergent]) ∧ (U@divergent = list[T@divergent]))] + reveal_type(constraints) + # revealed: None + reveal_type(generic_context(divergent).specialize_constrained(constraints)) +``` diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index b4b691168f62bc..ab21b49bf381ec 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -53,6 +53,7 @@ //! //! [bdd]: https://en.wikipedia.org/wiki/Binary_decision_diagram +use std::cell::RefCell; use std::cmp::Ordering; use std::fmt::Display; use std::ops::Range; @@ -62,9 +63,10 @@ use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; use crate::types::generics::{GenericContext, InferableTypeVars, Specialization}; +use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation, - TypeVarBoundOrConstraints, UnionType, + TypeVarBoundOrConstraints, UnionType, walk_bound_type_var_type, }; use crate::{Db, FxOrderSet}; @@ -213,6 +215,100 @@ impl<'db> ConstraintSet<'db> { self.node.is_always_satisfied(db) } + /// Returns whether this constraint set contains any cycles between typevars. If it does, then + /// we cannot create a specialization from this constraint set. + /// + /// We have restrictions in place that ensure that there are no cycles in the _lower and upper + /// bounds_ of each constraint, but it's still possible for a constraint to _mention_ another + /// typevar without _constraining_ it. For instance, `(T ≤ int) ∧ (U ≤ list[T])` is a valid + /// constraint set, which we can create a specialization from (`T = int, U = list[int]`). But + /// `(T ≤ list[U]) ∧ (U ≤ list[T])` does not violate our lower/upper bounds restrictions, since + /// neither bound _is_ a typevar. And it's not something we can create a specialization from, + /// since we would endlessly substitute until we stack overflow. + pub(crate) fn is_cyclic(self, db: &'db dyn Db) -> bool { + #[derive(Default)] + struct CollectReachability<'db> { + reachable_typevars: RefCell>>, + recursion_guard: TypeCollector<'db>, + } + + impl<'db> TypeVisitor<'db> for CollectReachability<'db> { + fn should_visit_lazy_type_attributes(&self) -> bool { + true + } + + fn visit_bound_type_var_type( + &self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarInstance<'db>, + ) { + self.reachable_typevars + .borrow_mut() + .insert(bound_typevar.identity(db)); + walk_bound_type_var_type(db, bound_typevar, self); + } + + fn visit_type(&self, db: &'db dyn Db, ty: Type<'db>) { + walk_type_with_recursion_guard(db, ty, self, &self.recursion_guard); + } + } + + fn visit_dfs<'db>( + reachable_typevars: &mut FxHashMap< + BoundTypeVarIdentity<'db>, + FxHashSet>, + >, + discovered: &mut FxHashSet>, + bound_typevar: BoundTypeVarIdentity<'db>, + ) -> bool { + discovered.insert(bound_typevar); + let outgoing = reachable_typevars + .remove(&bound_typevar) + .expect("should not visit typevar twice in DFS"); + for outgoing in outgoing { + if discovered.contains(&outgoing) { + return true; + } + if reachable_typevars.contains_key(&outgoing) { + if visit_dfs(reachable_typevars, discovered, outgoing) { + return true; + } + } + } + discovered.remove(&bound_typevar); + false + } + + // First find all of the typevars that each constraint directly mentions. + let mut reachable_typevars: FxHashMap< + BoundTypeVarIdentity<'db>, + FxHashSet>, + > = FxHashMap::default(); + self.node.for_each_constraint(db, &mut |constraint| { + let visitor = CollectReachability::default(); + visitor.visit_type(db, constraint.lower(db)); + visitor.visit_type(db, constraint.upper(db)); + reachable_typevars + .entry(constraint.typevar(db).identity(db)) + .or_default() + .extend(visitor.reachable_typevars.into_inner()); + }); + + // Then perform a depth-first search to see if there are any cycles. + let mut discovered: FxHashSet> = FxHashSet::default(); + while let Some(bound_typevar) = reachable_typevars.keys().copied().next() { + if !discovered.contains(&bound_typevar) { + let cycle_found = + visit_dfs(&mut reachable_typevars, &mut discovered, bound_typevar); + if cycle_found { + return true; + } + } + } + + false + } + /// Returns the constraints under which `lhs` is a subtype of `rhs`, assuming that the /// constraints in this constraint set hold. Panics if neither of the types being compared are /// a typevar. (That case is handled by `Type::has_relation_to`.) @@ -2964,6 +3060,12 @@ impl<'db> GenericContext<'db> { db: &'db dyn Db, constraints: ConstraintSet<'db>, ) -> Result, ()> { + // If the constraint set is cyclic, don't even try to construct a specialization. + if constraints.is_cyclic(db) { + // TODO: Better error + return Err(()); + } + // First we intersect with the valid specializations of all of the typevars. We need all of // valid specializations to hold simultaneously, so we do this once before abstracting over // each typevar. @@ -3020,7 +3122,7 @@ impl<'db> GenericContext<'db> { types[i] = least_upper_bound; } - Ok(self.specialize(db, types.into_boxed_slice())) + Ok(self.specialize_recursive(db, types.into_boxed_slice())) } } diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 2ee21a02ba9416..b540f20daf6d44 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -500,9 +500,16 @@ impl<'db> GenericContext<'db> { } /// Creates a specialization of this generic context. Panics if the length of `types` does not - /// match the number of typevars in the generic context. You must provide a specific type for - /// each typevar; no defaults are used. (Use [`specialize_partial`](Self::specialize_partial) - /// if you might not have types for every typevar.) + /// match the number of typevars in the generic context. + /// + /// You must provide a specific type for each typevar; no defaults are used. (Use + /// [`specialize_partial`](Self::specialize_partial) if you might not have types for every + /// typevar.) + /// + /// The types you provide should not mention any of the typevars in this generic context; + /// otherwise, you will be left with a partial specialization. (Use + /// [`specialize_recursive`](Self::specialize_recursive) if your types might mention typevars + /// in this generic context.) pub(crate) fn specialize( self, db: &'db dyn Db, @@ -512,6 +519,41 @@ impl<'db> GenericContext<'db> { Specialization::new(db, self, types, None, None) } + /// Creates a specialization of this generic context. Panics if the length of `types` does not + /// match the number of typevars in the generic context. + /// + /// You are allowed to provide types that mention the typevars in this generic context. + pub(crate) fn specialize_recursive( + self, + db: &'db dyn Db, + mut types: Box<[Type<'db>]>, + ) -> Specialization<'db> { + let len = types.len(); + assert!(self.len(db) == len); + loop { + let mut any_changed = false; + for i in 0..len { + let partial = PartialSpecialization { + generic_context: self, + types: &types, + }; + let updated = types[i].apply_type_mapping( + db, + &TypeMapping::PartialSpecialization(partial), + TypeContext::default(), + ); + if updated != types[i] { + types[i] = updated; + any_changed = true; + } + } + + if !any_changed { + return Specialization::new(db, self, types, None, None); + } + } + } + /// Creates a specialization of this generic context for the `tuple` class. pub(crate) fn specialize_tuple( self,