Skip to content

Commit

Permalink
Auto merge of #130215 - RalfJung:interpret-simd, r=compiler-errors
Browse files Browse the repository at this point in the history
interpret: simplify SIMD type handling

This is possible as a follow-up to #129403
  • Loading branch information
bors committed Sep 13, 2024
2 parents 473ae00 + e2bc16c commit 0307e40
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 199 deletions.
6 changes: 3 additions & 3 deletions compiler/rustc_const_eval/src/interpret/intrinsics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
sym::simd_insert => {
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
let elem = &args[2];
let (input, input_len) = self.operand_to_simd(&args[0])?;
let (dest, dest_len) = self.mplace_to_simd(dest)?;
let (input, input_len) = self.project_to_simd(&args[0])?;
let (dest, dest_len) = self.project_to_simd(dest)?;
assert_eq!(input_len, dest_len, "Return vector length must match input length");
// Bounds are not checked by typeck so we have to do it ourselves.
if index >= input_len {
Expand All @@ -406,7 +406,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
sym::simd_extract => {
let index = u64::from(self.read_scalar(&args[1])?.to_u32()?);
let (input, input_len) = self.operand_to_simd(&args[0])?;
let (input, input_len) = self.project_to_simd(&args[0])?;
// Bounds are not checked by typeck so we have to do it ourselves.
if index >= input_len {
throw_ub_format!(
Expand Down
24 changes: 0 additions & 24 deletions compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,30 +681,6 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
Ok(str)
}

/// Converts a repr(simd) operand into an operand where `place_index` accesses the SIMD elements.
/// Also returns the number of elements.
///
/// Can (but does not always) trigger UB if `op` is uninitialized.
pub fn operand_to_simd(
&self,
op: &OpTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, u64)> {
// Basically we just transmute this place into an array following simd_size_and_type.
// This only works in memory, but repr(simd) types should never be immediates anyway.
assert!(op.layout.ty.is_simd());
match op.as_mplace_or_imm() {
Left(mplace) => self.mplace_to_simd(&mplace),
Right(imm) => match *imm {
Immediate::Uninit => {
throw_ub!(InvalidUninitBytes(None))
}
Immediate::Scalar(..) | Immediate::ScalarPair(..) => {
bug!("arrays/slices can never have Scalar/ScalarPair layout")
}
},
}
}

/// Read from a local of the current frame.
/// Will not access memory, instead an indirect `Operand` is returned.
///
Expand Down
45 changes: 22 additions & 23 deletions compiler/rustc_const_eval/src/interpret/place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,13 +377,15 @@ where
Prov: Provenance,
M: Machine<'tcx, Provenance = Prov>,
{
pub fn ptr_with_meta_to_mplace(
fn ptr_with_meta_to_mplace(
&self,
ptr: Pointer<Option<M::Provenance>>,
meta: MemPlaceMeta<M::Provenance>,
layout: TyAndLayout<'tcx>,
unaligned: bool,
) -> MPlaceTy<'tcx, M::Provenance> {
let misaligned = self.is_ptr_misaligned(ptr, layout.align.abi);
let misaligned =
if unaligned { None } else { self.is_ptr_misaligned(ptr, layout.align.abi) };
MPlaceTy { mplace: MemPlace { ptr, meta, misaligned }, layout }
}

Expand All @@ -393,7 +395,16 @@ where
layout: TyAndLayout<'tcx>,
) -> MPlaceTy<'tcx, M::Provenance> {
assert!(layout.is_sized());
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout)
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout, /*unaligned*/ false)
}

pub fn ptr_to_mplace_unaligned(
&self,
ptr: Pointer<Option<M::Provenance>>,
layout: TyAndLayout<'tcx>,
) -> MPlaceTy<'tcx, M::Provenance> {
assert!(layout.is_sized());
self.ptr_with_meta_to_mplace(ptr, MemPlaceMeta::None, layout, /*unaligned*/ true)
}

/// Take a value, which represents a (thin or wide) reference, and make it a place.
Expand All @@ -414,7 +425,7 @@ where
// `ref_to_mplace` is called on raw pointers even if they don't actually get dereferenced;
// we hence can't call `size_and_align_of` since that asserts more validity than we want.
let ptr = ptr.to_pointer(self)?;
Ok(self.ptr_with_meta_to_mplace(ptr, meta, layout))
Ok(self.ptr_with_meta_to_mplace(ptr, meta, layout, /*unaligned*/ false))
}

/// Turn a mplace into a (thin or wide) mutable raw pointer, pointing to the same space.
Expand Down Expand Up @@ -484,23 +495,6 @@ where
Ok(a)
}

