Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use cranelift_codegen::ir::immediates::Offset32;
use rustc_abi::Endian;
use rustc_middle::ty::SimdAlign;

use super::*;
use crate::prelude::*;
Expand Down Expand Up @@ -960,6 +961,15 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
let ptr_val = ptr.load_scalar(fx);

let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.unwrap_leaf()
.to_simd_alignment();

let memflags = match alignment {
SimdAlign::Unaligned => MemFlags::new().with_notrap(),
_ => MemFlags::trusted(),
};

for lane_idx in 0..val_lane_count {
let val_lane = val.value_lane(fx, lane_idx).load_scalar(fx);
let mask_lane = mask.value_lane(fx, lane_idx).load_scalar(fx);
Expand All @@ -972,7 +982,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(

fx.bcx.switch_to_block(if_enabled);
let offset = lane_idx as i32 * lane_clif_ty.bytes() as i32;
fx.bcx.ins().store(MemFlags::trusted(), val_lane, ptr_val, Offset32::new(offset));
fx.bcx.ins().store(memflags, val_lane, ptr_val, Offset32::new(offset));
fx.bcx.ins().jump(next, &[]);

fx.bcx.seal_block(next);
Expand All @@ -996,6 +1006,15 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
let ret_lane_layout = fx.layout_of(ret_lane_ty);

let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.unwrap_leaf()
.to_simd_alignment();

let memflags = match alignment {
SimdAlign::Unaligned => MemFlags::new().with_notrap(),
_ => MemFlags::trusted(),
};

for lane_idx in 0..ptr_lane_count {
let val_lane = val.value_lane(fx, lane_idx).load_scalar(fx);
let ptr_lane = ptr.value_lane(fx, lane_idx).load_scalar(fx);
Expand All @@ -1011,7 +1030,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
fx.bcx.seal_block(if_disabled);

fx.bcx.switch_to_block(if_enabled);
let res = fx.bcx.ins().load(lane_clif_ty, MemFlags::trusted(), ptr_lane, 0);
let res = fx.bcx.ins().load(lane_clif_ty, memflags, ptr_lane, 0);
fx.bcx.ins().jump(next, &[res.into()]);

fx.bcx.switch_to_block(if_disabled);
Expand Down
33 changes: 28 additions & 5 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE;
use rustc_hir::{self as hir};
use rustc_middle::mir::BinOp;
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf};
use rustc_middle::ty::{self, GenericArgsRef, Instance, Ty, TyCtxt, TypingEnv};
use rustc_middle::ty::{self, GenericArgsRef, Instance, SimdAlign, Ty, TyCtxt, TypingEnv};
use rustc_middle::{bug, span_bug};
use rustc_span::{Span, Symbol, sym};
use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_crate};
Expand Down Expand Up @@ -1826,15 +1826,34 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
));
}

fn llvm_alignment<'ll, 'tcx>(
bx: &mut Builder<'_, 'll, 'tcx>,
alignment: SimdAlign,
vector_ty: Ty<'tcx>,
element_ty: Ty<'tcx>,
) -> &'ll Value {
let alignment = match alignment {
SimdAlign::Unaligned => 1,
SimdAlign::Element => bx.align_of(element_ty).bytes(),
SimdAlign::Vector => bx.align_of(vector_ty).bytes(),
};

bx.const_i32(alignment as i32)
}

if name == sym::simd_masked_load {
// simd_masked_load(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
// simd_masked_load<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *_ T, values: <N x T>) -> <N x T>
// * N: number of elements in the input vectors
// * T: type of the element to load
// * M: any integer width is supported, will be truncated to i1
// Loads contiguous elements from memory behind `pointer`, but only for
// those lanes whose `mask` bit is enabled.
// The memory addresses corresponding to the “off” lanes are not accessed.

let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.unwrap_leaf()
.to_simd_alignment();

// The element type of the "mask" argument must be a signed integer type of any width
let mask_ty = in_ty;
let (mask_len, mask_elem) = (in_len, in_elem);
Expand Down Expand Up @@ -1891,7 +1910,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);

// Alignment of T, must be a constant integer value:
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
let alignment = llvm_alignment(bx, alignment, values_ty, values_elem);

let llvm_pointer = bx.type_ptr();

Expand All @@ -1906,14 +1925,18 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}

if name == sym::simd_masked_store {
// simd_masked_store(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
// simd_masked_store<_, _, _, const ALIGN: SimdAlign>(mask: <N x i{M}>, pointer: *mut T, values: <N x T>) -> ()
// * N: number of elements in the input vectors
// * T: type of the element to load
// * M: any integer width is supported, will be truncated to i1
// Stores contiguous elements to memory behind `pointer`, but only for
// those lanes whose `mask` bit is enabled.
// The memory addresses corresponding to the “off” lanes are not accessed.

let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.unwrap_leaf()
.to_simd_alignment();

// The element type of the "mask" argument must be a signed integer type of any width
let mask_ty = in_ty;
let (mask_len, mask_elem) = (in_len, in_elem);
Expand Down Expand Up @@ -1964,7 +1987,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);

