Skip to content

Commit a7c9e79

Browse files
author
Lukas Markeffsky
committed
interpret: adjust vtable validity check for higher-ranked types
1 parent 1543bb4 commit a7c9e79

File tree

6 files changed

+88
-56
lines changed

6 files changed

+88
-56
lines changed

compiler/rustc_const_eval/src/interpret/cast.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -430,10 +430,12 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
430430
};
431431
let erased_trait_ref =
432432
ty::ExistentialTraitRef::erase_self_ty(*self.tcx, upcast_trait_ref);
433-
assert!(data_b.principal().is_some_and(|b| self.eq_in_param_env(
434-
erased_trait_ref,
435-
self.tcx.instantiate_bound_regions_with_erased(b)
436-
)));
433+
assert_eq!(
434+
data_b.principal().map(|b| {
435+
self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b)
436+
}),
437+
Some(erased_trait_ref),
438+
);
437439
} else {
438440
// In this case codegen would keep using the old vtable. We don't want to do
439441
// that as it has the wrong trait. The reason codegen can do this is that

compiler/rustc_const_eval/src/interpret/eval_context.rs

+1-39
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ use either::{Left, Right};
44
use rustc_abi::{Align, HasDataLayout, Size, TargetDataLayout};
55
use rustc_errors::DiagCtxtHandle;
66
use rustc_hir::def_id::DefId;
7-
use rustc_infer::infer::TyCtxtInferExt;
8-
use rustc_infer::infer::at::ToTrace;
9-
use rustc_infer::traits::ObligationCause;
107
use rustc_middle::mir::interpret::{ErrorHandled, InvalidMetaKind, ReportedErrorInfo};
118
use rustc_middle::query::TyCtxtAt;
129
use rustc_middle::ty::layout::{
@@ -17,8 +14,7 @@ use rustc_middle::{mir, span_bug};
1714
use rustc_session::Limit;
1815
use rustc_span::Span;
1916
use rustc_target::callconv::FnAbi;
20-
use rustc_trait_selection::traits::ObligationCtxt;
21-
use tracing::{debug, instrument, trace};
17+
use tracing::{debug, trace};
2218

2319
use super::{
2420
Frame, FrameInfo, GlobalId, InterpErrorInfo, InterpErrorKind, InterpResult, MPlaceTy, Machine,
@@ -323,40 +319,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
323319
}
324320
}
325321

326-
/// Check if the two things are equal in the current param_env, using an infcx to get proper
327-
/// equality checks.
328-
#[instrument(level = "trace", skip(self), ret)]
329-
pub(super) fn eq_in_param_env<T>(&self, a: T, b: T) -> bool
330-
where
331-
T: PartialEq + TypeFoldable<TyCtxt<'tcx>> + ToTrace<'tcx>,
332-
{
333-
// Fast path: compare directly.
334-
if a == b {
335-
return true;
336-
}
337-
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
338-
let (infcx, param_env) = self.tcx.infer_ctxt().build_with_typing_env(self.typing_env);
339-
let ocx = ObligationCtxt::new(&infcx);
340-
let cause = ObligationCause::dummy_with_span(self.cur_span());
341-
// equate the two trait refs after normalization
342-
let a = ocx.normalize(&cause, param_env, a);
343-
let b = ocx.normalize(&cause, param_env, b);
344-
345-
if let Err(terr) = ocx.eq(&cause, param_env, a, b) {
346-
trace!(?terr);
347-
return false;
348-
}
349-
350-
let errors = ocx.select_all_or_error();
351-
if !errors.is_empty() {
352-
trace!(?errors);
353-
return false;
354-
}
355-
356-
// All good.
357-
true
358-
}
359-
360322
/// Walks up the callstack from the intrinsic's callsite, searching for the first callsite in a
361323
/// frame which is not `#[track_caller]`. This matches the `caller_location` intrinsic,
362324
/// and is primarily intended for the panic machinery.

compiler/rustc_const_eval/src/interpret/traits.rs

+7-13
Original file line numberDiff line numberDiff line change
@@ -86,21 +86,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
8686
throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
8787
}
8888

