Skip to content

check for inference var leaks before rollback #100745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 58 additions & 52 deletions compiler/rustc_infer/src/infer/fudge.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder, TypeSuperFoldable};
use rustc_middle::ty::TypeVisitable;
use rustc_middle::ty::{self, ConstVid, FloatVid, IntVid, RegionVid, Ty, TyCtxt, TyVid};

use super::type_variable::TypeVariableOrigin;
Expand Down Expand Up @@ -99,69 +100,74 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
where
F: FnOnce() -> Result<T, E>,
T: TypeFoldable<'tcx>,
E: TypeVisitable<'tcx>,
{
let variable_lengths = self.variable_lengths();
let (mut fudger, value) = self.probe(|_| {
match f() {
Ok(value) => {
let value = self.resolve_vars_if_possible(value);

// At this point, `value` could in principle refer
// to inference variables that have been created during
// the snapshot. Once we exit `probe()`, those are
// going to be popped, so we will have to
// eliminate any references to them.

let mut inner = self.inner.borrow_mut();
let type_vars =
inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
let int_vars = vars_since_snapshot(
&mut inner.int_unification_table(),
variable_lengths.int_var_len,
);
let float_vars = vars_since_snapshot(
&mut inner.float_unification_table(),
variable_lengths.float_var_len,
);
let region_vars = inner
.unwrap_region_constraints()
.vars_since_snapshot(variable_lengths.region_constraints_len);
let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
variable_lengths.const_var_len,
);

let fudger = InferenceFudger {
let snapshot = self.start_snapshot();

match f() {
Ok(value) => {
let value = self.resolve_vars_if_possible(value);

// At this point, `value` could in principle refer
// to inference variables that have been created during
// the snapshot. Once we exit the snapshot, those are
// going to be popped, so we will have to
// eliminate any references to them.

let mut inner = self.inner.borrow_mut();
let type_vars =
inner.type_variables().vars_since_snapshot(variable_lengths.type_var_len);
let int_vars = vars_since_snapshot(
&mut inner.int_unification_table(),
variable_lengths.int_var_len,
);
let float_vars = vars_since_snapshot(
&mut inner.float_unification_table(),
variable_lengths.float_var_len,
);
let region_vars = inner
.unwrap_region_constraints()
.vars_since_snapshot(variable_lengths.region_constraints_len);
let const_vars = const_vars_since_snapshot(
&mut inner.const_unification_table(),
variable_lengths.const_var_len,
);
drop(inner);

self.rollback_to("fudge_inference_if_ok -- ok", snapshot);

// At this point, we need to replace any of the now-popped
// type/region variables that appear in `value` with a fresh
// variable of the appropriate kind. We can't do this during
// the probe because they would just get popped then too. =)

// Micro-optimization: if no variables have been created, then
// `value` can't refer to any of them. =) So we can just return it.
if type_vars.0.is_empty()
&& int_vars.is_empty()
&& float_vars.is_empty()
&& region_vars.0.is_empty()
&& const_vars.0.is_empty()
{
Ok(value)
} else {
Ok(value.fold_with(&mut InferenceFudger {
infcx: self,
type_vars,
int_vars,
float_vars,
region_vars,
const_vars,
};

Ok((fudger, value))
}))
}
Err(e) => Err(e),
}
})?;

// At this point, we need to replace any of the now-popped
// type/region variables that appear in `value` with a fresh
// variable of the appropriate kind. We can't do this during
// the probe because they would just get popped then too. =)

// Micro-optimization: if no variables have been created, then
// `value` can't refer to any of them. =) So we can just return it.
if fudger.type_vars.0.is_empty()
&& fudger.int_vars.is_empty()
&& fudger.float_vars.is_empty()
&& fudger.region_vars.0.is_empty()
&& fudger.const_vars.0.is_empty()
{
Ok(value)
} else {
Ok(value.fold_with(&mut fudger))
Err(e) => {
debug_assert!(!e.needs_infer(), "fudge_inference_if_ok: leaking infer vars: {e:?}");
self.rollback_to("fudge_inference_if_ok -- error", snapshot);
return Err(e);
}
}
}
}
Expand Down
14 changes: 11 additions & 3 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,13 @@ pub mod type_variable;
mod undo_log;

#[must_use]
#[derive(Debug)]
#[derive(Debug, Clone, TypeFoldable, TypeVisitable)]
pub struct InferOk<'tcx, T> {
pub value: T,
pub obligations: PredicateObligations<'tcx>,
}
pub type InferResult<'tcx, T> = Result<InferOk<'tcx, T>, TypeError<'tcx>>;

pub type Bound<T> = Option<T>;
pub type UnitResult<'tcx> = RelateResult<'tcx, ()>; // "unify result"
pub type FixupResult<'tcx, T> = Result<T, FixupError<'tcx>>; // "fixup result"

