From d7922fbbda1f32f905d9372028520f869394ba0d Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Fri, 7 Jul 2023 20:42:37 +0000 Subject: [PATCH] Structurally normalize in selection --- .../src/solve/eval_ctxt/select.rs | 64 +++++++++++++------ .../trait-upcast-lhs-needs-normalization.rs | 18 ++++++ 2 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 tests/ui/traits/new-solver/trait-upcast-lhs-needs-normalization.rs diff --git a/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs b/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs index 141369b03370b..086244785d60a 100644 --- a/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs +++ b/compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs @@ -5,7 +5,7 @@ use rustc_hir::def_id::DefId; use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk}; use rustc_infer::traits::util::supertraits; use rustc_infer::traits::{ - Obligation, PolyTraitObligation, PredicateObligation, Selection, SelectionResult, + Obligation, PolyTraitObligation, PredicateObligation, Selection, SelectionResult, TraitEngine, }; use rustc_middle::traits::solve::{CanonicalInput, Certainty, Goal}; use rustc_middle::traits::{ @@ -20,6 +20,8 @@ use crate::solve::eval_ctxt::{EvalCtxt, GenerateProofTree}; use crate::solve::inspect::ProofTreeBuilder; use crate::solve::search_graph::OverflowHandler; use crate::traits::vtable::{count_own_vtable_entries, prepare_vtable_segments, VtblSegment}; +use crate::traits::StructurallyNormalizeExt; +use crate::traits::TraitEngineExt; pub trait InferCtxtSelectExt<'tcx> { fn select_in_new_trait_solver( @@ -227,25 +229,30 @@ fn rematch_object<'tcx>( goal: Goal<'tcx, ty::TraitPredicate<'tcx>>, mut nested: Vec>, ) -> SelectionResult<'tcx, Selection<'tcx>> { - let self_ty = goal.predicate.self_ty(); - let ty::Dynamic(data, _, source_kind) = *self_ty.kind() else { bug!() }; - let source_trait_ref = data.principal().unwrap().with_self_ty(infcx.tcx, self_ty); - - let (is_upcasting, target_trait_ref_unnormalized) = if Some(goal.predicate.def_id()) - == infcx.tcx.lang_items().unsize_trait() - { - assert_eq!(source_kind, ty::Dyn, "cannot upcast dyn*"); - if let ty::Dynamic(data, _, ty::Dyn) = goal.predicate.trait_ref.substs.type_at(1).kind() { - // FIXME: We also need to ensure that the source lifetime outlives the - // target lifetime. This doesn't matter for codegen, though, and only - // *really* matters if the goal's certainty is ambiguous. - (true, data.principal().unwrap().with_self_ty(infcx.tcx, self_ty)) + let a_ty = structurally_normalize(goal.predicate.self_ty(), infcx, goal.param_env, &mut nested); + let ty::Dynamic(data, _, source_kind) = *a_ty.kind() else { bug!() }; + let source_trait_ref = data.principal().unwrap().with_self_ty(infcx.tcx, a_ty); + + let (is_upcasting, target_trait_ref_unnormalized) = + if Some(goal.predicate.def_id()) == infcx.tcx.lang_items().unsize_trait() { + assert_eq!(source_kind, ty::Dyn, "cannot upcast dyn*"); + let b_ty = structurally_normalize( + goal.predicate.trait_ref.substs.type_at(1), + infcx, + goal.param_env, + &mut nested, + ); + if let ty::Dynamic(data, _, ty::Dyn) = *b_ty.kind() { + // FIXME: We also need to ensure that the source lifetime outlives the + // target lifetime. This doesn't matter for codegen, though, and only + // *really* matters if the goal's certainty is ambiguous. + (true, data.principal().unwrap().with_self_ty(infcx.tcx, a_ty)) + } else { + bug!() + } } else { - bug!() - } - } else { - (false, ty::Binder::dummy(goal.predicate.trait_ref)) - }; + (false, ty::Binder::dummy(goal.predicate.trait_ref)) + }; let mut target_trait_ref = None; for candidate_trait_ref in supertraits(infcx.tcx, source_trait_ref) { @@ -445,3 +452,22 @@ fn rematch_unsize<'tcx>( Ok(Some(ImplSource::Builtin(nested))) } + +fn structurally_normalize<'tcx>( + ty: Ty<'tcx>, + infcx: &InferCtxt<'tcx>, + param_env: ty::ParamEnv<'tcx>, + nested: &mut Vec>, +) -> Ty<'tcx> { + if matches!(ty.kind(), ty::Alias(..)) { + let mut engine = >::new(infcx); + let normalized_ty = infcx + .at(&ObligationCause::dummy(), param_env) + .structurally_normalize(ty, &mut *engine) + .expect("normalization shouldn't fail if we got to here"); + nested.extend(engine.pending_obligations()); + normalized_ty + } else { + ty + } +} diff --git a/tests/ui/traits/new-solver/trait-upcast-lhs-needs-normalization.rs b/tests/ui/traits/new-solver/trait-upcast-lhs-needs-normalization.rs new file mode 100644 index 0000000000000..79114b93b78df --- /dev/null +++ b/tests/ui/traits/new-solver/trait-upcast-lhs-needs-normalization.rs @@ -0,0 +1,18 @@ +// check-pass +// compile-flags: -Ztrait-solver=next + +pub trait A {} +pub trait B: A {} + +pub trait Mirror { + type Assoc: ?Sized; +} +impl Mirror for T { + type Assoc = T; +} + +pub fn foo<'a>(x: &'a ::Assoc) -> &'a (dyn A + 'static) { + x +} + +fn main() {}