Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

interpret: dyn trait metadata check: equate traits in a proper way #126232

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions compiler/rustc_const_eval/src/interpret/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
match (&src_pointee_ty.kind(), &dest_pointee_ty.kind()) {
(&ty::Array(_, length), &ty::Slice(_)) => {
let ptr = self.read_pointer(src)?;
// u64 cast is from usize to u64, which is always good
let val = Immediate::new_slice(
ptr,
length.eval_target_usize(*self.tcx, self.param_env),
Expand All @@ -401,13 +400,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let (old_data, old_vptr) = val.to_scalar_pair();
let old_data = old_data.to_pointer(self)?;
let old_vptr = old_vptr.to_pointer(self)?;
let (ty, old_trait) = self.get_ptr_vtable(old_vptr)?;
if old_trait != data_a.principal() {
throw_ub!(InvalidVTableTrait {
expected_trait: data_a,
vtable_trait: old_trait,
});
}
let ty = self.get_ptr_vtable_ty(old_vptr, Some(data_a))?;
let new_vptr = self.get_vtable_ptr(ty, data_b.principal())?;
self.write_immediate(Immediate::new_dyn_trait(old_data, new_vptr, self), dest)
}
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/eval_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
Ok(Some((full_size, full_align)))
}
ty::Dynamic(_, _, ty::Dyn) => {
ty::Dynamic(expected_trait, _, ty::Dyn) => {
let vtable = metadata.unwrap_meta().to_pointer(self)?;
// Read size and align from vtable (already checks size).
Ok(Some(self.get_vtable_size_and_align(vtable)?))
Ok(Some(self.get_vtable_size_and_align(vtable, Some(expected_trait))?))
}

ty::Slice(_) | ty::Str => {
Expand Down
6 changes: 4 additions & 2 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {

sym::vtable_size => {
let ptr = self.read_pointer(&args[0])?;
let (size, _align) = self.get_vtable_size_and_align(ptr)?;
// `None` because we don't know which trait to expect here; any vtable is okay.
let (size, _align) = self.get_vtable_size_and_align(ptr, None)?;
self.write_scalar(Scalar::from_target_usize(size.bytes(), self), dest)?;
}
sym::vtable_align => {
let ptr = self.read_pointer(&args[0])?;
let (_size, align) = self.get_vtable_size_and_align(ptr)?;
// `None` because we don't know which trait to expect here; any vtable is okay.
let (_size, align) = self.get_vtable_size_and_align(ptr, None)?;
self.write_scalar(Scalar::from_target_usize(align.bytes(), self), dest)?;
}

Expand Down
17 changes: 12 additions & 5 deletions compiler/rustc_const_eval/src/interpret/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -867,19 +867,26 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
.ok_or_else(|| err_ub!(InvalidFunctionPointer(Pointer::new(alloc_id, offset))).into())
}

pub fn get_ptr_vtable(
/// Get the dynamic type of the given vtable pointer.
/// If `expected_trait` is `Some`, it must be a vtable for the given trait.
pub fn get_ptr_vtable_ty(
&self,
ptr: Pointer<Option<M::Provenance>>,
) -> InterpResult<'tcx, (Ty<'tcx>, Option<ty::PolyExistentialTraitRef<'tcx>>)> {
expected_trait: Option<&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>>,
) -> InterpResult<'tcx, Ty<'tcx>> {
trace!("get_ptr_vtable({:?})", ptr);
let (alloc_id, offset, _tag) = self.ptr_get_alloc_id(ptr)?;
if offset.bytes() != 0 {
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
}
match self.tcx.try_get_global_alloc(alloc_id) {
Some(GlobalAlloc::VTable(ty, trait_ref)) => Ok((ty, trait_ref)),
_ => throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset))),
let Some(GlobalAlloc::VTable(ty, vtable_trait)) = self.tcx.try_get_global_alloc(alloc_id)
else {
throw_ub!(InvalidVTablePointer(Pointer::new(alloc_id, offset)))
};
if let Some(expected_trait) = expected_trait {
self.check_vtable_for_type(vtable_trait, expected_trait)?;
}
Ok(ty)
}

pub fn alloc_mark_immutable(&mut self, id: AllocId) -> InterpResult<'tcx> {
Expand Down
49 changes: 0 additions & 49 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use tracing::{instrument, trace};

use rustc_ast::Mutability;
use rustc_middle::mir;
use rustc_middle::ty;
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
use rustc_middle::ty::Ty;
use rustc_middle::{bug, span_bug};
Expand Down Expand Up @@ -1017,54 +1016,6 @@ where
let layout = self.layout_of(raw.ty)?;
Ok(self.ptr_to_mplace(ptr.into(), layout))
}

