Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uplift next trait solver to rustc_next_trait_solver #126614

Merged
merged 7 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -4520,7 +4520,16 @@ dependencies = [
name = "rustc_next_trait_solver"
version = "0.0.0"
dependencies = [
"bitflags 2.5.0",
"derivative",
"rustc_ast_ir",
"rustc_data_structures",
"rustc_index",
"rustc_macros",
"rustc_serialize",
"rustc_type_ir",
"rustc_type_ir_macros",
"tracing",
]

[[package]]
Expand Down
151 changes: 6 additions & 145 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub use BoundRegionConversionTime::*;
pub use RegionVariableOrigin::*;
pub use SubregionOrigin::*;

use crate::infer::relate::{Relate, RelateResult};
use crate::infer::relate::RelateResult;
use crate::traits::{self, ObligationCause, ObligationInspector, PredicateObligation, TraitEngine};
use error_reporting::TypeErrCtxt;
use free_regions::RegionRelations;
Expand Down Expand Up @@ -44,7 +44,7 @@ use rustc_middle::ty::{ConstVid, EffectVid, FloatVid, IntVid, TyVid};
use rustc_middle::ty::{GenericArg, GenericArgKind, GenericArgs, GenericArgsRef};
use rustc_middle::{bug, span_bug};
use rustc_span::symbol::Symbol;
use rustc_span::{Span, DUMMY_SP};
use rustc_span::Span;
use snapshot::undo_log::InferCtxtUndoLogs;
use std::cell::{Cell, RefCell};
use std::fmt;
Expand Down Expand Up @@ -334,149 +334,6 @@ pub struct InferCtxt<'tcx> {
pub obligation_inspector: Cell<Option<ObligationInspector<'tcx>>>,
}

impl<'tcx> ty::InferCtxtLike for InferCtxt<'tcx> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved to rustc_trait_selection/src/solve/infcx.rs

type Interner = TyCtxt<'tcx>;

fn interner(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn universe_of_ty(&self, vid: TyVid) -> Option<ty::UniverseIndex> {
// FIXME(BoxyUwU): this is kind of jank and means that printing unresolved
// ty infers will give you the universe of the var it resolved to not the universe
// it actually had. It also means that if you have a `?0.1` and infer it to `u8` then
// try to print out `?0.1` it will just print `?0`.
match self.probe_ty_var(vid) {
Err(universe) => Some(universe),
Ok(_) => None,
}
}

fn universe_of_lt(&self, lt: ty::RegionVid) -> Option<ty::UniverseIndex> {
match self.inner.borrow_mut().unwrap_region_constraints().probe_value(lt) {
Err(universe) => Some(universe),
Ok(_) => None,
}
}

fn universe_of_ct(&self, ct: ConstVid) -> Option<ty::UniverseIndex> {
// Same issue as with `universe_of_ty`
match self.probe_const_var(ct) {
Err(universe) => Some(universe),
Ok(_) => None,
}
}

fn root_ty_var(&self, var: TyVid) -> TyVid {
self.root_var(var)
}

fn root_const_var(&self, var: ConstVid) -> ConstVid {
self.root_const_var(var)
}

fn opportunistic_resolve_ty_var(&self, vid: TyVid) -> Ty<'tcx> {
match self.probe_ty_var(vid) {
Ok(ty) => ty,
Err(_) => Ty::new_var(self.tcx, self.root_var(vid)),
}
}

fn opportunistic_resolve_int_var(&self, vid: IntVid) -> Ty<'tcx> {
self.opportunistic_resolve_int_var(vid)
}

fn opportunistic_resolve_float_var(&self, vid: FloatVid) -> Ty<'tcx> {
self.opportunistic_resolve_float_var(vid)
}

fn opportunistic_resolve_ct_var(&self, vid: ConstVid) -> ty::Const<'tcx> {
match self.probe_const_var(vid) {
Ok(ct) => ct,
Err(_) => ty::Const::new_var(self.tcx, self.root_const_var(vid)),
}
}