89+
// This checks whether there is a subtyping relation between the predicates in either direction.
90+
// For example:
91+
// - casting between `dyn for<'a> Trait<fn(&'a u8)>` and `dyn Trait<fn(&'static u8)>` is OK
92+
// - casting between `dyn Trait<for<'a> fn(&'a u8)>` and either of the above is UB
8993
for (a_pred, b_pred) in std::iter::zip(sorted_vtable, sorted_expected) {
90-
let is_eq = match (a_pred.skip_binder(), b_pred.skip_binder()) {
91-
(
92-
ty::ExistentialPredicate::Trait(a_data),
93-
ty::ExistentialPredicate::Trait(b_data),
94-
) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
94+
let a_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, a_pred);
95+
let b_pred = self.tcx.normalize_erasing_late_bound_regions(self.typing_env, b_pred);
9596

96-
(
97-
ty::ExistentialPredicate::Projection(a_data),
98-
ty::ExistentialPredicate::Projection(b_data),
99-
) => self.eq_in_param_env(a_pred.rebind(a_data), b_pred.rebind(b_data)),
100-
101-
_ => false,
102-
};
103-
if !is_eq {
97+
if a_pred != b_pred {
10498
throw_ub!(InvalidVTableTrait { vtable_dyn_type, expected_dyn_type });
10599
}
106100
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Test that transmuting from `&dyn Dyn<fn(&'static ())>` to `&dyn Dyn<for<'a> fn(&'a ())>` is UB.
2+
//
3+
// The vtable of `() as Dyn<fn(&'static ())>` and `() as Dyn<for<'a> fn(&'a ())>` can have
4+
// different entries and, because in the former the entry for `foo` is vacant, this test will
5+
// segfault at runtime.
6+
7+
trait Dyn<U> {
8+
fn foo(&self)
9+
where
10+
U: HigherRanked,
11+
{
12+
}
13+
}
14+
impl<T, U> Dyn<U> for T {}
15+
16+
trait HigherRanked {}
17+
impl HigherRanked for for<'a> fn(&'a ()) {}
18+
19+
// 2nd candidate is required so that selecting `(): Dyn<fn(&'static ())>` will
20+
// evaluate the candidates and fail the leak check instead of returning the
21+
// only applicable candidate.
22+
trait Unsatisfied {}
23+
impl<T: Unsatisfied> HigherRanked for T {}
24+
25+
fn main() {
26+
let x: &dyn Dyn<fn(&'static ())> = &();
27+
let y: &dyn Dyn<for<'a> fn(&'a ())> = unsafe { std::mem::transmute(x) };
28+
//~^ ERROR: wrong trait in wide pointer vtable
29+
y.foo();
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
error: Undefined Behavior: constructing invalid value: wrong trait in wide pointer vtable: expected `Dyn<for<'a> fn(&'a ())>`, but encountered `Dyn<fn(&())>`
2+
--> tests/fail/validity/dyn-trait-leak-check.rs:LL:CC
3+
|
4+
LL | let y: &dyn Dyn<for<'a> fn(&'a ())> = unsafe { std::mem::transmute(x) };
5+
| ^^^^^^^^^^^^^^^^^^^^^^ constructing invalid value: wrong trait in wide pointer vtable: expected `Dyn<for<'a> fn(&'a ())>`, but encountered `Dyn<fn(&())>`
6+
|
7+
= help: this indicates a bug in the program: it performed an invalid operation, and caused Undefined Behavior
8+
= help: see https://doc.rust-lang.org/nightly/reference/behavior-considered-undefined.html for further information
9+
= note: BACKTRACE:
10+
= note: inside `main` at tests/fail/validity/dyn-trait-leak-check.rs:LL:CC
11+
12+
note: some details are omitted, run with `MIRIFLAGS=-Zmiri-backtrace=full` for a verbose backtrace
13+
14+
error: aborting due to 1 previous error
15+

src/tools/miri/tests/pass/dyn-upcast.rs

+29
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ fn main() {
1212
drop_principal();
1313
modulo_binder();
1414
modulo_assoc();
15+
bidirectional_subtyping();
1516
}
1617

1718
fn vtable_nop_cast() {
@@ -532,3 +533,31 @@ fn modulo_assoc() {
532533

533534
(&() as &dyn Trait as &dyn Middle<()>).say_hello(&0);
534535
}
536+
537+
fn bidirectional_subtyping() {
538+
// Test that transmuting between subtypes of dyn traits is fine, even in the
539+
// "wrong direction", i.e. going from a lower-ranked to a higher-ranked dyn trait.
540+
// Note that compared to the `dyn-trait-leak-check` test, the `for` is on the *outside* here!
541+
542+
trait Trait<U: ?Sized> {}
543+
impl<T, U: ?Sized> Trait<U> for T {}
544+
545+
struct Wrapper<T: ?Sized>(T);
546+
547+
let x: &dyn Trait<fn(&'static ())> = &();
548+
let _y: &dyn for<'a> Trait<fn(&'a ())> = unsafe { std::mem::transmute(x) };
549+
550+
let x: &dyn for<'a> Trait<fn(&'a ())> = &();
551+
let _y: &dyn Trait<fn(&'static ())> = unsafe { std::mem::transmute(x) };
552+
553+
let x: &dyn Trait<dyn Trait<fn(&'static ())>> = &();
554+
let _y: &dyn for<'a> Trait<dyn Trait<fn(&'a ())>> = unsafe { std::mem::transmute(x) };
555+
556+
let x: &dyn for<'a> Trait<dyn Trait<fn(&'a ())>> = &();
557+
let _y: &dyn Trait<dyn Trait<fn(&'static ())>> = unsafe { std::mem::transmute(x) };
558+
559+
// This lowers to a ptr-to-ptr cast (which behaves like a transmute)
560+
// and not an unsizing coercion:
561+
let x: *const dyn for<'a> Trait<&'a ()> = &();
562+
let _y: *const Wrapper<dyn Trait<&'static ()>> = x as _;
563+
}

0 commit comments

Comments
 (0)