/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
/// Aso returns the vtable.
pub(super) fn unpack_dyn_trait(
&self,
mplace: &MPlaceTy<'tcx, M::Provenance>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, Pointer<Option<M::Provenance>>)> {
assert!(
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
"`unpack_dyn_trait` only makes sense on `dyn*` types"
);
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
if expected_trait.principal() != vtable_trait {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
// This is a kind of transmute, from a place with unsized type and metadata to
// a place with sized type and no metadata.
let layout = self.layout_of(ty)?;
let mplace =
MPlaceTy { mplace: MemPlace { meta: MemPlaceMeta::None, ..mplace.mplace }, layout };
Ok((mplace, vtable))
}

/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
/// Also returns the vtable.
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
&self,
val: &P,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, (P, Pointer<Option<M::Provenance>>)> {
assert!(
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
"`unpack_dyn_star` only makes sense on `dyn*` types"
);
let data = self.project_field(val, 0)?;
let vtable = self.project_field(val, 1)?;
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
let (ty, vtable_trait) = self.get_ptr_vtable(vtable)?;
if expected_trait.principal() != vtable_trait {
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}
// `data` is already the right thing but has the wrong type. So we transmute it.
let layout = self.layout_of(ty)?;
let data = data.transmute(layout, self)?;
Ok((data, vtable))
}
}

// Some nodes are used a lot. Make sure they don't unintentionally get bigger.
Expand Down
37 changes: 18 additions & 19 deletions compiler/rustc_const_eval/src/interpret/terminator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::borrow::Cow;

use either::Either;
use rustc_middle::ty::TyCtxt;
use tracing::trace;

use rustc_middle::span_bug;
Expand Down Expand Up @@ -827,20 +828,19 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
};

