Skip to content

Move normalize_fn_sig to TypeErrCtxt #105185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_session::Session;
use rustc_span::symbol::Ident;
use rustc_span::{self, Span};
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode, ObligationCtxt};

use std::cell::{Cell, RefCell};
use std::ops::Deref;
Expand Down Expand Up @@ -162,6 +162,23 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
infcx: &self.infcx,
typeck_results: Some(self.typeck_results.borrow()),
fallback_has_occurred: self.fallback_has_occurred.get(),
normalize_fn_sig: Box::new(|fn_sig| {
if fn_sig.has_escaping_bound_vars() {
return fn_sig;
}
self.probe(|_| {
let ocx = ObligationCtxt::new_in_snapshot(self);
let normalized_fn_sig =
ocx.normalize(&ObligationCause::dummy(), self.param_env, fn_sig);
if ocx.select_all_or_error().is_empty() {
let normalized_fn_sig = self.resolve_vars_if_possible(normalized_fn_sig);
if !normalized_fn_sig.needs_infer() {
return normalized_fn_sig;
}
}
fn_sig
})
}),
}
}

Expand Down
29 changes: 2 additions & 27 deletions compiler/rustc_hir_typeck/src/inherited.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use super::callee::DeferredCallResolution;

use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::sync::Lrc;
use rustc_hir as hir;
use rustc_hir::def_id::LocalDefId;
use rustc_hir::HirIdMap;
Expand All @@ -11,9 +10,7 @@ use rustc_middle::ty::visit::TypeVisitable;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_span::def_id::LocalDefIdMap;
use rustc_span::{self, Span};
use rustc_trait_selection::traits::{
self, ObligationCause, ObligationCtxt, TraitEngine, TraitEngineExt as _,
};
use rustc_trait_selection::traits::{self, TraitEngine, TraitEngineExt as _};

use std::cell::RefCell;
use std::ops::Deref;
Expand Down Expand Up @@ -92,29 +89,7 @@ impl<'tcx> Inherited<'tcx> {
infcx: tcx
.infer_ctxt()
.ignoring_regions()
.with_opaque_type_inference(DefiningAnchor::Bind(hir_owner.def_id))
.with_normalize_fn_sig_for_diagnostic(Lrc::new(move |infcx, fn_sig| {
if fn_sig.has_escaping_bound_vars() {
return fn_sig;
}
infcx.probe(|_| {
let ocx = ObligationCtxt::new_in_snapshot(infcx);
let normalized_fn_sig = ocx.normalize(
&ObligationCause::dummy(),
// FIXME(compiler-errors): This is probably not the right param-env...
infcx.tcx.param_env(def_id),
fn_sig,
);
if ocx.select_all_or_error().is_empty() {
let normalized_fn_sig =
infcx.resolve_vars_if_possible(normalized_fn_sig);
if !normalized_fn_sig.needs_infer() {
return normalized_fn_sig;
}
}
fn_sig
})
})),
.with_opaque_type_inference(DefiningAnchor::Bind(hir_owner.def_id)),
def_id,
typeck_results: RefCell::new(ty::TypeckResults::new(hir_owner)),
}
Expand Down
4 changes: 0 additions & 4 deletions compiler/rustc_infer/src/infer/at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ impl<'tcx> InferCtxt<'tcx> {
err_count_on_creation: self.err_count_on_creation,
in_snapshot: self.in_snapshot.clone(),
universe: self.universe.clone(),
normalize_fn_sig_for_diagnostic: self
.normalize_fn_sig_for_diagnostic
.as_ref()
.map(|f| f.clone()),
intercrate: self.intercrate,
}
}
Expand Down
13 changes: 3 additions & 10 deletions compiler/rustc_infer/src/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub mod nice_region_error;
pub struct TypeErrCtxt<'a, 'tcx> {
pub infcx: &'a InferCtxt<'tcx>,
pub typeck_results: Option<std::cell::Ref<'a, ty::TypeckResults<'tcx>>>,
pub normalize_fn_sig: Box<dyn Fn(ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx> + 'a>,
pub fallback_has_occurred: bool,
}

Expand Down Expand Up @@ -1007,22 +1008,14 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
}
}