// Alignment of T, must be a constant integer value:
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
let alignment = llvm_alignment(bx, alignment, values_ty, values_elem);

let llvm_pointer = bx.type_ptr();

Expand Down
57 changes: 50 additions & 7 deletions compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use either::Either;
use rustc_abi::Endian;
use rustc_abi::{BackendRepr, Endian};
use rustc_apfloat::{Float, Round};
use rustc_middle::mir::interpret::{InterpErrorKind, UndefinedBehaviorInfo};
use rustc_middle::ty::FloatTy;
use rustc_middle::mir::interpret::{InterpErrorKind, Pointer, UndefinedBehaviorInfo};
use rustc_middle::ty::{FloatTy, SimdAlign};
use rustc_middle::{bug, err_ub_format, mir, span_bug, throw_unsup_format, ty};
use rustc_span::{Symbol, sym};
use tracing::trace;

use super::{
ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Provenance, Scalar, Size, interp_ok,
throw_ub_format,
ImmTy, InterpCx, InterpResult, Machine, OpTy, PlaceTy, Provenance, Scalar, Size, TyAndLayout,
assert_matches, interp_ok, throw_ub_format,
};
use crate::interpret::Writeable;

Expand Down Expand Up @@ -658,6 +658,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
}
}
sym::simd_masked_load => {
let dest_layout = dest.layout;

let (mask, mask_len) = self.project_to_simd(&args[0])?;
let ptr = self.read_pointer(&args[1])?;
let (default, default_len) = self.project_to_simd(&args[2])?;
Expand All @@ -666,6 +668,14 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
assert_eq!(dest_len, mask_len);
assert_eq!(dest_len, default_len);

self.check_simd_ptr_alignment(
ptr,
dest_layout,
generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.unwrap_leaf()
.to_simd_alignment(),
)?;

for i in 0..dest_len {
let mask = self.read_immediate(&self.project_index(&mask, i)?)?;
let default = self.read_immediate(&self.project_index(&default, i)?)?;
Expand All @@ -674,7 +684,8 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
let val = if simd_element_to_bool(mask)? {
// Size * u64 is implemented as always checked
let ptr = ptr.wrapping_offset(dest.layout.size * i, self);
let place = self.ptr_to_mplace(ptr, dest.layout);
// we have already checked the alignment requirements
let place = self.ptr_to_mplace_unaligned(ptr, dest.layout);
self.read_immediate(&place)?
} else {
default
Expand All @@ -689,14 +700,23 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {

assert_eq!(mask_len, vals_len);

self.check_simd_ptr_alignment(
ptr,
args[2].layout,
generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
.unwrap_leaf()
.to_simd_alignment(),
)?;

for i in 0..vals_len {
let mask = self.read_immediate(&self.project_index(&mask, i)?)?;
let val = self.read_immediate(&self.project_index(&vals, i)?)?;

if simd_element_to_bool(mask)? {
// Size * u64 is implemented as always checked
let ptr = ptr.wrapping_offset(val.layout.size * i, self);
let place = self.ptr_to_mplace(ptr, val.layout);
// we have already checked the alignment requirements
let place = self.ptr_to_mplace_unaligned(ptr, val.layout);
self.write_immediate(*val, &place)?
};
}
Expand Down Expand Up @@ -748,6 +768,29 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
FloatTy::F128 => unimplemented!("f16_f128"),
})
}

fn check_simd_ptr_alignment(
&self,
ptr: Pointer<Option<M::Provenance>>,
vector_layout: TyAndLayout<'tcx>,
alignment: SimdAlign,
) -> InterpResult<'tcx> {
assert_matches!(vector_layout.backend_repr, BackendRepr::SimdVector { .. });

match alignment {
ty::SimdAlign::Unaligned => {
// the pointer is supposed to be unaligned, so no check is required
interp_ok(())
}
ty::SimdAlign::Element => {
// take the alignment of the only field, which is an array and therefore has the same alignment
// as the element type.
let elem_align = vector_layout.field(self, 0).align.abi;
self.check_ptr_align(ptr, elem_align)
}
ty::SimdAlign::Vector => self.check_ptr_align(ptr, vector_layout.align.abi),
}
}
}

fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,8 @@ pub(crate) fn check_intrinsic_type(
(1, 0, vec![param(0), param(0), param(0)], param(0))
}
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
sym::simd_masked_load => (3, 0, vec![param(0), param(1), param(2)], param(2)),
sym::simd_masked_store => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
sym::simd_masked_load => (3, 1, vec![param(0), param(1), param(2)], param(2)),
sym::simd_masked_store => (3, 1, vec![param(0), param(1), param(2)], tcx.types.unit),
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], tcx.types.unit),
sym::simd_insert | sym::simd_insert_dyn => {
(2, 0, vec![param(0), tcx.types.u32, param(1)], param(0))
Expand Down
24 changes: 24 additions & 0 deletions compiler/rustc_middle/src/ty/consts/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ pub enum AtomicOrdering {
SeqCst = 4,
}

/// An enum to represent the compiler-side view of `intrinsics::simd::SimdAlign`.
#[derive(Debug, Copy, Clone)]
pub enum SimdAlign {
// These values must match `intrinsics::simd::SimdAlign`!
Unaligned = 0,
Element = 1,
Vector = 2,
}

impl std::fmt::Debug for ConstInt {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self { int, signed, is_ptr_sized_integral } = *self;
Expand Down Expand Up @@ -350,6 +359,21 @@ impl ScalarInt {
}
}

#[inline]
pub fn to_simd_alignment(self) -> SimdAlign {
use SimdAlign::*;
let val = self.to_u32();
if val == Unaligned as u32 {
Unaligned
} else if val == Element as u32 {
Element
} else if val == Vector as u32 {
Vector
} else {
panic!("not a valid simd alignment")
}
}

/// Converts the `ScalarInt` to `bool`.
/// Panics if the `size` of the `ScalarInt` is not equal to 1 byte.
/// Errors if it is not a valid `bool`.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub use self::closure::{
};
pub use self::consts::{
AnonConstKind, AtomicOrdering, Const, ConstInt, ConstKind, ConstToValTreeResult, Expr,
ExprKind, ScalarInt, UnevaluatedConst, ValTree, ValTreeKind, Value,
ExprKind, ScalarInt, SimdAlign, UnevaluatedConst, ValTree, ValTreeKind, Value,
};
pub use self::context::{
CtxtInterners, CurrentGcx, DeducedParamAttrs, Feed, FreeRegionInfo, GlobalCtxt, Lift, TyCtxt,
Expand Down
25 changes: 19 additions & 6 deletions library/core/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! In this module, a "vector" is any `repr(simd)` type.

use crate::marker::ConstParamTy;

/// Inserts an element into a vector, returning the updated vector.
///
/// `T` must be a vector with element type `U`, and `idx` must be `const`.
Expand Down Expand Up @@ -377,6 +379,19 @@ pub unsafe fn simd_gather<T, U, V>(val: T, ptr: U, mask: V) -> T;
#[rustc_nounwind]
pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);

/// A type for alignment options for SIMD masked load/store intrinsics.
#[derive(Debug, ConstParamTy, PartialEq, Eq)]
pub enum SimdAlign {
// These values must match the compiler's `SimdAlign` defined in
// `rustc_middle/src/ty/consts/int.rs`!
/// No alignment requirements on the pointer
Unaligned = 0,
/// The pointer must be aligned to the element type of the SIMD vector
Element = 1,
/// The pointer must be aligned to the SIMD vector type
Vector = 2,
}

/// Reads a vector of pointers.
///
/// `T` must be a vector.
Expand All @@ -392,13 +407,12 @@ pub unsafe fn simd_scatter<T, U, V>(val: T, ptr: U, mask: V);
/// `val`.
///
/// # Safety
/// Unmasked values in `T` must be readable as if by `<ptr>::read` (e.g. aligned to the element
/// type).
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
///
/// `mask` must only contain `0` or `!0` values.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
pub unsafe fn simd_masked_load<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T) -> T;

/// Writes to a vector of pointers.
///
Expand All @@ -414,13 +428,12 @@ pub unsafe fn simd_masked_load<V, U, T>(mask: V, ptr: U, val: T) -> T;
/// Otherwise if the corresponding value in `mask` is `0`, do nothing.
///
/// # Safety
/// Unmasked values in `T` must be writeable as if by `<ptr>::write` (e.g. aligned to the element
/// type).
/// `ptr` must be aligned according to the `ALIGN` parameter, see [`SimdAlign`] for details.
///
/// `mask` must only contain `0` or `!0` values.
#[rustc_intrinsic]
#[rustc_nounwind]
pub unsafe fn simd_masked_store<V, U, T>(mask: V, ptr: U, val: T);
pub unsafe fn simd_masked_store<V, U, T, const ALIGN: SimdAlign>(mask: V, ptr: U, val: T);

/// Adds two simd vectors elementwise, with saturation.
///
Expand Down
Loading
Loading