diff --git a/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md b/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md index 35768ef76d4686..1a72da9464fe8d 100644 --- a/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md +++ b/crates/ty_python_semantic/resources/mdtest/type_properties/implies_subtype_of.md @@ -173,7 +173,10 @@ def given_constraints[T](): static_assert(given_str.implies_subtype_of(T, str)) ``` -This might require propagating constraints from other typevars. +This might require propagating constraints from other typevars. (Note that we perform the test +twice, with different variable orderings. Our BDD implementation uses the Salsa IDs of each typevar +as part of the variable ordering. Reversing the typevar order helps us verify that we don't have any +BDD logic that is dependent on which variable ordering we end up with.) ```py def mutually_constrained[T, U](): @@ -183,6 +186,19 @@ def mutually_constrained[T, U](): static_assert(not given_int.implies_subtype_of(T, bool)) static_assert(not given_int.implies_subtype_of(T, str)) + # If [T ≤ U ∧ U ≤ int], then [T ≤ int] must be true as well. + given_int = ConstraintSet.range(Never, T, U) & ConstraintSet.range(Never, U, int) + static_assert(given_int.implies_subtype_of(T, int)) + static_assert(not given_int.implies_subtype_of(T, bool)) + static_assert(not given_int.implies_subtype_of(T, str)) + +def mutually_constrained[U, T](): + # If [T = U ∧ U ≤ int], then [T ≤ int] must be true as well. + given_int = ConstraintSet.range(U, T, U) & ConstraintSet.range(Never, U, int) + static_assert(given_int.implies_subtype_of(T, int)) + static_assert(not given_int.implies_subtype_of(T, bool)) + static_assert(not given_int.implies_subtype_of(T, str)) + # If [T ≤ U ∧ U ≤ int], then [T ≤ int] must be true as well. given_int = ConstraintSet.range(Never, T, U) & ConstraintSet.range(Never, U, int) static_assert(given_int.implies_subtype_of(T, int)) @@ -236,6 +252,22 @@ def mutually_constrained[T, U](): static_assert(not given_int.implies_subtype_of(Covariant[T], Covariant[bool])) static_assert(not given_int.implies_subtype_of(Covariant[T], Covariant[str])) + # If (T ≤ U ∧ U ≤ int), then (T ≤ int) must be true as well, and therefore + # (Covariant[T] ≤ Covariant[int]). + given_int = ConstraintSet.range(Never, T, U) & ConstraintSet.range(Never, U, int) + static_assert(given_int.implies_subtype_of(Covariant[T], Covariant[int])) + static_assert(not given_int.implies_subtype_of(Covariant[T], Covariant[bool])) + static_assert(not given_int.implies_subtype_of(Covariant[T], Covariant[str])) + +# Repeat the test with a different typevar ordering +def mutually_constrained[U, T](): + # If (T = U ∧ U ≤ int), then (T ≤ int) must be true as well, and therefore + # (Covariant[T] ≤ Covariant[int]). + given_int = ConstraintSet.range(U, T, U) & ConstraintSet.range(Never, U, int) + static_assert(given_int.implies_subtype_of(Covariant[T], Covariant[int])) + static_assert(not given_int.implies_subtype_of(Covariant[T], Covariant[bool])) + static_assert(not given_int.implies_subtype_of(Covariant[T], Covariant[str])) + # If (T ≤ U ∧ U ≤ int), then (T ≤ int) must be true as well, and therefore # (Covariant[T] ≤ Covariant[int]). given_int = ConstraintSet.range(Never, T, U) & ConstraintSet.range(Never, U, int) @@ -281,6 +313,22 @@ def mutually_constrained[T, U](): static_assert(not given_int.implies_subtype_of(Contravariant[bool], Contravariant[T])) static_assert(not given_int.implies_subtype_of(Contravariant[str], Contravariant[T])) + # If (T ≤ U ∧ U ≤ int), then (T ≤ int) must be true as well, and therefore + # (Contravariant[int] ≤ Contravariant[T]). + given_int = ConstraintSet.range(Never, T, U) & ConstraintSet.range(Never, U, int) + static_assert(given_int.implies_subtype_of(Contravariant[int], Contravariant[T])) + static_assert(not given_int.implies_subtype_of(Contravariant[bool], Contravariant[T])) + static_assert(not given_int.implies_subtype_of(Contravariant[str], Contravariant[T])) + +# Repeat the test with a different typevar ordering +def mutually_constrained[U, T](): + # If (T = U ∧ U ≤ int), then (T ≤ int) must be true as well, and therefore + # (Contravariant[int] ≤ Contravariant[T]). + given_int = ConstraintSet.range(U, T, U) & ConstraintSet.range(Never, U, int) + static_assert(given_int.implies_subtype_of(Contravariant[int], Contravariant[T])) + static_assert(not given_int.implies_subtype_of(Contravariant[bool], Contravariant[T])) + static_assert(not given_int.implies_subtype_of(Contravariant[str], Contravariant[T])) + # If (T ≤ U ∧ U ≤ int), then (T ≤ int) must be true as well, and therefore # (Contravariant[int] ≤ Contravariant[T]). given_int = ConstraintSet.range(Never, T, U) & ConstraintSet.range(Never, U, int) @@ -338,6 +386,25 @@ def mutually_constrained[T, U](): static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[bool])) static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[str])) + # If (T = U ∧ U = int), then (T = int) must be true as well. That is an equality constraint, so + # even though T is invariant, it does imply that (Invariant[T] ≤ Invariant[int]). + given_int = ConstraintSet.range(U, T, U) & ConstraintSet.range(int, U, int) + static_assert(given_int.implies_subtype_of(Invariant[T], Invariant[int])) + static_assert(given_int.implies_subtype_of(Invariant[int], Invariant[T])) + static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[bool])) + static_assert(not given_int.implies_subtype_of(Invariant[bool], Invariant[T])) + static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[str])) + static_assert(not given_int.implies_subtype_of(Invariant[str], Invariant[T])) + +# Repeat the test with a different typevar ordering +def mutually_constrained[U, T](): + # If (T = U ∧ U ≤ int), then (T ≤ int) must be true as well. But because T is invariant, that + # does _not_ imply that (Invariant[T] ≤ Invariant[int]). + given_int = ConstraintSet.range(U, T, U) & ConstraintSet.range(Never, U, int) + static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[int])) + static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[bool])) + static_assert(not given_int.implies_subtype_of(Invariant[T], Invariant[str])) + # If (T = U ∧ U = int), then (T = int) must be true as well. That is an equality constraint, so # even though T is invariant, it does imply that (Invariant[T] ≤ Invariant[int]). given_int = ConstraintSet.range(U, T, U) & ConstraintSet.range(int, U, int) diff --git a/crates/ty_python_semantic/src/types/constraints.rs b/crates/ty_python_semantic/src/types/constraints.rs index f8b66c33e4aa49..d76b965bd36105 100644 --- a/crates/ty_python_semantic/src/types/constraints.rs +++ b/crates/ty_python_semantic/src/types/constraints.rs @@ -58,17 +58,18 @@ use std::cmp::Ordering; use std::fmt::Display; +use std::ops::Range; use itertools::Itertools; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashMap, FxHashSet}; use salsa::plumbing::AsId; -use crate::Db; use crate::types::generics::InferableTypeVars; use crate::types::{ BoundTypeVarIdentity, BoundTypeVarInstance, IntersectionType, Type, TypeRelation, TypeVarBoundOrConstraints, UnionType, }; +use crate::{Db, FxOrderSet}; /// An extension trait for building constraint sets from [`Option`] values. pub(crate) trait OptionConstraintsExtension { @@ -206,8 +207,8 @@ impl<'db> ConstraintSet<'db> { } /// Returns whether this constraint set never holds - pub(crate) fn is_never_satisfied(self, _db: &'db dyn Db) -> bool { - self.node.is_never_satisfied() + pub(crate) fn is_never_satisfied(self, db: &'db dyn Db) -> bool { + self.node.is_never_satisfied(db) } /// Returns whether this constraint set always holds @@ -325,7 +326,7 @@ impl<'db> ConstraintSet<'db> { } pub(crate) fn display(self, db: &'db dyn Db) -> impl Display { - self.node.simplify(db).display(db) + self.node.simplify_for_display(db).display(db) } } @@ -716,17 +717,95 @@ impl<'db> Node<'db> { match self { Node::AlwaysTrue => true, Node::AlwaysFalse => false, - Node::Interior(_) => { - let domain = self.domain(db); - let restricted = self.and(db, domain); - restricted == domain + Node::Interior(interior) => { + let map = interior.sequent_map(db); + let mut path = PathAssignments::default(); + self.is_always_satisfied_inner(db, map, &mut path) + } + } + } + + fn is_always_satisfied_inner( + self, + db: &'db dyn Db, + map: &SequentMap<'db>, + path: &mut PathAssignments<'db>, + ) -> bool { + match self { + Node::AlwaysTrue => true, + Node::AlwaysFalse => false, + Node::Interior(interior) => { + // walk_edge will return None if this node's constraint (or anything we can derive + // from it) causes the if_true edge to become impossible. We want to ignore + // impossible paths, and so we treat them as passing the "always satisfied" check. + let constraint = interior.constraint(db); + let true_always_satisfied = path + .walk_edge(map, constraint.when_true(), |path, _| { + interior + .if_true(db) + .is_always_satisfied_inner(db, map, path) + }) + .unwrap_or(true); + if !true_always_satisfied { + return false; + } + + // Ditto for the if_false branch + path.walk_edge(map, constraint.when_false(), |path, _| { + interior + .if_false(db) + .is_always_satisfied_inner(db, map, path) + }) + .unwrap_or(true) } } } /// Returns whether this BDD represent the constant function `false`. - fn is_never_satisfied(self) -> bool { - matches!(self, Node::AlwaysFalse) + fn is_never_satisfied(self, db: &'db dyn Db) -> bool { + match self { + Node::AlwaysTrue => false, + Node::AlwaysFalse => true, + Node::Interior(interior) => { + let map = interior.sequent_map(db); + let mut path = PathAssignments::default(); + self.is_never_satisfied_inner(db, map, &mut path) + } + } + } + + fn is_never_satisfied_inner( + self, + db: &'db dyn Db, + map: &SequentMap<'db>, + path: &mut PathAssignments<'db>, + ) -> bool { + match self { + Node::AlwaysTrue => false, + Node::AlwaysFalse => true, + Node::Interior(interior) => { + // walk_edge will return None if this node's constraint (or anything we can derive + // from it) causes the if_true edge to become impossible. We want to ignore + // impossible paths, and so we treat them as passing the "never satisfied" check. + let constraint = interior.constraint(db); + let true_never_satisfied = path + .walk_edge(map, constraint.when_true(), |path, _| { + interior.if_true(db).is_never_satisfied_inner(db, map, path) + }) + .unwrap_or(true); + if !true_never_satisfied { + return false; + } + + // Ditto for the if_false branch + path.walk_edge(map, constraint.when_false(), |path, _| { + interior + .if_false(db) + .is_never_satisfied_inner(db, map, path) + }) + .unwrap_or(true) + } + } } /// Returns the negation of this BDD. @@ -806,13 +885,6 @@ impl<'db> Node<'db> { .or(db, self.negate(db).and(db, else_node)) } - fn satisfies(self, db: &'db dyn Db, other: Self) -> Self { - let simplified_self = self.simplify(db); - let implication = simplified_self.implies(db, other); - let (simplified, domain) = implication.simplify_and_domain(db); - simplified.and(db, domain) - } - fn implies_subtype_of(self, db: &'db dyn Db, lhs: Type<'db>, rhs: Type<'db>) -> Self { // When checking subtyping involving a typevar, we can turn the subtyping check into a // constraint (i.e, "is `T` a subtype of `int` becomes the constraint `T ≤ int`), and then @@ -838,7 +910,7 @@ impl<'db> Node<'db> { _ => panic!("at least one type should be a typevar"), }; - self.satisfies(db, constraint) + self.implies(db, constraint) } fn satisfied_by_all_typevars( @@ -859,19 +931,13 @@ impl<'db> Node<'db> { // Returns if some specialization satisfies this constraint set. let some_specialization_satisfies = move |specializations: Node<'db>| { - let when_satisfied = specializations - .satisfies(db, self) - .and(db, specializations) - .simplify(db); - !when_satisfied.is_never_satisfied() + let when_satisfied = specializations.implies(db, self).and(db, specializations); + !when_satisfied.is_never_satisfied(db) }; // Returns if all specializations satisfy this constraint set. let all_specializations_satisfy = move |specializations: Node<'db>| { - let when_satisfied = specializations - .satisfies(db, self) - .and(db, specializations) - .simplify(db); + let when_satisfied = specializations.implies(db, self).and(db, specializations); when_satisfied .iff(db, specializations) .is_always_satisfied(db) @@ -923,7 +989,7 @@ impl<'db> Node<'db> { ) -> Self { bound_typevars .into_iter() - .fold(self.simplify(db), |abstracted, bound_typevar| { + .fold(self, |abstracted, bound_typevar| { abstracted.exists_one(db, bound_typevar) }) } @@ -936,6 +1002,20 @@ impl<'db> Node<'db> { } } + fn exists_one_inner( + self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarIdentity<'db>, + map: &SequentMap<'db>, + path: &mut PathAssignments<'db>, + ) -> Self { + match self { + Node::AlwaysTrue => Node::AlwaysTrue, + Node::AlwaysFalse => Node::AlwaysFalse, + Node::Interior(interior) => interior.exists_one_inner(db, bound_typevar, map, path), + } + } + /// Returns a new BDD that returns the same results as `self`, but with some inputs fixed to /// particular values. (Those variables will not be checked when evaluating the result, and /// will not be present in the result.) @@ -1095,26 +1175,34 @@ impl<'db> Node<'db> { interior.if_false(db).for_each_constraint(db, f); } - /// Returns a simplified version of a BDD, along with the BDD's domain. - fn simplify_and_domain(self, db: &'db dyn Db) -> (Self, Self) { + /// Simplifies a BDD, replacing constraints with simpler or smaller constraints where possible. + /// + /// TODO: [Historical note] This is now used only for display purposes, but previously was also + /// used to ensure that we added the "transitive closure" to each BDD. The constraints in a BDD + /// are not independent; some combinations of constraints can imply other constraints. This + /// affects us in two ways: First, it means that certain combinations are impossible. (If + /// `a → b` then `a ∧ ¬b` can never happen.) Second, it means that certain constraints can be + /// inferred even if they do not explicitly appear in the BDD. It is important to take this + /// into account in several BDD operations (satisfiability, existential quantification, etc). + /// Before, we used this method to _add_ the transitive closure to a BDD, in an attempt to make + /// sure that it holds "all the facts" that would be needed to satisfy any query we might make. + /// We also used this method to calculate the "domain" of the BDD to help rule out invalid + /// inputs. However, this was at odds with using this method for display purposes, where our + /// goal is to _remove_ redundant information, so as to not clutter up the display. To resolve + /// this dilemma, all of the correctness uses have been refactored to use [`SequentMap`] + /// instead. It tracks the same information in a more efficient and lazy way, and never tries + /// to remove redundant information. For expediency, however, we did not make any changes to + /// this method, other than to stop tracking the domain (which was never used for display + /// purposes). That means we have some tech debt here, since there is a lot of duplicate logic + /// between `simplify_for_display` and `SequentMap`. It would be nice to update our display + /// logic to use the sequent map as much as possible. But that can happen later. + fn simplify_for_display(self, db: &'db dyn Db) -> Self { match self { - Node::AlwaysTrue | Node::AlwaysFalse => (self, Node::AlwaysTrue), + Node::AlwaysTrue | Node::AlwaysFalse => self, Node::Interior(interior) => interior.simplify(db), } } - /// Simplifies a BDD, replacing constraints with simpler or smaller constraints where possible. - fn simplify(self, db: &'db dyn Db) -> Self { - let (simplified, _) = self.simplify_and_domain(db); - simplified - } - - /// Returns the domain (the set of allowed inputs) for a BDD. - fn domain(self, db: &'db dyn Db) -> Self { - let (_, domain) = self.simplify_and_domain(db); - domain - } - /// Returns clauses describing all of the variable assignments that cause this BDD to evaluate /// to `true`. (This translates the boolean function that this BDD represents into DNF form.) fn satisfied_clauses(self, db: &'db dyn Db) -> SatisfiedClauses<'db> { @@ -1351,25 +1439,91 @@ impl<'db> InteriorNode<'db> { #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] fn exists_one(self, db: &'db dyn Db, bound_typevar: BoundTypeVarIdentity<'db>) -> Node<'db> { + let map = self.sequent_map(db); + let mut path = PathAssignments::default(); + self.exists_one_inner(db, bound_typevar, map, &mut path) + } + + fn exists_one_inner( + self, + db: &'db dyn Db, + bound_typevar: BoundTypeVarIdentity<'db>, + map: &SequentMap<'db>, + path: &mut PathAssignments<'db>, + ) -> Node<'db> { let self_constraint = self.constraint(db); - let self_typevar = self_constraint.typevar(db).identity(db); - match bound_typevar.cmp(&self_typevar) { + let self_typevar = self_constraint.typevar(db); + match bound_typevar.cmp(&self_typevar.identity(db)) { // If the typevar that this node checks is "later" than the typevar we're abstracting // over, then we have reached a point in the BDD where the abstraction can no longer // affect the result, and we can return early. Ordering::Less => Node::Interior(self), + // If the typevar that this node checks _is_ the typevar we're abstracting over, then // we replace this node with the OR of its if_false/if_true edges. That is, the result // is true if there's any assignment of this node's constraint that is true. + // + // We also have to check if there are any derived facts that depend on the constraint + // we're about to remove. If so, we need to "remember" them by AND-ing them in with the + // corresponding branch. Ordering::Equal => { - let if_true = self.if_true(db).exists_one(db, bound_typevar); - let if_false = self.if_false(db).exists_one(db, bound_typevar); + let if_true = path + .walk_edge(map, self_constraint.when_true(), |path, new_range| { + let branch = + self.if_true(db) + .exists_one_inner(db, bound_typevar, map, path); + path.assignments[new_range] + .iter() + .filter(|assignment| { + // Don't add back any derived facts if they reference the typevar + // that we're trying to remove! + !assignment + .constraint() + .typevar(db) + .is_same_typevar_as(db, self_typevar) + }) + .fold(branch, |branch, assignment| { + branch.and(db, Node::new_satisfied_constraint(db, *assignment)) + }) + }) + .unwrap_or(Node::AlwaysFalse); + let if_false = path + .walk_edge(map, self_constraint.when_false(), |path, new_range| { + let branch = + self.if_false(db) + .exists_one_inner(db, bound_typevar, map, path); + path.assignments[new_range] + .iter() + .filter(|assignment| { + // Don't add back any derived facts if they reference the typevar + // that we're trying to remove! + !assignment + .constraint() + .typevar(db) + .is_same_typevar_as(db, self_typevar) + }) + .fold(branch, |branch, assignment| { + branch.and(db, Node::new_satisfied_constraint(db, *assignment)) + }) + }) + .unwrap_or(Node::AlwaysFalse); if_true.or(db, if_false) } + // Otherwise, we abstract the if_false/if_true edges recursively. Ordering::Greater => { - let if_true = self.if_true(db).exists_one(db, bound_typevar); - let if_false = self.if_false(db).exists_one(db, bound_typevar); + let if_true = path + .walk_edge(map, self_constraint.when_true(), |path, _| { + self.if_true(db) + .exists_one_inner(db, bound_typevar, map, path) + }) + .unwrap_or(Node::AlwaysFalse); + let if_false = path + .walk_edge(map, self_constraint.when_false(), |path, _| { + self.if_false(db) + .exists_one_inner(db, bound_typevar, map, path) + }) + .unwrap_or(Node::AlwaysFalse); Node::new(db, self_constraint, if_true, if_false) } } @@ -1405,14 +1559,24 @@ impl<'db> InteriorNode<'db> { } } - /// Returns a simplified version of a BDD, along with the BDD's domain. + /// Returns a sequent map for this BDD, which records the relationships between the constraints + /// that appear in the BDD. + #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] + fn sequent_map(self, db: &'db dyn Db) -> SequentMap<'db> { + let mut map = SequentMap::default(); + Node::Interior(self).for_each_constraint(db, &mut |constraint| { + map.add(db, constraint); + }); + map + } + + /// Returns a simplified version of a BDD. /// - /// Both are calculated by looking at the relationships that exist between the constraints that + /// This is calculated by looking at the relationships that exist between the constraints that /// are mentioned in the BDD. For instance, if one constraint implies another (`x → y`), then - /// `x ∧ ¬y` is not a valid input, and is excluded from the BDD's domain. At the same time, we - /// can rewrite any occurrences of `x ∨ y` into `y`. + /// `x ∧ ¬y` is not a valid input, and we can rewrite any occurrences of `x ∨ y` into `y`. #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] - fn simplify(self, db: &'db dyn Db) -> (Node<'db>, Node<'db>) { + fn simplify(self, db: &'db dyn Db) -> Node<'db> { // To simplify a non-terminal BDD, we find all pairs of constraints that are mentioned in // the BDD. If any of those pairs can be simplified to some other BDD, we perform a // substitution to replace the pair with the simplification. @@ -1439,7 +1603,6 @@ impl<'db> InteriorNode<'db> { // Repeatedly pop constraint pairs off of the visit queue, checking whether each pair can // be simplified. let mut simplified = Node::Interior(self); - let mut domain = Node::AlwaysTrue; while let Some((left_constraint, right_constraint)) = to_visit.pop() { // If the constraints refer to different typevars, the only simplifications we can make // are of the form `S ≤ T ∧ T ≤ int → S ≤ int`. @@ -1512,9 +1675,6 @@ impl<'db> InteriorNode<'db> { let positive_right_node = Node::new_satisfied_constraint(db, right_constraint.when_true()); let lhs = positive_left_node.and(db, positive_right_node); - let implication = lhs.implies(db, new_node); - domain = domain.and(db, implication); - let intersection = new_node.ite(db, lhs, Node::AlwaysFalse); simplified = simplified.and(db, intersection); continue; @@ -1548,13 +1708,6 @@ impl<'db> InteriorNode<'db> { let negative_larger_node = Node::new_satisfied_constraint(db, larger_constraint.when_false()); - let positive_smaller_node = - Node::new_satisfied_constraint(db, smaller_constraint.when_true()); - - // smaller → larger - let implication = positive_smaller_node.implies(db, positive_larger_node); - domain = domain.and(db, implication); - // larger ∨ smaller = larger simplified = simplified.substitute_union( db, @@ -1620,11 +1773,6 @@ impl<'db> InteriorNode<'db> { let negative_right_node = Node::new_satisfied_constraint(db, right_constraint.when_false()); - // (left ∧ right) → intersection - let implication = (positive_left_node.and(db, positive_right_node)) - .implies(db, positive_intersection_node); - domain = domain.and(db, implication); - // left ∧ right = intersection simplified = simplified.substitute_intersection( db, @@ -1689,11 +1837,6 @@ impl<'db> InteriorNode<'db> { let positive_right_node = Node::new_satisfied_constraint(db, right_constraint.when_true()); - // (left ∧ right) → false - let implication = (positive_left_node.and(db, positive_right_node)) - .implies(db, Node::AlwaysFalse); - domain = domain.and(db, implication); - // left ∧ right = false simplified = simplified.substitute_intersection( db, @@ -1731,7 +1874,7 @@ impl<'db> InteriorNode<'db> { } } - (simplified, domain) + simplified } } @@ -1844,6 +1987,502 @@ impl<'db> ConstraintAssignment<'db> { } } +/// A collection of _sequents_ that describe how the constraints mentioned in a BDD relate to each +/// other. These are used in several BDD operations that need to know about "derived facts" even if +/// they are not mentioned in the BDD directly. These operations involve walking one or more paths +/// from the root node to a terminal node. Each sequent describes paths that are invalid (which are +/// pruned from the search), and new constraints that we can assume to be true even if we haven't +/// seen them directly. +/// +/// We support several kinds of sequent: +/// +/// - `C₁ ∧ C₂ → false`: This indicates that `C₁` and `C₂` are disjoint: it is not possible for +/// both to hold. Any path that assumes both is impossible and can be pruned. +/// +/// - `C₁ ∧ C₂ → D`: This indicates that the intersection of `C₁` and `C₂` can be simplified to +/// `D`. Any path that assumes both `C₁` and `C₂` hold, but assumes `D` does _not_, is impossible +/// and can be pruned. +/// +/// - `C → D`: This indicates that `C` on its own is enough to imply `D`. Any path that assumes `C` +/// holds but `D` does _not_ is impossible and can be pruned. +#[derive(Debug, Default, Eq, PartialEq, get_size2::GetSize, salsa::Update)] +struct SequentMap<'db> { + /// Sequents of the form `C₁ ∧ C₂ → false` + impossibilities: FxHashSet<(ConstrainedTypeVar<'db>, ConstrainedTypeVar<'db>)>, + /// Sequents of the form `C₁ ∧ C₂ → D` + pair_implications: + FxHashMap<(ConstrainedTypeVar<'db>, ConstrainedTypeVar<'db>), Vec>>, + /// Sequents of the form `C → D` + single_implications: FxHashMap, Vec>>, + /// Constraints that we have already processed + processed: FxHashSet>, +} + +impl<'db> SequentMap<'db> { + fn add(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { + // If we've already seen this constraint, we can skip it. + if !self.processed.insert(constraint) { + return; + } + + // Otherwise, check this constraint against all of the other ones we've seen so far, seeing + // if they're related to each other. + let processed = std::mem::take(&mut self.processed); + for other in &processed { + if constraint != *other { + self.add_sequents_for_pair(db, constraint, *other); + } + } + self.processed = processed; + + // And see if we can create any sequents from the constraint on its own. + self.add_sequents_for_single(db, constraint); + } + + fn pair_key( + db: &'db dyn Db, + ante1: ConstrainedTypeVar<'db>, + ante2: ConstrainedTypeVar<'db>, + ) -> (ConstrainedTypeVar<'db>, ConstrainedTypeVar<'db>) { + if ante1.ordering(db) < ante2.ordering(db) { + (ante1, ante2) + } else { + (ante2, ante1) + } + } + + fn add_impossibility( + &mut self, + db: &'db dyn Db, + ante1: ConstrainedTypeVar<'db>, + ante2: ConstrainedTypeVar<'db>, + ) { + self.impossibilities + .insert(Self::pair_key(db, ante1, ante2)); + } + + fn add_pair_implication( + &mut self, + db: &'db dyn Db, + ante1: ConstrainedTypeVar<'db>, + ante2: ConstrainedTypeVar<'db>, + post: ConstrainedTypeVar<'db>, + ) { + if ante1 == post || ante2 == post { + return; + } + self.pair_implications + .entry(Self::pair_key(db, ante1, ante2)) + .or_default() + .push(post); + } + + fn add_single_implication( + &mut self, + ante: ConstrainedTypeVar<'db>, + post: ConstrainedTypeVar<'db>, + ) { + if ante == post { + return; + } + self.single_implications.entry(ante).or_default().push(post); + } + + fn add_sequents_for_single(&mut self, db: &'db dyn Db, constraint: ConstrainedTypeVar<'db>) { + // If the lower or upper bound of this constraint is a typevar, we can propagate the + // constraint: + // + // 1. `(S ≤ T ≤ U) → (S ≤ U)` + // 2. `(S ≤ T ≤ τ) → (S ≤ τ)` + // 3. `(τ ≤ T ≤ U) → (τ ≤ U)` + // + // Technically, (1) also allows `(S = T) → (S = S)`, but the rhs of that is vacuously true, + // so we don't add a sequent for that case. + + let lower = constraint.lower(db); + let upper = constraint.upper(db); + match (lower, upper) { + // Case 1 + (Type::TypeVar(lower_typevar), Type::TypeVar(upper_typevar)) => { + if !lower_typevar.is_same_typevar_as(db, upper_typevar) { + let post_constraint = + ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper); + self.add_single_implication(constraint, post_constraint); + } + } + + // Case 2 + (Type::TypeVar(lower_typevar), _) => { + let post_constraint = + ConstrainedTypeVar::new(db, lower_typevar, Type::Never, upper); + self.add_single_implication(constraint, post_constraint); + } + + // Case 3 + (_, Type::TypeVar(upper_typevar)) => { + let post_constraint = + ConstrainedTypeVar::new(db, upper_typevar, lower, Type::object()); + self.add_single_implication(constraint, post_constraint); + } + + _ => {} + } + } + + fn add_sequents_for_pair( + &mut self, + db: &'db dyn Db, + left_constraint: ConstrainedTypeVar<'db>, + right_constraint: ConstrainedTypeVar<'db>, + ) { + // If either of the constraints has another typevar as a lower/upper bound, the only + // sequents we can add are for the transitive closure. For instance, if we have + // `(S ≤ T) ∧ (T ≤ int)`, then `(S ≤ int)` will also hold, and we should add a sequent for + // this implication. These are the `mutual_sequents` mentioned below — sequents that come + // about because two typevars are mutually constrained. + // + // Complicating things is that `(S ≤ T)` will be encoded differently depending on how `S` + // and `T` compare in our arbitrary BDD variable ordering. + // + // When `S` comes before `T`, `(S ≤ T)` will be encoded as `(Never ≤ S ≤ T)`, and the + // overall antecedent will be `(Never ≤ S ≤ T) ∧ (T ≤ int)`. Those two individual + // constraints constrain different typevars (`S` and `T`, respectively), and are handled by + // `add_mutual_sequents_for_different_typevars`. + // + // When `T` comes before `S`, `(S ≤ T)` will be encoded as `(S ≤ T ≤ object)`, and the + // overall antecedent will be `(S ≤ T ≤ object) ∧ (T ≤ int)`. Those two individual + // constraints both constrain `T`, and are handled by + // `add_mutual_sequents_for_same_typevars`. + // + // If all of the lower and upper bounds are concrete (i.e., not typevars), then there + // several _other_ sequents that we can add, as handled by `add_concrete_sequents`. + let left_typevar = left_constraint.typevar(db); + let right_typevar = right_constraint.typevar(db); + if !left_typevar.is_same_typevar_as(db, right_typevar) { + self.add_mutual_sequents_for_different_typevars(db, left_constraint, right_constraint); + } else if left_constraint.lower(db).is_type_var() + || left_constraint.upper(db).is_type_var() + || right_constraint.lower(db).is_type_var() + || right_constraint.upper(db).is_type_var() + { + self.add_mutual_sequents_for_same_typevars(db, left_constraint, right_constraint); + } else { + self.add_concrete_sequents(db, left_constraint, right_constraint); + } + } + + fn add_mutual_sequents_for_different_typevars( + &mut self, + db: &'db dyn Db, + left_constraint: ConstrainedTypeVar<'db>, + right_constraint: ConstrainedTypeVar<'db>, + ) { + // We've structured our constraints so that a typevar's upper/lower bound can only + // be another typevar if the bound is "later" in our arbitrary ordering. That means + // we only have to check this pair of constraints in one direction — though we do + // have to figure out which of the two typevars is constrained, and which one is + // the upper/lower bound. + let left_typevar = left_constraint.typevar(db); + let right_typevar = right_constraint.typevar(db); + let (bound_typevar, bound_constraint, constrained_typevar, constrained_constraint) = + if left_typevar.can_be_bound_for(db, right_typevar) { + ( + left_typevar, + left_constraint, + right_typevar, + right_constraint, + ) + } else { + ( + right_typevar, + right_constraint, + left_typevar, + left_constraint, + ) + }; + + // We then look for cases where the "constrained" typevar's upper and/or lower bound + // matches the "bound" typevar. If so, we're going to add an implication sequent that + // replaces the upper/lower bound that matched with the bound constraint's corresponding + // bound. + let (new_lower, new_upper) = match ( + constrained_constraint.lower(db), + constrained_constraint.upper(db), + ) { + // (B ≤ C ≤ B) ∧ (BL ≤ B ≤ BU) → (BL ≤ C ≤ BU) + (Type::TypeVar(constrained_lower), Type::TypeVar(constrained_upper)) + if constrained_lower.is_same_typevar_as(db, bound_typevar) + && constrained_upper.is_same_typevar_as(db, bound_typevar) => + { + (bound_constraint.lower(db), bound_constraint.upper(db)) + } + + // (CL ≤ C ≤ B) ∧ (BL ≤ B ≤ BU) → (CL ≤ C ≤ BU) + (constrained_lower, Type::TypeVar(constrained_upper)) + if constrained_upper.is_same_typevar_as(db, bound_typevar) => + { + (constrained_lower, bound_constraint.upper(db)) + } + + // (B ≤ C ≤ CU) ∧ (BL ≤ B ≤ BU) → (BL ≤ C ≤ CU) + (Type::TypeVar(constrained_lower), constrained_upper) + if constrained_lower.is_same_typevar_as(db, bound_typevar) => + { + (bound_constraint.lower(db), constrained_upper) + } + + _ => return, + }; + + let post_constraint = + ConstrainedTypeVar::new(db, constrained_typevar, new_lower, new_upper); + self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); + } + + fn add_mutual_sequents_for_same_typevars( + &mut self, + db: &'db dyn Db, + left_constraint: ConstrainedTypeVar<'db>, + right_constraint: ConstrainedTypeVar<'db>, + ) { + let mut try_one_direction = + |left_constraint: ConstrainedTypeVar<'db>, + right_constraint: ConstrainedTypeVar<'db>| { + let left_lower = left_constraint.lower(db); + let left_upper = left_constraint.upper(db); + let right_lower = right_constraint.lower(db); + let right_upper = right_constraint.upper(db); + let post_constraint = match (left_lower, left_upper) { + (Type::TypeVar(bound_typevar), Type::TypeVar(other_bound_typevar)) + if bound_typevar.is_same_typevar_as(db, other_bound_typevar) => + { + ConstrainedTypeVar::new(db, bound_typevar, right_lower, right_upper) + } + (Type::TypeVar(bound_typevar), _) => { + ConstrainedTypeVar::new(db, bound_typevar, Type::Never, right_upper) + } + (_, Type::TypeVar(bound_typevar)) => { + ConstrainedTypeVar::new(db, bound_typevar, right_lower, Type::object()) + } + _ => return, + }; + self.add_pair_implication(db, left_constraint, right_constraint, post_constraint); + }; + + try_one_direction(left_constraint, right_constraint); + try_one_direction(right_constraint, left_constraint); + } + + fn add_concrete_sequents( + &mut self, + db: &'db dyn Db, + left_constraint: ConstrainedTypeVar<'db>, + right_constraint: ConstrainedTypeVar<'db>, + ) { + match left_constraint.intersect(db, right_constraint) { + Some(intersection_constraint) => { + self.add_pair_implication( + db, + left_constraint, + right_constraint, + intersection_constraint, + ); + self.add_single_implication(intersection_constraint, left_constraint); + self.add_single_implication(intersection_constraint, right_constraint); + } + None => { + self.add_impossibility(db, left_constraint, right_constraint); + } + } + } + + #[expect(dead_code)] // Keep this around for debugging purposes + fn display<'a>(&'a self, db: &'db dyn Db, prefix: &'a dyn Display) -> impl Display + 'a { + struct DisplaySequentMap<'a, 'db> { + map: &'a SequentMap<'db>, + prefix: &'a dyn Display, + db: &'db dyn Db, + } + + impl Display for DisplaySequentMap<'_, '_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut first = true; + let mut maybe_write_prefix = |f: &mut std::fmt::Formatter<'_>| { + if first { + first = false; + Ok(()) + } else { + write!(f, "\n{}", self.prefix) + } + }; + + for (ante1, ante2) in &self.map.impossibilities { + maybe_write_prefix(f)?; + write!( + f, + "{} ∧ {} → false", + ante1.display(self.db), + ante2.display(self.db), + )?; + } + + for ((ante1, ante2), posts) in &self.map.pair_implications { + for post in posts { + maybe_write_prefix(f)?; + write!( + f, + "{} ∧ {} → {}", + ante1.display(self.db), + ante2.display(self.db), + post.display(self.db), + )?; + } + } + + for (ante, posts) in &self.map.single_implications { + for post in posts { + maybe_write_prefix(f)?; + write!(f, "{} → {}", ante.display(self.db), post.display(self.db))?; + } + } + + if first { + f.write_str("[no sequents]")?; + } + Ok(()) + } + } + + DisplaySequentMap { + map: self, + prefix, + db, + } + } +} + +/// The collection of constraints that we know to be true or false at a certain point when +/// traversing a BDD. +#[derive(Debug, Default)] +struct PathAssignments<'db> { + assignments: FxOrderSet>, +} + +impl<'db> PathAssignments<'db> { + /// Walks one of the outgoing edges of an internal BDD node. `assignment` describes the + /// constraint that the BDD node checks, and whether we are following the `if_true` or + /// `if_false` edge. + /// + /// This new assignment might cause this path to become impossible — for instance, if we were + /// already assuming (from an earlier edge in the path) a constraint that is disjoint with this + /// one. We might also be able to infer _other_ assignments that do not appear in the BDD + /// directly, but which are implied from a combination of constraints that we _have_ seen. + /// + /// To handle all of this, you provide a callback. If the path has become impossible, we will + /// return `None` _without invoking the callback_. If the path does not contain any + /// contradictions, we will invoke the callback and return its result (wrapped in `Some`). + /// + /// Your callback will also be provided a slice of all of the constraints that we were able to + /// infer from `assignment` combined with the information we already knew. (For borrow-check + /// reasons, we provide this as a [`Range`]; use that range to index into `self.assignments` to + /// get the list of all of the assignments that we learned from this edge.) + /// + /// You will presumably end up making a recursive call of some kind to keep progressing through + /// the BDD. You should make this call from inside of your callback, so that as you get further + /// down into the BDD structure, we remember all of the information that we have learned from + /// the path we're on. + fn walk_edge( + &mut self, + map: &SequentMap<'db>, + assignment: ConstraintAssignment<'db>, + f: impl FnOnce(&mut Self, Range) -> R, + ) -> Option { + // Record a snapshot of the assignments that we already knew held — both so that we can + // pass along the range of which assignments are new, and so that we can reset back to this + // point before returning. + let start = self.assignments.len(); + + // Add the new assignment and anything we can derive from it. + let result = if self.add_assignment(map, assignment).is_err() { + // If that results in the path now being impossible due to a contradiction, return + // without invoking the callback. + None + } else { + // Otherwise invoke the callback to keep traversing the BDD. The callback will likely + // traverse additional edges, which might add more to our `assignments` set. But even + // if that happens, `start..end` will mark the assignments that were added by the + // `add_assignment` call above — that is, the new assignment for this edge along with + // the derived information we inferred from it. + let end = self.assignments.len(); + Some(f(self, start..end)) + }; + + // Reset back to where we were before following this edge, so that the caller can reuse a + // single instance for the entire BDD traversal. + self.assignments.truncate(start); + result + } + + fn assignment_holds(&self, assignment: ConstraintAssignment<'db>) -> bool { + self.assignments.contains(&assignment) + } + + /// Adds a new assignment, along with any derived information that we can infer from the new + /// assignment combined with the assignments we've already seen. If any of this causes the path + /// to become invalid, due to a contradiction, returns a [`PathAssignmentConflict`] error. + fn add_assignment( + &mut self, + map: &SequentMap<'db>, + assignment: ConstraintAssignment<'db>, + ) -> Result<(), PathAssignmentConflict> { + // First add this assignment. If it causes a conflict, return that as an error. If we've + // already know this assignment holds, just return. + if self.assignments.contains(&assignment.negated()) { + return Err(PathAssignmentConflict); + } + if !self.assignments.insert(assignment) { + return Ok(()); + } + + // Then use our sequents to add additional facts that we know to be true. + // TODO: This is very naive at the moment, partly for expediency, and partly because we + // don't anticipate the sequent maps to be very large. We might consider avoiding the + // brute-force search. + + for (ante1, ante2) in &map.impossibilities { + if self.assignment_holds(ante1.when_true()) && self.assignment_holds(ante2.when_true()) + { + // The sequent map says (ante1 ∧ ante2) is an impossible combination, and the + // current path asserts that both are true. + return Err(PathAssignmentConflict); + } + } + + for ((ante1, ante2), posts) in &map.pair_implications { + for post in posts { + if self.assignment_holds(ante1.when_true()) + && self.assignment_holds(ante2.when_true()) + { + self.add_assignment(map, post.when_true())?; + } + } + } + + for (ante, posts) in &map.single_implications { + for post in posts { + if self.assignment_holds(ante.when_true()) { + self.add_assignment(map, post.when_true())?; + } + } + } + + Ok(()) + } +} + +#[derive(Debug)] +struct PathAssignmentConflict; + /// A single clause in the DNF representation of a BDD #[derive(Clone, Debug, Default, Eq, PartialEq)] struct SatisfiedClause<'db> {