Skip to content

Commit 8d7707f

Browse files
committed
Normalize associated types with bound vars
1 parent b03ccac commit 8d7707f

File tree

53 files changed

+716
-358
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+716
-358
lines changed

compiler/rustc_middle/src/ty/layout.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -2474,10 +2474,9 @@ impl<'tcx> ty::Instance<'tcx> {
24742474
// `src/test/ui/polymorphization/normalized_sig_types.rs`), and codegen not keeping
24752475
// track of a polymorphization `ParamEnv` to allow normalizing later.
24762476
let mut sig = match *ty.kind() {
2477-
ty::FnDef(def_id, substs) if tcx.sess.opts.debugging_opts.polymorphize => tcx
2477+
ty::FnDef(def_id, substs) => tcx
24782478
.normalize_erasing_regions(tcx.param_env(def_id), tcx.fn_sig(def_id))
24792479
.subst(tcx, substs),
2480-
ty::FnDef(def_id, substs) => tcx.fn_sig(def_id).subst(tcx, substs),
24812480
_ => unreachable!(),
24822481
};
24832482

compiler/rustc_mir/src/borrow_check/type_check/input_output.rs

+30-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
1010
use rustc_infer::infer::LateBoundRegionConversionTime;
1111
use rustc_middle::mir::*;
12-
use rustc_middle::ty::Ty;
12+
use rustc_middle::traits::ObligationCause;
13+
use rustc_middle::ty::{self, Ty};
14+
use rustc_trait_selection::traits::query::normalize::AtExt;
1315

1416
use rustc_index::vec::Idx;
1517
use rustc_span::Span;
@@ -80,6 +82,33 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
8082
let local = Local::new(argument_index + 1);
8183

8284
let mir_input_ty = body.local_decls[local].ty;
85+
// FIXME(jackh726): This is a hack. It's somewhat like
86+
// `rustc_traits::normalize_after_erasing_regions`. Ideally, we'd
87+
// like to normalize *before* inserting into `local_decls`, but I
88+
// couldn't figure out where the heck that was.
89+
let mir_input_ty = match self
90+
.infcx
91+
.at(&ObligationCause::dummy(), ty::ParamEnv::empty())
92+
.normalize(mir_input_ty)
93+
{
94+
Ok(n) => {
95+
debug!("equate_inputs_and_outputs: {:?}", n);
96+
if n.obligations.iter().all(|o| {
97+
matches!(
98+
o.predicate.kind().skip_binder(),
99+
ty::PredicateKind::RegionOutlives(_)
100+
)
101+
}) {
102+
n.value
103+
} else {
104+
mir_input_ty
105+
}
106+
}
107+
Err(_) => {
108+
debug!("equate_inputs_and_outputs: NoSolution");
109+
mir_input_ty
110+
}
111+
};
83112
let mir_input_span = body.local_decls[local].source_info.span;
84113
self.equate_normalized_input_or_output(
85114
normalized_input_ty,

compiler/rustc_mir/src/borrow_check/type_check/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1053,6 +1053,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
10531053
);
10541054
for user_annotation in self.user_type_annotations {
10551055
let CanonicalUserTypeAnnotation { span, ref user_ty, inferred_ty } = *user_annotation;
1056+
let inferred_ty = self.normalize(inferred_ty, Locations::All(span));
10561057
let annotation = self.instantiate_canonical_with_fresh_inference_vars(span, user_ty);
10571058
match annotation {
10581059
UserType::Ty(mut ty) => {

compiler/rustc_trait_selection/src/traits/project.rs

+32-39
Original file line numberDiff line numberDiff line change
@@ -362,25 +362,25 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
362362
if !needs_normalization(&ty, self.param_env.reveal()) {
363363
return ty;
364364
}
365-
// We don't want to normalize associated types that occur inside of region
366-
// binders, because they may contain bound regions, and we can't cope with that.
367-
//
368-
// Example:
369-
//
370-
// for<'a> fn(<T as Foo<&'a>>::A)
371-
//
372-
// Instead of normalizing `<T as Foo<&'a>>::A` here, we'll
373-
// normalize it when we instantiate those bound regions (which
374-
// should occur eventually).
375-
376-
let ty = ty.super_fold_with(self);
365+
366+
// N.b. while we want to call `super_fold_with(self)` on `ty` before
367+
// normalization, we wait until we know whether we need to normalize the
368+
// current type. If we do, then we only fold the ty *after* replacing bound
369+
// vars with placeholders. This means that nested types don't need to replace
370+
// bound vars at the current binder level or above. A key assumption here is
371+
// that folding the type can't introduce new bound vars.
372+
377373
match *ty.kind() {
378-
ty::Opaque(def_id, substs) if !substs.has_escaping_bound_vars() => {
374+
ty::Opaque(def_id, substs) => {
379375
// Only normalize `impl Trait` after type-checking, usually in codegen.
380376
match self.param_env.reveal() {
381-
Reveal::UserFacing => ty,
377+
Reveal::UserFacing => ty.super_fold_with(self),
382378

383379
Reveal::All => {
380+
// N.b. there is an assumption here all this code can handle
381+
// escaping bound vars.
382+
383+
let substs = substs.super_fold_with(self);
384384
let recursion_limit = self.tcx().recursion_limit();
385385
if !recursion_limit.value_within_limit(self.depth) {
386386
let obligation = Obligation::with_depth(
@@ -403,18 +403,13 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
403403
}
404404

405405
ty::Projection(data) if !data.has_escaping_bound_vars() => {
406-
// This is kind of hacky -- we need to be able to
407-
// handle normalization within binders because
408-
// otherwise we wind up a need to normalize when doing
409-
// trait matching (since you can have a trait
410-
// obligation like `for<'a> T::B: Fn(&'a i32)`), but
411-
// we can't normalize with bound regions in scope. So
412-
// far now we just ignore binders but only normalize
413-
// if all bound regions are gone (and then we still
414-
// have to renormalize whenever we instantiate a
415-
// binder). It would be better to normalize in a
416-
// binding-aware fashion.
406+
// This branch is *mostly* just an optimization: when we don't
407+
// have escaping bound vars, we don't need to replace them with
408+
// placeholders (see branch below). *Also*, we know that we can
409+
// register an obligation to *later* project, since we know
410+
// there won't be bound vars there.
417411

412+
let data = data.super_fold_with(self);
418413
let normalized_ty = normalize_projection_type(
419414
self.selcx,
420415
self.param_env,
@@ -433,22 +428,19 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
433428
normalized_ty
434429
}
435430

436-
ty::Projection(data) if !data.trait_ref(self.tcx()).has_escaping_bound_vars() => {
437-
// Okay, so you thought the previous branch was hacky. Well, to
438-
// extend upon this, when the *trait ref* doesn't have escaping
439-
// bound vars, but the associated item *does* (can only occur
440-
// with GATs), then we might still be able to project the type.
441-
// For this, we temporarily replace the bound vars with
442-
// placeholders. Note though, that in the case that we still
443-
// can't project for whatever reason (e.g. self type isn't
444-
// known enough), we *can't* register an obligation and return
445-
// an inference variable (since then that obligation would have
446-
// bound vars and that's a can of worms). Instead, we just
447-
// give up and fall back to pretending like we never tried!
431+
ty::Projection(data) => {
432+
// If there are escaping bound vars, we temporarily replace the
433+
// bound vars with placeholders. Note though, that in the cas
434+
// that we still can't project for whatever reason (e.g. self
435+
// type isn't known enough), we *can't* register an obligation
436+
// and return an inference variable (since then that obligation
437+
// would have bound vars and that's a can of worms). Instead,
438+
// we just give up and fall back to pretending like we never tried!
448439

449440
let infcx = self.selcx.infcx();
450441
let (data, mapped_regions, mapped_types, mapped_consts) =
451442
BoundVarReplacer::replace_bound_vars(infcx, &mut self.universes, data);
443+
let data = data.super_fold_with(self);
452444
let normalized_ty = opt_normalize_projection_type(
453445
self.selcx,
454446
self.param_env,
@@ -459,7 +451,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
459451
)
460452
.ok()
461453
.flatten()
462-
.unwrap_or_else(|| ty);
454+
.unwrap_or_else(|| ty.super_fold_with(self));
463455

464456
let normalized_ty = PlaceholderReplacer::replace_placeholders(
465457
infcx,
@@ -479,7 +471,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
479471
normalized_ty
480472
}
481473

482-
_ => ty,
474+
_ => ty.super_fold_with(self),
483475
}
484476
}
485477

@@ -908,6 +900,7 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
908900
// an impl, where-clause etc) and hence we must
909901
// re-normalize it
910902

903+
let projected_ty = selcx.infcx().resolve_vars_if_possible(projected_ty);
911904
debug!(?projected_ty, ?depth, ?projected_obligations);
912905

913906
let result = if projected_ty.has_projections() {

compiler/rustc_trait_selection/src/traits/query/normalize.rs

+83-31
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use rustc_infer::traits::Normalized;
1414
use rustc_middle::mir;
1515
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder};
1616
use rustc_middle::ty::subst::Subst;
17-
use rustc_middle::ty::{self, Ty, TyCtxt};
17+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitor};
18+
19+
use std::ops::ControlFlow;
1820

1921
use super::NoSolution;
2022

@@ -65,6 +67,14 @@ impl<'cx, 'tcx> AtExt<'tcx> for At<'cx, 'tcx> {
6567
universes: vec![],
6668
};
6769

70+
if value.has_escaping_bound_vars() {
71+
let mut max_visitor =
72+
MaxEscapingBoundVarVisitor { outer_index: ty::INNERMOST, escaping: 0 };
73+
value.visit_with(&mut max_visitor);
74+
if max_visitor.escaping > 0 {
75+
normalizer.universes.extend((0..max_visitor.escaping).map(|_| None));
76+
}
77+
}
6878
let result = value.fold_with(&mut normalizer);
6979
info!(
7080
"normalize::<{}>: result={:?} with {} obligations",
@@ -85,6 +95,58 @@ impl<'cx, 'tcx> AtExt<'tcx> for At<'cx, 'tcx> {
8595
}
8696
}
8797

98+
/// Visitor to find the maximum escaping bound var
99+
struct MaxEscapingBoundVarVisitor {
100+
// The index which would count as escaping
101+
outer_index: ty::DebruijnIndex,
102+
escaping: usize,
103+
}
104+
105+
impl<'tcx> TypeVisitor<'tcx> for MaxEscapingBoundVarVisitor {
106+
fn visit_binder<T: TypeFoldable<'tcx>>(
107+
&mut self,
108+
t: &ty::Binder<'tcx, T>,
109+
) -> ControlFlow<Self::BreakTy> {
110+
self.outer_index.shift_in(1);
111+
let result = t.super_visit_with(self);
112+
self.outer_index.shift_out(1);
113+
result
114+
}
115+
116+
#[inline]
117+
fn visit_ty(&mut self, t: Ty<'tcx>) -> ControlFlow<Self::BreakTy> {
118+
if t.outer_exclusive_binder() > self.outer_index {
119+
self.escaping = self
120+
.escaping
121+
.max(t.outer_exclusive_binder().as_usize() - self.outer_index.as_usize());
122+
}
123+
ControlFlow::CONTINUE
124+
}
125+
126+
#[inline]
127+
fn visit_region(&mut self, r: ty::Region<'tcx>) -> ControlFlow<Self::BreakTy> {
128+
match *r {
129+
ty::ReLateBound(debruijn, _) if debruijn > self.outer_index => {
130+
self.escaping =
131+
self.escaping.max(debruijn.as_usize() - self.outer_index.as_usize());
132+
}
133+
_ => {}
134+
}
135+
ControlFlow::CONTINUE
136+
}
137+
138+
fn visit_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> ControlFlow<Self::BreakTy> {
139+
match ct.val {
140+
ty::ConstKind::Bound(debruijn, _) if debruijn >= self.outer_index => {
141+
self.escaping =
142+
self.escaping.max(debruijn.as_usize() - self.outer_index.as_usize());
143+
ControlFlow::CONTINUE
144+
}
145+
_ => ct.super_visit_with(self),
146+
}
147+
}
148+
}
149+
88150
struct QueryNormalizer<'cx, 'tcx> {
89151
infcx: &'cx InferCtxt<'cx, 'tcx>,
90152
cause: &'cx ObligationCause<'tcx>,
@@ -121,14 +183,25 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
121183
return ty;
122184
}
123185

124-
let ty = ty.super_fold_with(self);
186+
// N.b. while we want to call `super_fold_with(self)` on `ty` before
187+
// normalization, we wait until we know whether we need to normalize the
188+
// current type. If we do, then we only fold the ty *after* replacing bound
189+
// vars with placeholders. This means that nested types don't need to replace
190+
// bound vars at the current binder level or above. A key assumption here is
191+
// that folding the type can't introduce new bound vars.
192+
193+
// Wrap this in a closure so we don't accidentally return from the outer function
125194
let res = (|| match *ty.kind() {
126-
ty::Opaque(def_id, substs) if !substs.has_escaping_bound_vars() => {
195+
ty::Opaque(def_id, substs) => {
127196
// Only normalize `impl Trait` after type-checking, usually in codegen.
128197
match self.param_env.reveal() {
129-
Reveal::UserFacing => ty,
198+
Reveal::UserFacing => ty.super_fold_with(self),
130199

131200
Reveal::All => {
201+
// N.b. there is an assumption here all this code can handle
202+
// escaping bound vars.
203+
204+
let substs = substs.super_fold_with(self);
132205
let recursion_limit = self.tcx().recursion_limit();
133206
if !recursion_limit.value_within_limit(self.anon_depth) {
134207
let obligation = Obligation::with_depth(
@@ -161,19 +234,11 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
161234
}
162235

163236
ty::Projection(data) if !data.has_escaping_bound_vars() => {
164-
// This is kind of hacky -- we need to be able to
165-
// handle normalization within binders because
166-
// otherwise we wind up a need to normalize when doing
167-
// trait matching (since you can have a trait
168-
// obligation like `for<'a> T::B: Fn(&'a i32)`), but
169-
// we can't normalize with bound regions in scope. So
170-
// far now we just ignore binders but only normalize
171-
// if all bound regions are gone (and then we still
172-
// have to renormalize whenever we instantiate a
173-
// binder). It would be better to normalize in a
174-
// binding-aware fashion.
237+
// This branch is just an optimization: when we don't have escaping bound vars,
238+
// we don't need to replace them with placeholders (see branch below).
175239

176240
let tcx = self.infcx.tcx;
241+
let data = data.super_fold_with(self);
177242

178243
let mut orig_values = OriginalQueryValues::default();
179244
// HACK(matthewjasper) `'static` is special-cased in selection,
@@ -217,22 +282,9 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
217282
}
218283
}
219284
}
220-
ty::Projection(data) if !data.trait_ref(self.infcx.tcx).has_escaping_bound_vars() => {
221-
// See note in `rustc_trait_selection::traits::project`
222-
223-
// One other point mentioning: In `traits::project`, if a
224-
// projection can't be normalized, we return an inference variable
225-
// and register an obligation to later resolve that. Here, the query
226-
// will just return ambiguity. In both cases, the effect is the same: we only want
227-
// to return `ty` because there are bound vars that we aren't yet handling in a more
228-
// complete way.
229285

230-
// `BoundVarReplacer` can't handle escaping bound vars. Ideally, we want this before even calling
231-
// `QueryNormalizer`, but some const-generics tests pass escaping bound vars.
232-
// Also, use `ty` so we get that sweet `outer_exclusive_binder` optimization
233-
assert!(!ty.has_vars_bound_at_or_above(ty::DebruijnIndex::from_usize(
234-
self.universes.len()
235-
)));
286+
ty::Projection(data) => {
287+
// See note in `rustc_trait_selection::traits::project`
236288

237289
let tcx = self.infcx.tcx;
238290
let infcx = self.infcx;
@@ -292,7 +344,7 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
292344
)
293345
}
294346

295-
_ => ty,
347+
_ => ty.super_fold_with(self),
296348
})();
297349
self.cache.insert(ty, res);
298350
res

compiler/rustc_typeck/src/check/coercion.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,8 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
796796
//! into a closure or a `proc`.
797797
798798
let b = self.shallow_resolve(b);
799+
let InferOk { value: b, mut obligations } =
800+
self.normalize_associated_types_in_as_infer_ok(self.cause.span, b);
799801
debug!("coerce_from_fn_item(a={:?}, b={:?})", a, b);
800802

801803
match b.kind() {
@@ -815,8 +817,9 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
815817
}
816818
}
817819

818-
let InferOk { value: a_sig, mut obligations } =
820+
let InferOk { value: a_sig, obligations: o1 } =
819821
self.normalize_associated_types_in_as_infer_ok(self.cause.span, a_sig);
822+
obligations.extend(o1);
820823

821824
let a_fn_pointer = self.tcx.mk_fn_ptr(a_sig);
822825
let InferOk { value, obligations: o2 } = self.coerce_from_safe_fn(

src/test/ui/associated-type-bounds/issue-83017.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// check-pass
2+
13
#![feature(associated_type_bounds)]
24

35
trait TraitA<'a> {
@@ -34,6 +36,4 @@ where
3436

3537
fn main() {
3638
foo::<Z>();
37-
//~^ ERROR: the trait bound `for<'a, 'b> <Z as TraitA<'a>>::AsA: TraitB<'a, 'b>` is not satisfied
38-
//~| ERROR: the trait bound `for<'a, 'b, 'c> <<Z as TraitA<'a>>::AsA as TraitB<'a, 'b>>::AsB: TraitC<'a, 'b, 'c>` is not satisfied
3939
}

0 commit comments

Comments
 (0)