fn normalize_fn_sig_for_diagnostic(&self, sig: ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx> {
if let Some(normalize) = &self.normalize_fn_sig_for_diagnostic {
normalize(self, sig)
} else {
sig
}
}

/// Given two `fn` signatures highlight only sub-parts that are different.
fn cmp_fn_sig(
&self,
sig1: &ty::PolyFnSig<'tcx>,
sig2: &ty::PolyFnSig<'tcx>,
) -> (DiagnosticStyledString, DiagnosticStyledString) {
let sig1 = &self.normalize_fn_sig_for_diagnostic(*sig1);
let sig2 = &self.normalize_fn_sig_for_diagnostic(*sig2);
let sig1 = &(self.normalize_fn_sig)(*sig1);
let sig2 = &(self.normalize_fn_sig)(*sig2);

let get_lifetimes = |sig| {
use rustc_hir::def::Namespace;
Expand Down
32 changes: 7 additions & 25 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,6 @@ pub struct InferCtxt<'tcx> {
/// bound.
universe: Cell<ty::UniverseIndex>,

normalize_fn_sig_for_diagnostic:
Option<Lrc<dyn Fn(&InferCtxt<'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>,

/// During coherence we have to assume that other crates may add
/// additional impls which we currently don't know about.
///
Expand Down Expand Up @@ -573,8 +570,6 @@ pub struct InferCtxtBuilder<'tcx> {
considering_regions: bool,
/// Whether we are in coherence mode.
intercrate: bool,
normalize_fn_sig_for_diagnostic:
Option<Lrc<dyn Fn(&InferCtxt<'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>>,
}

pub trait TyCtxtInferExt<'tcx> {
Expand All @@ -587,7 +582,6 @@ impl<'tcx> TyCtxtInferExt<'tcx> for TyCtxt<'tcx> {
tcx: self,
defining_use_anchor: DefiningAnchor::Error,
considering_regions: true,
normalize_fn_sig_for_diagnostic: None,
intercrate: false,
}
}
Expand Down Expand Up @@ -615,14 +609,6 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
self
}

pub fn with_normalize_fn_sig_for_diagnostic(
mut self,
fun: Lrc<dyn Fn(&InferCtxt<'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>,
) -> Self {
self.normalize_fn_sig_for_diagnostic = Some(fun);
self
}

/// Given a canonical value `C` as a starting point, create an
/// inference context that contains each of the bound values
/// within instantiated as a fresh variable. The `f` closure is
Expand All @@ -644,13 +630,7 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
}

pub fn build(&mut self) -> InferCtxt<'tcx> {
let InferCtxtBuilder {
tcx,
defining_use_anchor,
considering_regions,
ref normalize_fn_sig_for_diagnostic,
intercrate,
} = *self;
let InferCtxtBuilder { tcx, defining_use_anchor, considering_regions, intercrate } = *self;
InferCtxt {
tcx,
defining_use_anchor,
Expand All @@ -666,9 +646,6 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
in_snapshot: Cell::new(false),
skip_leak_check: Cell::new(false),
universe: Cell::new(ty::UniverseIndex::ROOT),
normalize_fn_sig_for_diagnostic: normalize_fn_sig_for_diagnostic
.as_ref()
.map(|f| f.clone()),
intercrate,
}
}
Expand Down Expand Up @@ -709,7 +686,12 @@ impl<'tcx> InferCtxt<'tcx> {
/// Creates a `TypeErrCtxt` for emitting various inference errors.
/// During typeck, use `FnCtxt::err_ctxt` instead.
pub fn err_ctxt(&self) -> TypeErrCtxt<'_, 'tcx> {
TypeErrCtxt { infcx: self, typeck_results: None, fallback_has_occurred: false }
TypeErrCtxt {
infcx: self,
typeck_results: None,
fallback_has_occurred: false,
normalize_fn_sig: Box::new(|fn_sig| fn_sig),
}
}

pub fn is_in_snapshot(&self) -> bool {
Expand Down