fn opportunistic_resolve_effect_var(&self, vid: EffectVid) -> ty::Const<'tcx> {
match self.probe_effect_var(vid) {
Some(ct) => ct,
None => {
ty::Const::new_infer(self.tcx, InferConst::EffectVar(self.root_effect_var(vid)))
}
}
}

fn opportunistic_resolve_lt_var(&self, vid: ty::RegionVid) -> ty::Region<'tcx> {
self.inner.borrow_mut().unwrap_region_constraints().opportunistic_resolve_var(self.tcx, vid)
}

fn defining_opaque_types(&self) -> &'tcx ty::List<LocalDefId> {
self.defining_opaque_types
}

fn next_ty_infer(&self) -> Ty<'tcx> {
self.next_ty_var(DUMMY_SP)
}

fn next_const_infer(&self) -> ty::Const<'tcx> {
self.next_const_var(DUMMY_SP)
}

fn fresh_args_for_item(&self, def_id: DefId) -> ty::GenericArgsRef<'tcx> {
self.fresh_args_for_item(DUMMY_SP, def_id)
}

fn instantiate_binder_with_infer<T: TypeFoldable<Self::Interner> + Copy>(
&self,
value: ty::Binder<'tcx, T>,
) -> T {
self.instantiate_binder_with_fresh_vars(
DUMMY_SP,
BoundRegionConversionTime::HigherRankedType,
value,
)
}

fn enter_forall<T: TypeFoldable<TyCtxt<'tcx>> + Copy, U>(
&self,
value: ty::Binder<'tcx, T>,
f: impl FnOnce(T) -> U,
) -> U {
self.enter_forall(value, f)
}

fn relate<T: Relate<TyCtxt<'tcx>>>(
&self,
param_env: ty::ParamEnv<'tcx>,
lhs: T,
variance: ty::Variance,
rhs: T,
) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution> {
self.at(&ObligationCause::dummy(), param_env).relate_no_trace(lhs, variance, rhs)
}

fn eq_structurally_relating_aliases<T: Relate<TyCtxt<'tcx>>>(
&self,
param_env: ty::ParamEnv<'tcx>,
lhs: T,
rhs: T,
) -> Result<Vec<Goal<'tcx, ty::Predicate<'tcx>>>, NoSolution> {
self.at(&ObligationCause::dummy(), param_env)
.eq_structurally_relating_aliases_no_trace(lhs, rhs)
}

fn resolve_vars_if_possible<T>(&self, value: T) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>,
{
self.resolve_vars_if_possible(value)
}

fn probe<T>(&self, probe: impl FnOnce() -> T) -> T {
self.probe(|_| probe())
}
}

/// See the `error_reporting` module for more details.
#[derive(Clone, Copy, Debug, PartialEq, Eq, TypeFoldable, TypeVisitable)]
pub enum ValuePairs<'tcx> {
Expand Down Expand Up @@ -830,6 +687,10 @@ impl<'tcx> InferCtxt<'tcx> {
self.tcx.dcx()
}

pub fn defining_opaque_types(&self) -> &'tcx ty::List<LocalDefId> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since defining_opaque_types is a private field on InferCtxt, and I've chosen to expose only one trait (SolverDelegate) and not two (InferCtxtLike), we gotta make it possible to reference this in the impl in rustc_trait_solver.

self.defining_opaque_types
}