/// Converts a repr(simd) place into a place where `place_index` accesses the SIMD elements.
/// Also returns the number of elements.
pub fn mplace_to_simd(
&self,
mplace: &MPlaceTy<'tcx, M::Provenance>,
) -> InterpResult<'tcx, (MPlaceTy<'tcx, M::Provenance>, u64)> {
// Basically we want to transmute this place into an array following simd_size_and_type.
let (len, e_ty) = mplace.layout.ty.simd_size_and_type(*self.tcx);
// Some SIMD types have padding, so `len` many `e_ty` does not cover the entire place.
// Therefore we cannot transmute, and instead we project at offset 0, which side-steps
// the size check.
let array_layout = self.layout_of(Ty::new_array(self.tcx.tcx, e_ty, len))?;
assert!(array_layout.size <= mplace.layout.size);
let mplace = mplace.offset(Size::ZERO, array_layout, self)?;
Ok((mplace, len))
}

/// Turn a local in the current frame into a place.
pub fn local_to_place(
&self,
Expand Down Expand Up @@ -986,7 +980,7 @@ where
span_bug!(self.cur_span(), "cannot allocate space for `extern` type, size is not known")
};
let ptr = self.allocate_ptr(size, align, kind)?;
Ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout))
Ok(self.ptr_with_meta_to_mplace(ptr.into(), meta, layout, /*unaligned*/ false))
}

pub fn allocate(
Expand Down Expand Up @@ -1021,7 +1015,12 @@ where
};
let meta = Scalar::from_target_usize(u64::try_from(str.len()).unwrap(), self);
let layout = self.layout_of(self.tcx.types.str_).unwrap();
Ok(self.ptr_with_meta_to_mplace(ptr.into(), MemPlaceMeta::Meta(meta), layout))
Ok(self.ptr_with_meta_to_mplace(
ptr.into(),
MemPlaceMeta::Meta(meta),
layout,
/*unaligned*/ false,
))
}

pub fn raw_const_to_mplace(
Expand Down
13 changes: 13 additions & 0 deletions compiler/rustc_const_eval/src/interpret/projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,19 @@ where
base.offset(offset, field_layout, self)
}

/// Converts a repr(simd) value into an array of the right size, such that `project_index`
/// accesses the SIMD elements. Also returns the number of elements.
pub fn project_to_simd<P: Projectable<'tcx, M::Provenance>>(
&self,
base: &P,
) -> InterpResult<'tcx, (P, u64)> {
assert!(base.layout().ty.ty_adt_def().unwrap().repr().simd());
// SIMD types must be newtypes around arrays, so all we have to do is project to their only field.
let array = self.project_field(base, 0)?;
let len = array.len(self)?;
Ok((array, len))
}

fn project_constant_index<P: Projectable<'tcx, M::Provenance>>(
&self,
base: &P,
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_const_eval/src/interpret/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rustc_middle::ty::{
};
use tracing::debug;

use super::{throw_inval, InterpCx, MPlaceTy, MemPlaceMeta, MemoryKind};
use super::{throw_inval, InterpCx, MPlaceTy, MemoryKind};
use crate::const_eval::{CompileTimeInterpCx, CompileTimeMachine, InterpretationResult};

/// Checks whether a type contains generic parameters which must be instantiated.
Expand Down Expand Up @@ -103,5 +103,5 @@ pub(crate) fn create_static_alloc<'tcx>(
assert_eq!(ecx.machine.static_root_ids, None);
ecx.machine.static_root_ids = Some((alloc_id, static_def_id));
assert!(ecx.memory.alloc_map.insert(alloc_id, (MemoryKind::Stack, alloc)).is_none());
Ok(ecx.ptr_with_meta_to_mplace(Pointer::from(alloc_id).into(), MemPlaceMeta::None, layout))
Ok(ecx.ptr_to_mplace(Pointer::from(alloc_id).into(), layout))
}
82 changes: 41 additions & 41 deletions src/tools/miri/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
| "bitreverse"
=> {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down Expand Up @@ -200,9 +200,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, left_len);
assert_eq!(dest_len, right_len);
Expand Down Expand Up @@ -291,10 +291,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"fma" => {
let [a, b, c] = check_arg_count(args)?;
let (a, a_len) = this.operand_to_simd(a)?;
let (b, b_len) = this.operand_to_simd(b)?;
let (c, c_len) = this.operand_to_simd(c)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (a, a_len) = this.project_to_simd(a)?;
let (b, b_len) = this.project_to_simd(b)?;
let (c, c_len) = this.project_to_simd(c)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, a_len);
assert_eq!(dest_len, b_len);
Expand Down Expand Up @@ -345,7 +345,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;