Expand Down Expand Up @@ -845,6 +844,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
pub fn commit_if_ok<T, E, F>(&self, f: F) -> Result<T, E>
where
F: FnOnce(&CombinedSnapshot<'a, 'tcx>) -> Result<T, E>,
E: TypeVisitable<'tcx>,
{
let snapshot = self.start_snapshot();
let r = f(&snapshot);
Expand All @@ -853,7 +853,8 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
Ok(_) => {
self.commit_from(snapshot);
}
Err(_) => {
Err(ref e) => {
debug_assert!(!e.needs_infer(), "commit_if_ok: leaking infer vars: {e:?}");
self.rollback_to("commit_if_ok -- error", snapshot);
}
}
Expand All @@ -865,9 +866,13 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
pub fn probe<R, F>(&self, f: F) -> R
where
F: FnOnce(&CombinedSnapshot<'a, 'tcx>) -> R,
R: TypeVisitable<'tcx>,
{
let snapshot = self.start_snapshot();

let r = f(&snapshot);

debug_assert!(!r.needs_infer(), "probe: leaking infer vars: {r:?}");
self.rollback_to("probe", snapshot);
r
}
Expand All @@ -877,13 +882,16 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> {
pub fn probe_maybe_skip_leak_check<R, F>(&self, should_skip: bool, f: F) -> R
where
F: FnOnce(&CombinedSnapshot<'a, 'tcx>) -> R,
R: TypeVisitable<'tcx>,
{
let snapshot = self.start_snapshot();
let was_skip_leak_check = self.skip_leak_check.get();
if should_skip {
self.skip_leak_check.set(true);
}
let r = f(&snapshot);

debug_assert!(!r.needs_infer(), "probe_maybe_skip_leak_check: leaking infer vars: {r:?}");
self.rollback_to("probe", snapshot);
self.skip_leak_check.set(was_skip_leak_check);
r
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_infer/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub use rustc_middle::traits::{EvaluationResult, Reveal};
pub(crate) type UndoLog<'tcx> =
snapshot_map::UndoLog<ProjectionCacheKey<'tcx>, ProjectionCacheEntry<'tcx>>;

#[derive(Clone)]
#[derive(Clone, TypeFoldable, TypeVisitable)]
pub struct MismatchedProjectionTypes<'tcx> {
pub err: ty::error::TypeError<'tcx>,
}
Expand Down
9 changes: 6 additions & 3 deletions compiler/rustc_middle/src/infer/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,12 @@ TrivialTypeTraversalAndLiftImpls! {
}
}

TrivialTypeTraversalImpls! {
for <'tcx> {
crate::infer::canonical::CanonicalVarInfos<'tcx>,
impl<'tcx> ty::TypeFoldable<'tcx> for CanonicalVarInfos<'tcx> {
fn try_fold_with<F: ty::FallibleTypeFolder<'tcx>>(
self,
_: &mut F,
) -> ::std::result::Result<CanonicalVarInfos<'tcx>, F::Error> {
Ok(self)
}
}

Expand Down
7 changes: 0 additions & 7 deletions compiler/rustc_middle/src/mir/type_visitable.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! `TypeVisitable` implementations for MIR types

use super::*;
use crate::ty;

impl<'tcx> TypeVisitable<'tcx> for Terminator<'tcx> {
fn visit_with<V: TypeVisitor<'tcx>>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
Expand Down Expand Up @@ -67,12 +66,6 @@ impl<'tcx> TypeVisitable<'tcx> for Place<'tcx> {
}
}

impl<'tcx> TypeVisitable<'tcx> for &'tcx ty::List<PlaceElem<'tcx>> {
fn visit_with<V: TypeVisitor<'tcx>>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
self.iter().try_for_each(|t| t.visit_with(visitor))
}
}