pub fn next_trait_solver(&self) -> bool {
self.next_trait_solver
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/arena.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ macro_rules! arena_types {
rustc_middle::ty::EarlyBinder<'tcx, rustc_middle::ty::Ty<'tcx>>
>,
[] external_constraints: rustc_middle::traits::solve::ExternalConstraintsData<rustc_middle::ty::TyCtxt<'tcx>>,
[] predefined_opaques_in_body: rustc_middle::traits::solve::PredefinedOpaquesData<'tcx>,
[] predefined_opaques_in_body: rustc_middle::traits::solve::PredefinedOpaquesData<rustc_middle::ty::TyCtxt<'tcx>>,
[decode] doc_link_resolutions: rustc_hir::def::DocLinkResMap,
[] stripped_cfg_items: rustc_ast::expand::StrippedCfgItem,
[] mod_child: rustc_middle::metadata::ModChild,
Expand Down
14 changes: 4 additions & 10 deletions compiler/rustc_middle/src/traits/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use rustc_type_ir as ir;
pub use rustc_type_ir::solve::*;

use crate::ty::{
self, FallibleTypeFolder, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor,
self, FallibleTypeFolder, TyCtxt, TypeFoldable, TypeFolder, TypeVisitable, TypeVisitor,
};

mod cache;

pub use cache::{CacheData, EvaluationCache};
pub use cache::EvaluationCache;

pub type Goal<'tcx, P> = ir::solve::Goal<TyCtxt<'tcx>, P>;
pub type QueryInput<'tcx, P> = ir::solve::QueryInput<TyCtxt<'tcx>, P>;
Expand All @@ -19,17 +19,11 @@ pub type CandidateSource<'tcx> = ir::solve::CandidateSource<TyCtxt<'tcx>>;
pub type CanonicalInput<'tcx, P = ty::Predicate<'tcx>> = ir::solve::CanonicalInput<TyCtxt<'tcx>, P>;
pub type CanonicalResponse<'tcx> = ir::solve::CanonicalResponse<TyCtxt<'tcx>>;

/// Additional constraints returned on success.
#[derive(Debug, PartialEq, Eq, Clone, Hash, HashStable, Default)]
pub struct PredefinedOpaquesData<'tcx> {
pub opaque_types: Vec<(ty::OpaqueTypeKey<'tcx>, Ty<'tcx>)>,
}

#[derive(Debug, PartialEq, Eq, Copy, Clone, Hash, HashStable)]
pub struct PredefinedOpaques<'tcx>(pub(crate) Interned<'tcx, PredefinedOpaquesData<'tcx>>);
pub struct PredefinedOpaques<'tcx>(pub(crate) Interned<'tcx, PredefinedOpaquesData<TyCtxt<'tcx>>>);

impl<'tcx> std::ops::Deref for PredefinedOpaques<'tcx> {
type Target = PredefinedOpaquesData<'tcx>;
type Target = PredefinedOpaquesData<TyCtxt<'tcx>>;

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
28 changes: 11 additions & 17 deletions compiler/rustc_middle/src/traits/solve/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use rustc_data_structures::sync::Lock;
use rustc_query_system::cache::WithDepNode;
use rustc_query_system::dep_graph::DepNodeIndex;
use rustc_session::Limit;
use rustc_type_ir::solve::CacheData;

/// The trait solver cache used by `-Znext-solver`.
///
/// FIXME(@lcnr): link to some official documentation of how
Expand All @@ -14,17 +16,9 @@ pub struct EvaluationCache<'tcx> {
map: Lock<FxHashMap<CanonicalInput<'tcx>, CacheEntry<'tcx>>>,
}

#[derive(Debug, PartialEq, Eq)]
pub struct CacheData<'tcx> {
pub result: QueryResult<'tcx>,
pub proof_tree: Option<&'tcx inspect::CanonicalGoalEvaluationStep<TyCtxt<'tcx>>>,
pub additional_depth: usize,
pub encountered_overflow: bool,
}