let imm_from_bool =
|b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
Expand Down Expand Up @@ -408,7 +408,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
use mir::BinOp;

let [op, init] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;
let init = this.read_immediate(init)?;

let mir_op = match intrinsic_name {
Expand All @@ -426,10 +426,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"select" => {
let [mask, yes, no] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (yes, yes_len) = this.project_to_simd(yes)?;
let (no, no_len) = this.project_to_simd(no)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, yes_len);
Expand All @@ -448,9 +448,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Variant of `select` that takes a bitmask rather than a "vector of bool".
"select_bitmask" => {
let [mask, yes, no] = check_arg_count(args)?;
let (yes, yes_len) = this.operand_to_simd(yes)?;
let (no, no_len) = this.operand_to_simd(no)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (yes, yes_len) = this.project_to_simd(yes)?;
let (no, no_len) = this.project_to_simd(no)?;
let (dest, dest_len) = this.project_to_simd(dest)?;
let bitmask_len = dest_len.next_multiple_of(8);
if bitmask_len > 64 {
throw_unsup_format!(
Expand Down Expand Up @@ -522,7 +522,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// Converts a "vector of bool" into a bitmask.
"bitmask" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (op, op_len) = this.project_to_simd(op)?;
let bitmask_len = op_len.next_multiple_of(8);
if bitmask_len > 64 {
throw_unsup_format!(
Expand Down Expand Up @@ -570,8 +570,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
let [op] = check_arg_count(args)?;
let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down Expand Up @@ -627,9 +627,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"shuffle_generic" => {
let [left, right] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

let index = generic_args[2]
.expect_const()
Expand Down Expand Up @@ -662,15 +662,15 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"shuffle" => {
let [left, right, index] = check_arg_count(args)?;
let (left, left_len) = this.operand_to_simd(left)?;
let (right, right_len) = this.operand_to_simd(right)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (left, left_len) = this.project_to_simd(left)?;
let (right, right_len) = this.project_to_simd(right)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

// `index` is an array or a SIMD type
let (index, index_len) = match index.layout.ty.kind() {
// FIXME: remove this once `index` must always be a SIMD vector.
ty::Array(..) => (index.assert_mem_place(), index.len(this)?),
_ => this.operand_to_simd(index)?,
ty::Array(..) => (index.clone(), index.len(this)?),
_ => this.project_to_simd(index)?,
};

assert_eq!(left_len, right_len);
Expand Down Expand Up @@ -699,10 +699,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"gather" => {
let [passthru, ptrs, mask] = check_arg_count(args)?;
let (passthru, passthru_len) = this.operand_to_simd(passthru)?;
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (passthru, passthru_len) = this.project_to_simd(passthru)?;
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, passthru_len);
assert_eq!(dest_len, ptrs_len);
Expand All @@ -725,9 +725,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"scatter" => {
let [value, ptrs, mask] = check_arg_count(args)?;
let (value, value_len) = this.operand_to_simd(value)?;
let (ptrs, ptrs_len) = this.operand_to_simd(ptrs)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (value, value_len) = this.project_to_simd(value)?;
let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
let (mask, mask_len) = this.project_to_simd(mask)?;

assert_eq!(ptrs_len, value_len);
assert_eq!(ptrs_len, mask_len);
Expand All @@ -745,10 +745,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"masked_load" => {
let [mask, ptr, default] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let ptr = this.read_pointer(ptr)?;
let (default, default_len) = this.operand_to_simd(default)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (default, default_len) = this.project_to_simd(default)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, default_len);
Expand All @@ -772,9 +772,9 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
"masked_store" => {
let [mask, ptr, vals] = check_arg_count(args)?;
let (mask, mask_len) = this.operand_to_simd(mask)?;
let (mask, mask_len) = this.project_to_simd(mask)?;
let ptr = this.read_pointer(ptr)?;
let (vals, vals_len) = this.operand_to_simd(vals)?;
let (vals, vals_len) = this.project_to_simd(vals)?;

assert_eq!(mask_len, vals_len);

Expand Down
4 changes: 2 additions & 2 deletions src/tools/miri/src/shims/foreign_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,8 +903,8 @@ trait EvalContextExtPriv<'tcx>: crate::MiriInterpCxExt<'tcx> {
name if name.starts_with("llvm.ctpop.v") => {
let [op] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;

let (op, op_len) = this.operand_to_simd(op)?;
let (dest, dest_len) = this.mplace_to_simd(dest)?;
let (op, op_len) = this.project_to_simd(op)?;
let (dest, dest_len) = this.project_to_simd(dest)?;

assert_eq!(dest_len, op_len);

Expand Down
Loading

0 comments on commit 0307e40

Please sign in to comment.