impl<'tcx> TypeVisitable<'tcx> for Rvalue<'tcx> {
fn visit_with<V: TypeVisitor<'tcx>>(&self, visitor: &mut V) -> ControlFlow<V::BreakTy> {
use crate::mir::Rvalue::*;
Expand Down
9 changes: 8 additions & 1 deletion compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub enum Reveal {
///
/// We do not want to intern this as there are a lot of obligation causes which
/// only live for a short period of time.
#[derive(Clone, Debug, PartialEq, Eq, Lift)]
#[derive(Clone, Debug, PartialEq, Eq, Lift, TypeVisitable)]
pub struct ObligationCause<'tcx> {
pub span: Span,

Expand Down Expand Up @@ -186,13 +186,15 @@ impl<'tcx> ObligationCause<'tcx> {
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
#[derive(TypeVisitable)]
pub struct UnifyReceiverContext<'tcx> {
pub assoc_item: ty::AssocItem,
pub param_env: ty::ParamEnv<'tcx>,
pub substs: SubstsRef<'tcx>,
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift, Default)]
#[derive(TypeVisitable)]
pub struct InternedObligationCauseCode<'tcx> {
/// `None` for `ObligationCauseCode::MiscObligation` (a common case, occurs ~60% of
/// the time). `Some` otherwise.
Expand Down Expand Up @@ -221,6 +223,7 @@ impl<'tcx> std::ops::Deref for InternedObligationCauseCode<'tcx> {
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
#[derive(TypeVisitable)]
pub enum ObligationCauseCode<'tcx> {
/// Not well classified or should be obvious from the span.
MiscObligation,
Expand Down Expand Up @@ -415,6 +418,7 @@ pub enum ObligationCauseCode<'tcx> {
/// we can walk in order to obtain precise spans for any
/// 'nested' types (e.g. `Foo` in `Option<Foo>`).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, HashStable)]
#[derive(TypeVisitable)]
pub enum WellFormedLoc {
/// Use the type of the provided definition.
Ty(LocalDefId),
Expand All @@ -432,6 +436,7 @@ pub enum WellFormedLoc {
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
#[derive(TypeVisitable)]
pub struct ImplDerivedObligationCause<'tcx> {
pub derived: DerivedObligationCause<'tcx>,
pub impl_def_id: DefId,
Expand Down Expand Up @@ -479,6 +484,7 @@ impl<'tcx> ty::Lift<'tcx> for StatementAsExpression {
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
#[derive(TypeVisitable)]
pub struct MatchExpressionArmCause<'tcx> {
pub arm_block_id: Option<hir::HirId>,
pub arm_ty: Ty<'tcx>,
Expand All @@ -505,6 +511,7 @@ pub struct IfExpressionCause<'tcx> {
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
#[derive(TypeVisitable)]
pub struct DerivedObligationCause<'tcx> {
/// The trait predicate of the parent obligation that led to the
/// current obligation. Note that only trait obligations lead to
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_middle/src/traits/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub type CanonicalTypeOpProvePredicateGoal<'tcx> =
pub type CanonicalTypeOpNormalizeGoal<'tcx, T> =
Canonical<'tcx, ty::ParamEnvAnd<'tcx, type_op::Normalize<T>>>;

#[derive(Copy, Clone, Debug, HashStable)]
#[derive(Copy, Clone, Debug, HashStable, TypeVisitable)]
pub struct NoSolution;

pub type Fallible<T> = Result<T, NoSolution>;
Expand Down Expand Up @@ -178,7 +178,7 @@ impl<'tcx> FromIterator<DropckConstraint<'tcx>> for DropckConstraint<'tcx> {
}
}

#[derive(Debug, HashStable)]
#[derive(Debug, HashStable, TypeVisitable)]
pub struct CandidateStep<'tcx> {
pub self_ty: Canonical<'tcx, QueryResponse<'tcx, Ty<'tcx>>>,
pub autoderefs: usize,
Expand All @@ -191,7 +191,7 @@ pub struct CandidateStep<'tcx> {
pub unsize: bool,
}

#[derive(Copy, Clone, Debug, HashStable)]
#[derive(Copy, Clone, Debug, HashStable, TypeVisitable)]
pub struct MethodAutoderefStepsResult<'tcx> {
/// The valid autoderef steps that could be find.
pub steps: &'tcx [CandidateStep<'tcx>],
Expand All @@ -202,7 +202,7 @@ pub struct MethodAutoderefStepsResult<'tcx> {
pub reached_recursion_limit: bool,
}

#[derive(Debug, HashStable)]
#[derive(Debug, HashStable, TypeVisitable)]
pub struct MethodAutoderefBadTy<'tcx> {
pub reached_raw_pointer: bool,
pub ty: Canonical<'tcx, QueryResponse<'tcx, Ty<'tcx>>>,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_middle/src/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ pub enum SelectionCandidate<'tcx> {
/// so they are noops when unioned with a definite error, and within
/// the categories it's easy to see that the unions are correct.
#[derive(Copy, Clone, Debug, PartialOrd, Ord, PartialEq, Eq, HashStable)]
#[derive(TypeVisitable)]
pub enum EvaluationResult {
/// Evaluation successful.
EvaluatedToOk,
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use std::collections::BTreeMap;
///
/// To implement this conveniently, use the derive macro located in
/// `rustc_macros`.
pub trait TypeFoldable<'tcx>: TypeVisitable<'tcx> {
pub trait TypeFoldable<'tcx>: TypeVisitable<'tcx> + Clone {
/// The entry point for folding. To fold a value `t` with a folder `f`
/// call: `t.try_fold_with(f)`.
///
Expand Down
Loading