impl<'tcx> EvaluationCache<'tcx> {
impl<'tcx> rustc_type_ir::inherent::EvaluationCache<TyCtxt<'tcx>> for &'tcx EvaluationCache<'tcx> {
/// Insert a final result into the global cache.
pub fn insert(
fn insert(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
Expand All @@ -48,7 +42,7 @@ impl<'tcx> EvaluationCache<'tcx> {
if cfg!(debug_assertions) {
drop(map);
let expected = CacheData { result, proof_tree, additional_depth, encountered_overflow };
let actual = self.get(tcx, key, [], Limit(additional_depth));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Limit is defined in rustc_session, and I was kinda too lazy to find a good ancestor crate to put Limit in...

Could fix it as a follow-up or fix it in this PR if necessary.

let actual = self.get(tcx, key, [], additional_depth);
if !actual.as_ref().is_some_and(|actual| expected == *actual) {
bug!("failed to lookup inserted element for {key:?}: {expected:?} != {actual:?}");
}
Expand All @@ -59,13 +53,13 @@ impl<'tcx> EvaluationCache<'tcx> {
/// and handling root goals of coinductive cycles.
///
/// If this returns `Some` the cache result can be used.
pub fn get(
fn get(
&self,
tcx: TyCtxt<'tcx>,
key: CanonicalInput<'tcx>,
stack_entries: impl IntoIterator<Item = CanonicalInput<'tcx>>,
available_depth: Limit,
) -> Option<CacheData<'tcx>> {
available_depth: usize,
) -> Option<CacheData<TyCtxt<'tcx>>> {
let map = self.map.borrow();
let entry = map.get(&key)?;

Expand All @@ -76,7 +70,7 @@ impl<'tcx> EvaluationCache<'tcx> {
}

if let Some(ref success) = entry.success {
if available_depth.value_within_limit(success.additional_depth) {
if Limit(available_depth).value_within_limit(success.additional_depth) {
let QueryData { result, proof_tree } = success.data.get(tcx);
return Some(CacheData {
result,
Expand All @@ -87,12 +81,12 @@ impl<'tcx> EvaluationCache<'tcx> {
}
}

entry.with_overflow.get(&available_depth.0).map(|e| {
entry.with_overflow.get(&available_depth).map(|e| {
let QueryData { result, proof_tree } = e.get(tcx);
CacheData {
result,
proof_tree,
additional_depth: available_depth.0,
additional_depth: available_depth,
encountered_overflow: true,
}
})
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_middle/src/ty/adt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,22 @@ impl<'tcx> rustc_type_ir::inherent::AdtDef<TyCtxt<'tcx>> for AdtDef<'tcx> {
self.did()
}

fn is_struct(self) -> bool {
self.is_struct()
}

fn struct_tail_ty(self, interner: TyCtxt<'tcx>) -> Option<ty::EarlyBinder<'tcx, Ty<'tcx>>> {
Some(interner.type_of(self.non_enum_variant().tail_opt()?.did))
}

fn is_phantom_data(self) -> bool {
self.is_phantom_data()
}

fn all_field_tys(
self,
tcx: TyCtxt<'tcx>,
) -> ty::EarlyBinder<'tcx, impl Iterator<Item = Ty<'tcx>>> {
) -> ty::EarlyBinder<'tcx, impl IntoIterator<Item = Ty<'tcx>>> {
ty::EarlyBinder::bind(
self.all_fields().map(move |field| tcx.type_of(field.did).skip_binder()),
)
Expand Down
6 changes: 5 additions & 1 deletion compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ mod valtree;

pub use int::*;
pub use kind::*;
use rustc_span::Span;
use rustc_span::DUMMY_SP;
use rustc_span::{ErrorGuaranteed, Span};
pub use valtree::*;

pub type ConstKind<'tcx> = ir::ConstKind<TyCtxt<'tcx>>;
Expand Down Expand Up @@ -176,6 +176,10 @@ impl<'tcx> rustc_type_ir::inherent::Const<TyCtxt<'tcx>> for Const<'tcx> {
fn new_expr(interner: TyCtxt<'tcx>, expr: ty::Expr<'tcx>) -> Self {
Const::new_expr(interner, expr)
}

fn new_error(interner: TyCtxt<'tcx>, guar: ErrorGuaranteed) -> Self {
Const::new_error(interner, guar)
}
}

impl<'tcx> Const<'tcx> {
Expand Down
Loading
Loading