// Obtain the underlying trait we are working on, and the adjusted receiver argument.
let (vptr, dyn_ty, adjusted_receiver) = if let ty::Dynamic(data, _, ty::DynStar) =
let (dyn_trait, dyn_ty, adjusted_recv) = if let ty::Dynamic(data, _, ty::DynStar) =
receiver_place.layout.ty.kind()
{
let (recv, vptr) = self.unpack_dyn_star(&receiver_place, data)?;
let (dyn_ty, _dyn_trait) = self.get_ptr_vtable(vptr)?;
let recv = self.unpack_dyn_star(&receiver_place, data)?;

(vptr, dyn_ty, recv.ptr())
(data.principal(), recv.layout.ty, recv.ptr())
} else {
// Doesn't have to be a `dyn Trait`, but the unsized tail must be `dyn Trait`.
// (For that reason we also cannot use `unpack_dyn_trait`.)
let receiver_tail = self
.tcx
.struct_tail_erasing_lifetimes(receiver_place.layout.ty, self.param_env);
let ty::Dynamic(data, _, ty::Dyn) = receiver_tail.kind() else {
let ty::Dynamic(receiver_trait, _, ty::Dyn) = receiver_tail.kind() else {
span_bug!(
self.cur_span(),
"dynamic call on non-`dyn` type {}",
Expand All @@ -851,25 +851,24 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {

// Get the required information from the vtable.
let vptr = receiver_place.meta().unwrap_meta().to_pointer(self)?;
let (dyn_ty, dyn_trait) = self.get_ptr_vtable(vptr)?;
if dyn_trait != data.principal() {
throw_ub!(InvalidVTableTrait {
expected_trait: data,
vtable_trait: dyn_trait,
});
}
let dyn_ty = self.get_ptr_vtable_ty(vptr, Some(receiver_trait))?;

// It might be surprising that we use a pointer as the receiver even if this
// is a by-val case; this works because by-val passing of an unsized `dyn
// Trait` to a function is actually desugared to a pointer.
(vptr, dyn_ty, receiver_place.ptr())
(receiver_trait.principal(), dyn_ty, receiver_place.ptr())
};

// Now determine the actual method to call. We can do that in two different ways and
// compare them to ensure everything fits.
let Some(ty::VtblEntry::Method(fn_inst)) =
self.get_vtable_entries(vptr)?.get(idx).copied()
else {
let vtable_entries = if let Some(dyn_trait) = dyn_trait {
let trait_ref = dyn_trait.with_self_ty(*self.tcx, dyn_ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
};
let Some(ty::VtblEntry::Method(fn_inst)) = vtable_entries.get(idx).copied() else {
// FIXME(fee1-dead) these could be variants of the UB info enum instead of this
throw_ub_custom!(fluent::const_eval_dyn_call_not_a_method);
};
Expand Down Expand Up @@ -898,7 +897,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let receiver_ty = Ty::new_mut_ptr(self.tcx.tcx, dyn_ty);
args[0] = FnArg::Copy(
ImmTy::from_immediate(
Scalar::from_maybe_pointer(adjusted_receiver, self).into(),
Scalar::from_maybe_pointer(adjusted_recv, self).into(),
self.layout_of(receiver_ty)?,
)
.into(),
Expand Down Expand Up @@ -974,11 +973,11 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let place = match place.layout.ty.kind() {
ty::Dynamic(data, _, ty::Dyn) => {
// Dropping a trait object. Need to find actual drop fn.
self.unpack_dyn_trait(&place, data)?.0
self.unpack_dyn_trait(&place, data)?
}
ty::Dynamic(data, _, ty::DynStar) => {
// Dropping a `dyn*`. Need to find actual drop fn.
self.unpack_dyn_star(&place, data)?.0
self.unpack_dyn_star(&place, data)?
}
_ => {
debug_assert_eq!(
Expand Down
101 changes: 83 additions & 18 deletions compiler/rustc_const_eval/src/interpret/traits.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use rustc_infer::infer::TyCtxtInferExt;
use rustc_infer::traits::ObligationCause;
use rustc_middle::mir::interpret::{InterpResult, Pointer};
use rustc_middle::ty::layout::LayoutOf;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, Ty};
use rustc_target::abi::{Align, Size};
use rustc_trait_selection::traits::ObligationCtxt;
use tracing::trace;

use super::util::ensure_monomorphic_enough;
use super::{InterpCx, Machine};
use super::{throw_ub, InterpCx, MPlaceTy, Machine, MemPlaceMeta, OffsetMode, Projectable};

impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
/// Creates a dynamic vtable for the given type and vtable origin. This is used only for
Expand Down Expand Up @@ -33,28 +36,90 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok(vtable_ptr.into())
}

/// Returns a high-level representation of the entries of the given vtable.
pub fn get_vtable_entries(
&self,
vtable: Pointer<Option<M::Provenance>>,
) -> InterpResult<'tcx, &'tcx [ty::VtblEntry<'tcx>]> {
let (ty, poly_trait_ref) = self.get_ptr_vtable(vtable)?;
Ok(if let Some(poly_trait_ref) = poly_trait_ref {
let trait_ref = poly_trait_ref.with_self_ty(*self.tcx, ty);
let trait_ref = self.tcx.erase_regions(trait_ref);
self.tcx.vtable_entries(trait_ref)
} else {
TyCtxt::COMMON_VTABLE_ENTRIES
})
}

pub fn get_vtable_size_and_align(
&self,
vtable: Pointer<Option<M::Provenance>>,
expected_trait: Option<&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>>,
) -> InterpResult<'tcx, (Size, Align)> {
let (ty, _trait_ref) = self.get_ptr_vtable(vtable)?;
let ty = self.get_ptr_vtable_ty(vtable, expected_trait)?;
let layout = self.layout_of(ty)?;
assert!(layout.is_sized(), "there are no vtables for unsized types");
Ok((layout.size, layout.align.abi))
}

/// Check that the given vtable trait is valid for a pointer/reference/place with the given
/// expected trait type.
pub(super) fn check_vtable_for_type(
&self,
vtable_trait: Option<ty::PolyExistentialTraitRef<'tcx>>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx> {
// Fast path: if they are equal, it's all fine.
if expected_trait.principal() == vtable_trait {
return Ok(());
}
if let (Some(expected_trait), Some(vtable_trait)) =
(expected_trait.principal(), vtable_trait)
{
// Slow path: spin up an inference context to check if these traits are sufficiently equal.
let infcx = self.tcx.infer_ctxt().build();
let ocx = ObligationCtxt::new(&infcx);
let cause = ObligationCause::dummy_with_span(self.cur_span());
// equate the two trait refs after normalization
let expected_trait = ocx.normalize(&cause, self.param_env, expected_trait);
let vtable_trait = ocx.normalize(&cause, self.param_env, vtable_trait);
if ocx.eq(&cause, self.param_env, expected_trait, vtable_trait).is_ok() {
if ocx.select_all_or_error().is_empty() {
// All good.
return Ok(());
}
}
}
throw_ub!(InvalidVTableTrait { expected_trait, vtable_trait });
}

/// Turn a place with a `dyn Trait` type into a place with the actual dynamic type.
pub(super) fn unpack_dyn_trait(
&self,
mplace: &MPlaceTy<'tcx, M::Provenance>,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, MPlaceTy<'tcx, M::Provenance>> {
assert!(
matches!(mplace.layout.ty.kind(), ty::Dynamic(_, _, ty::Dyn)),
"`unpack_dyn_trait` only makes sense on `dyn*` types"
);
let vtable = mplace.meta().unwrap_meta().to_pointer(self)?;
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
// This is a kind of transmute, from a place with unsized type and metadata to
// a place with sized type and no metadata.
let layout = self.layout_of(ty)?;
let mplace = mplace.offset_with_meta(
Size::ZERO,
OffsetMode::Wrapping,
MemPlaceMeta::None,
layout,
self,
)?;
Ok(mplace)
}

/// Turn a `dyn* Trait` type into an value with the actual dynamic type.
pub(super) fn unpack_dyn_star<P: Projectable<'tcx, M::Provenance>>(
&self,
val: &P,
expected_trait: &'tcx ty::List<ty::PolyExistentialPredicate<'tcx>>,
) -> InterpResult<'tcx, P> {
assert!(
matches!(val.layout().ty.kind(), ty::Dynamic(_, _, ty::DynStar)),
"`unpack_dyn_star` only makes sense on `dyn*` types"
);
let data = self.project_field(val, 0)?;
let vtable = self.project_field(val, 1)?;
let vtable = self.read_pointer(&vtable.to_op(self)?)?;
let ty = self.get_ptr_vtable_ty(vtable, Some(expected_trait))?;
// `data` is already the right thing but has the wrong type. So we transmute it.
let layout = self.layout_of(ty)?;
let data = data.transmute(layout, self)?;
Ok(data)
}
}
18 changes: 7 additions & 11 deletions compiler/rustc_const_eval/src/interpret/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,20 +343,16 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
match tail.kind() {
ty::Dynamic(data, _, ty::Dyn) => {
let vtable = meta.unwrap_meta().to_pointer(self.ecx)?;
// Make sure it is a genuine vtable pointer.
let (_dyn_ty, dyn_trait) = try_validation!(
self.ecx.get_ptr_vtable(vtable),
// Make sure it is a genuine vtable pointer for the right trait.
try_validation!(
self.ecx.get_ptr_vtable_ty(vtable, Some(data)),
self.path,
Ub(DanglingIntPointer(..) | InvalidVTablePointer(..)) =>
InvalidVTablePtr { value: format!("{vtable}") }
InvalidVTablePtr { value: format!("{vtable}") },
Ub(InvalidVTableTrait { expected_trait, vtable_trait }) => {
InvalidMetaWrongTrait { expected_trait, vtable_trait: *vtable_trait }
},
);
// Make sure it is for the right trait.
if dyn_trait != data.principal() {
throw_validation_failure!(
self.path,
InvalidMetaWrongTrait { expected_trait: data, vtable_trait: dyn_trait }
);
}
}
ty::Slice(..) | ty::Str => {
let _len = meta.unwrap_meta().to_target_usize(self.ecx)?;
Expand Down
Loading
Loading