Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove sub_relations from the InferCtxt #119989

Merged
merged 5 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1522,10 +1522,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
if self.next_trait_solver()
&& let ty::Alias(..) = ty.kind()
{
match self
// We need to use a separate variable here as otherwise the temporary for
// `self.fulfillment_cx.borrow_mut()` is alive in the `Err` branch, resulting
// in a reentrant borrow, causing an ICE.
let result = self
.at(&self.misc(sp), self.param_env)
.structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut())
{
.structurally_normalize(ty, &mut **self.fulfillment_cx.borrow_mut());
match result {
Ok(normalized_ty) => normalized_ty,
Err(errors) => {
let guar = self.err_ctxt().report_fulfillment_errors(errors);
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_hir_typeck/src/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rustc_hir as hir;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir_analysis::astconv::AstConv;
use rustc_infer::infer;
use rustc_infer::infer::error_reporting::sub_relations::SubRelations;
use rustc_infer::infer::error_reporting::TypeErrCtxt;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
Expand Down Expand Up @@ -155,8 +156,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
///
/// [`InferCtxt::err_ctxt`]: infer::InferCtxt::err_ctxt
pub fn err_ctxt(&'a self) -> TypeErrCtxt<'a, 'tcx> {
let mut sub_relations = SubRelations::default();
sub_relations.add_constraints(
self,
self.fulfillment_cx.borrow_mut().pending_obligations().iter().map(|o| o.predicate),
);
TypeErrCtxt {
infcx: &self.infcx,
sub_relations: RefCell::new(sub_relations),
typeck_results: Some(self.typeck_results.borrow()),
fallback_has_occurred: self.fallback_has_occurred.get(),
normalize_fn_sig: Box::new(|fn_sig| {
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_infer/src/infer/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ mod note_and_explain;
mod suggest;

pub(crate) mod need_type_info;
pub mod sub_relations;
pub use need_type_info::TypeAnnotationNeeded;

pub mod nice_region_error;
Expand Down Expand Up @@ -123,6 +124,8 @@ fn escape_literal(s: &str) -> String {
/// methods which should not be used during the happy path.
pub struct TypeErrCtxt<'a, 'tcx> {
pub infcx: &'a InferCtxt<'tcx>,
pub sub_relations: std::cell::RefCell<sub_relations::SubRelations>,

pub typeck_results: Option<std::cell::Ref<'a, ty::TypeckResults<'tcx>>>,
pub fallback_has_occurred: bool,

Expand Down
43 changes: 20 additions & 23 deletions compiler/rustc_infer/src/infer/error_reporting/need_type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ impl<'tcx> TypeErrCtxt<'_, 'tcx> {
parent_name,
});

let args = if self.infcx.tcx.get_diagnostic_item(sym::iterator_collect_fn)
let args = if self.tcx.get_diagnostic_item(sym::iterator_collect_fn)
== Some(generics_def_id)
{
"Vec<_>".to_string()
Expand Down Expand Up @@ -710,7 +710,7 @@ struct InsertableGenericArgs<'tcx> {
/// While doing so, the currently best spot is stored in `infer_source`.
/// For details on how we rank spots, see [Self::source_cost]
struct FindInferSourceVisitor<'a, 'tcx> {
infcx: &'a InferCtxt<'tcx>,
tecx: &'a TypeErrCtxt<'a, 'tcx>,
typeck_results: &'a TypeckResults<'tcx>,

target: GenericArg<'tcx>,
Expand All @@ -722,12 +722,12 @@ struct FindInferSourceVisitor<'a, 'tcx> {

impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
fn new(
infcx: &'a InferCtxt<'tcx>,
tecx: &'a TypeErrCtxt<'a, 'tcx>,
typeck_results: &'a TypeckResults<'tcx>,
target: GenericArg<'tcx>,
) -> Self {
FindInferSourceVisitor {
infcx,
tecx,
typeck_results,

target,
Expand Down Expand Up @@ -778,7 +778,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
}

// The sources are listed in order of preference here.
let tcx = self.infcx.tcx;
let tcx = self.tecx.tcx;
let ctx = CostCtxt { tcx };
match source.kind {
InferSourceKind::LetBinding { ty, .. } => ctx.ty_cost(ty),
Expand Down Expand Up @@ -829,12 +829,12 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {

fn node_args_opt(&self, hir_id: HirId) -> Option<GenericArgsRef<'tcx>> {
let args = self.typeck_results.node_args_opt(hir_id);
self.infcx.resolve_vars_if_possible(args)
self.tecx.resolve_vars_if_possible(args)
}

fn opt_node_type(&self, hir_id: HirId) -> Option<Ty<'tcx>> {
let ty = self.typeck_results.node_type_opt(hir_id);
self.infcx.resolve_vars_if_possible(ty)
self.tecx.resolve_vars_if_possible(ty)
}

// Check whether this generic argument is the inference variable we
Expand All @@ -849,20 +849,17 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
use ty::{Infer, TyVar};
match (inner_ty.kind(), target_ty.kind()) {
(&Infer(TyVar(a_vid)), &Infer(TyVar(b_vid))) => {
self.infcx.inner.borrow_mut().type_variables().sub_unified(a_vid, b_vid)
self.tecx.sub_relations.borrow_mut().unified(self.tecx, a_vid, b_vid)
}
_ => false,
}
}
(GenericArgKind::Const(inner_ct), GenericArgKind::Const(target_ct)) => {
use ty::InferConst::*;
match (inner_ct.kind(), target_ct.kind()) {
(ty::ConstKind::Infer(Var(a_vid)), ty::ConstKind::Infer(Var(b_vid))) => self
.infcx
.inner
.borrow_mut()
.const_unification_table()
.unioned(a_vid, b_vid),
(ty::ConstKind::Infer(Var(a_vid)), ty::ConstKind::Infer(Var(b_vid))) => {
self.tecx.inner.borrow_mut().const_unification_table().unioned(a_vid, b_vid)
}
_ => false,
}
}
Expand Down Expand Up @@ -917,7 +914,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
&self,
expr: &'tcx hir::Expr<'tcx>,
) -> Box<dyn Iterator<Item = InsertableGenericArgs<'tcx>> + 'a> {
let tcx = self.infcx.tcx;
let tcx = self.tecx.tcx;
match expr.kind {
hir::ExprKind::Path(ref path) => {
if let Some(args) = self.node_args_opt(expr.hir_id) {
Expand Down Expand Up @@ -980,7 +977,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
path: &'tcx hir::Path<'tcx>,
args: GenericArgsRef<'tcx>,
) -> impl Iterator<Item = InsertableGenericArgs<'tcx>> + 'a {
let tcx = self.infcx.tcx;
let tcx = self.tecx.tcx;
let have_turbofish = path.segments.iter().any(|segment| {
segment.args.is_some_and(|args| args.args.iter().any(|arg| arg.is_ty_or_const()))
});
Expand Down Expand Up @@ -1034,7 +1031,7 @@ impl<'a, 'tcx> FindInferSourceVisitor<'a, 'tcx> {
args: GenericArgsRef<'tcx>,
qpath: &'tcx hir::QPath<'tcx>,
) -> Box<dyn Iterator<Item = InsertableGenericArgs<'tcx>> + 'a> {
let tcx = self.infcx.tcx;
let tcx = self.tecx.tcx;
match qpath {
hir::QPath::Resolved(_self_ty, path) => {
Box::new(self.resolved_path_inferred_arg_iter(path, args))
Expand Down Expand Up @@ -1107,7 +1104,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
type NestedFilter = nested_filter::OnlyBodies;

fn nested_visit_map(&mut self) -> Self::Map {
self.infcx.tcx.hir()
self.tecx.tcx.hir()
}

fn visit_local(&mut self, local: &'tcx Local<'tcx>) {
Expand Down Expand Up @@ -1163,7 +1160,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {

#[instrument(level = "debug", skip(self))]
fn visit_expr(&mut self, expr: &'tcx Expr<'tcx>) {
let tcx = self.infcx.tcx;
let tcx = self.tecx.tcx;
match expr.kind {
// When encountering `func(arg)` first look into `arg` and then `func`,
// as `arg` is "more specific".
Expand Down Expand Up @@ -1194,7 +1191,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
if generics.parent.is_none() && generics.has_self {
argument_index += 1;
}
let args = self.infcx.resolve_vars_if_possible(args);
let args = self.tecx.resolve_vars_if_possible(args);
let generic_args =
&generics.own_args_no_defaults(tcx, args)[generics.own_counts().lifetimes..];
let span = match expr.kind {
Expand Down Expand Up @@ -1224,7 +1221,7 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
{
let output = args.as_closure().sig().output().skip_binder();
if self.generic_arg_contains_target(output.into()) {
let body = self.infcx.tcx.hir().body(body);
let body = self.tecx.tcx.hir().body(body);
let should_wrap_expr = if matches!(body.value.kind, ExprKind::Block(..)) {
None
} else {
Expand Down Expand Up @@ -1252,12 +1249,12 @@ impl<'a, 'tcx> Visitor<'tcx> for FindInferSourceVisitor<'a, 'tcx> {
&& let Some(args) = self.node_args_opt(expr.hir_id)
&& args.iter().any(|arg| self.generic_arg_contains_target(arg))
&& let Some(def_id) = self.typeck_results.type_dependent_def_id(expr.hir_id)
&& self.infcx.tcx.trait_of_item(def_id).is_some()
&& self.tecx.tcx.trait_of_item(def_id).is_some()
&& !has_impl_trait(def_id)
{
let successor =
method_args.get(0).map_or_else(|| (")", span.hi()), |arg| (", ", arg.span.lo()));
let args = self.infcx.resolve_vars_if_possible(args);
let args = self.tecx.resolve_vars_if_possible(args);
self.update_infer_source(InferSource {
span: path.ident.span,
kind: InferSourceKind::FullyQualifiedMethodCall {
Expand Down
81 changes: 81 additions & 0 deletions compiler/rustc_infer/src/infer/error_reporting/sub_relations.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::undo_log::NoUndo;
use rustc_data_structures::unify as ut;
use rustc_middle::ty;

use crate::infer::InferCtxt;

#[derive(Debug, Copy, Clone, PartialEq)]
struct SubId(u32);
impl ut::UnifyKey for SubId {
type Value = ();
#[inline]
fn index(&self) -> u32 {
self.0
}
#[inline]
fn from_index(i: u32) -> SubId {
SubId(i)
}
fn tag() -> &'static str {
"SubId"
}
}

/// When reporting ambiguity errors, we sometimes want to
/// treat all inference vars which are subtypes of each
/// others as if they are equal. For this case we compute
/// the transitive closure of our subtype obligations here.
///
/// E.g. when encountering ambiguity errors, we want to suggest
/// specifying some method argument or to add a type annotation
/// to a local variable. Because subtyping cannot change the
/// shape of a type, it's fine if the cause of the ambiguity error
/// is only related to the suggested variable via subtyping.
///
/// Even for something like `let x = returns_arg(); x.method();` the
/// type of `x` is only a supertype of the argument of `returns_arg`. We
/// still want to suggest specifying the type of the argument.
#[derive(Default)]
pub struct SubRelations {
compiler-errors marked this conversation as resolved.
Show resolved Hide resolved
map: FxHashMap<ty::TyVid, SubId>,
table: ut::UnificationTableStorage<SubId>,
}

impl SubRelations {
fn get_id<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, vid: ty::TyVid) -> SubId {
let root_vid = infcx.root_var(vid);
*self.map.entry(root_vid).or_insert_with(|| self.table.with_log(&mut NoUndo).new_key(()))
}

pub fn add_constraints<'tcx>(
&mut self,
infcx: &InferCtxt<'tcx>,
obls: impl IntoIterator<Item = ty::Predicate<'tcx>>,
) {
for p in obls {
let (a, b) = match p.kind().skip_binder() {
ty::PredicateKind::Subtype(ty::SubtypePredicate { a_is_expected: _, a, b }) => {
(a, b)
}
ty::PredicateKind::Coerce(ty::CoercePredicate { a, b }) => (a, b),
_ => continue,
};

match (a.kind(), b.kind()) {
(&ty::Infer(ty::TyVar(a_vid)), &ty::Infer(ty::TyVar(b_vid))) => {
let a = self.get_id(infcx, a_vid);
let b = self.get_id(infcx, b_vid);
self.table.with_log(&mut NoUndo).unify_var_var(a, b).unwrap();
}
_ => continue,
}
}
}

pub fn unified<'tcx>(&mut self, infcx: &InferCtxt<'tcx>, a: ty::TyVid, b: ty::TyVid) -> bool {
let a = self.get_id(infcx, a);
let b = self.get_id(infcx, b);
self.table.with_log(&mut NoUndo).unioned(a, b)
}
}
Loading
Loading