Skip to content

Commit

Permalink
Auto merge of #121223 - RalfJung:simd-intrinsics, r=Amanieu
Browse files Browse the repository at this point in the history
intrinsics::simd: add missing functions, avoid UB-triggering fast-math

Turns out stdarch declares a bunch more SIMD intrinsics that are still missing from libcore.
I hope I got the docs and in particular the safety requirements right for these "unordered" and "nanless" intrinsics.

Many of these are unused even in stdarch, but they are implemented in the codegen backend, so we may as well list them here.

r? `@Amanieu`
Cc `@calebzulawski` `@workingjubilee`
  • Loading branch information
bors committed Feb 21, 2024
2 parents f8131a4 + 64efe80 commit 975c7bf
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 45 deletions.
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
simd_reduce(fx, v, None, ret, &|fx, _ty, a, b| fx.bcx.ins().bxor(a, b));
}

sym::simd_reduce_min | sym::simd_reduce_min_nanless => {
sym::simd_reduce_min => {
intrinsic_args!(fx, args => (v); intrinsic);

if !v.layout().ty.is_simd() {
Expand All @@ -762,7 +762,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
});
}

sym::simd_reduce_max | sym::simd_reduce_max_nanless => {
sym::simd_reduce_max => {
intrinsic_args!(fx, args => (v); intrinsic);

if !v.layout().ty.is_simd() {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
self.vector_reduce(src, |a, b, context| context.new_binary_op(None, op, a.get_type(), a, b))
}

pub fn vector_reduce_fadd_fast(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
pub fn vector_reduce_fadd_reassoc(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
unimplemented!();
}

Expand All @@ -1772,7 +1772,7 @@ impl<'a, 'gcc, 'tcx> Builder<'a, 'gcc, 'tcx> {
unimplemented!();
}

pub fn vector_reduce_fmul_fast(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
pub fn vector_reduce_fmul_reassoc(&mut self, _acc: RValue<'gcc>, _src: RValue<'gcc>) -> RValue<'gcc> {
unimplemented!();
}

Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_codegen_gcc/src/intrinsic/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,14 +989,14 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(

arith_red!(
simd_reduce_add_unordered: BinaryOp::Plus,
vector_reduce_fadd_fast,
vector_reduce_fadd_reassoc,
false,
add,
0.0 // TODO: Use this argument.
);
arith_red!(
simd_reduce_mul_unordered: BinaryOp::Mult,
vector_reduce_fmul_fast,
vector_reduce_fmul_reassoc,
false,
mul,
1.0
Expand Down Expand Up @@ -1041,9 +1041,6 @@ pub fn generic_simd_intrinsic<'a, 'gcc, 'tcx>(

minmax_red!(simd_reduce_min: vector_reduce_min, vector_reduce_fmin);
minmax_red!(simd_reduce_max: vector_reduce_max, vector_reduce_fmax);
// TODO(sadlerap): revisit these intrinsics to generate more optimal reductions
minmax_red!(simd_reduce_min_nanless: vector_reduce_min, vector_reduce_fmin);
minmax_red!(simd_reduce_max_nanless: vector_reduce_max, vector_reduce_fmax);

macro_rules! bitwise_red {
($name:ident : $op:expr, $boolean:expr) => {
Expand Down
24 changes: 4 additions & 20 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1367,17 +1367,17 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
pub fn vector_reduce_fmul(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe { llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src) }
}
pub fn vector_reduce_fadd_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
pub fn vector_reduce_fadd_reassoc(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
llvm::LLVMRustSetAlgebraicMath(instr);
llvm::LLVMRustSetAllowReassoc(instr);
instr
}
}
pub fn vector_reduce_fmul_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
pub fn vector_reduce_fmul_reassoc(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
unsafe {
let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
llvm::LLVMRustSetAlgebraicMath(instr);
llvm::LLVMRustSetAllowReassoc(instr);
instr
}
}
Expand Down Expand Up @@ -1406,22 +1406,6 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
llvm::LLVMRustBuildVectorReduceFMax(self.llbuilder, src, /*NoNaNs:*/ false)
}
}
pub fn vector_reduce_fmin_fast(&mut self, src: &'ll Value) -> &'ll Value {
unsafe {
let instr =
llvm::LLVMRustBuildVectorReduceFMin(self.llbuilder, src, /*NoNaNs:*/ true);
llvm::LLVMRustSetFastMath(instr);
instr
}
}
pub fn vector_reduce_fmax_fast(&mut self, src: &'ll Value) -> &'ll Value {
unsafe {
let instr =
llvm::LLVMRustBuildVectorReduceFMax(self.llbuilder, src, /*NoNaNs:*/ true);
llvm::LLVMRustSetFastMath(instr);
instr
}
}
pub fn vector_reduce_min(&mut self, src: &'ll Value, is_signed: bool) -> &'ll Value {
unsafe { llvm::LLVMRustBuildVectorReduceMin(self.llbuilder, src, is_signed) }
}
Expand Down
7 changes: 2 additions & 5 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1880,14 +1880,14 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
arith_red!(simd_reduce_mul_ordered: vector_reduce_mul, vector_reduce_fmul, true, mul, 1.0);
arith_red!(
simd_reduce_add_unordered: vector_reduce_add,
vector_reduce_fadd_algebraic,
vector_reduce_fadd_reassoc,
false,
add,
0.0
);
arith_red!(
simd_reduce_mul_unordered: vector_reduce_mul,
vector_reduce_fmul_algebraic,
vector_reduce_fmul_reassoc,
false,
mul,
1.0
Expand Down Expand Up @@ -1920,9 +1920,6 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
minmax_red!(simd_reduce_min: vector_reduce_min, vector_reduce_fmin);
minmax_red!(simd_reduce_max: vector_reduce_max, vector_reduce_fmax);

minmax_red!(simd_reduce_min_nanless: vector_reduce_min, vector_reduce_fmin_fast);
minmax_red!(simd_reduce_max_nanless: vector_reduce_max, vector_reduce_fmax_fast);

macro_rules! bitwise_red {
($name:ident : $red:ident, $boolean:expr) => {
if name == sym::$name {
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/llvm/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,7 @@ extern "C" {

pub fn LLVMRustSetFastMath(Instr: &Value);
pub fn LLVMRustSetAlgebraicMath(Instr: &Value);
pub fn LLVMRustSetAllowReassoc(Instr: &Value);

// Miscellaneous instructions
pub fn LLVMRustGetInstrProfIncrementIntrinsic(M: &Module) -> &Value;
Expand Down
4 changes: 1 addition & 3 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,9 +606,7 @@ pub fn check_platform_intrinsic_type(
| sym::simd_reduce_or
| sym::simd_reduce_xor
| sym::simd_reduce_min
| sym::simd_reduce_max
| sym::simd_reduce_min_nanless
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
| sym::simd_reduce_max => (2, 0, vec![param(0)], param(1)),
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
_ => {
Expand Down
14 changes: 14 additions & 0 deletions compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,20 @@ extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) {
}
}

// Enable the reassoc fast-math flag, allowing transformations that pretend
// floating-point addition and multiplication are associative.
//
// Note that this does NOT enable any flags which can cause a floating-point operation on
// well-defined inputs to return poison, and therefore this function can be used to build
// safe Rust intrinsics (such as fadd_algebraic).
//
// https://llvm.org/docs/LangRef.html#fast-math-flags
extern "C" void LLVMRustSetAllowReassoc(LLVMValueRef V) {
if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
I->setHasAllowReassoc(true);
}
}

extern "C" LLVMValueRef
LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
const char *Name, LLVMAtomicOrdering Order) {
Expand Down
2 changes: 0 additions & 2 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1553,9 +1553,7 @@ symbols! {
simd_reduce_and,
simd_reduce_any,
simd_reduce_max,
simd_reduce_max_nanless,
simd_reduce_min,
simd_reduce_min_nanless,
simd_reduce_mul_ordered,
simd_reduce_mul_unordered,
simd_reduce_or,
Expand Down
69 changes: 69 additions & 0 deletions library/core/src/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
//! In this module, a "vector" is any `repr(simd)` type.
extern "platform-intrinsic" {
/// Insert an element into a vector, returning the updated vector.
///
/// `T` must be a vector with element type `U`.
///
/// # Safety
///
/// `idx` must be in-bounds of the vector.
pub fn simd_insert<T, U>(x: T, idx: u32, val: U) -> T;

/// Extract an element from a vector.
///
/// `T` must be a vector with element type `U`.
///
/// # Safety
///
/// `idx` must be in-bounds of the vector.
pub fn simd_extract<T, U>(x: T, idx: u32) -> U;

/// Add two simd vectors elementwise.
///
/// `T` must be a vector of integer or floating point primitive types.
Expand Down Expand Up @@ -317,6 +335,14 @@ extern "platform-intrinsic" {
/// Starting with the value `y`, add the elements of `x` and accumulate.
pub fn simd_reduce_add_ordered<T, U>(x: T, y: U) -> U;

/// Add elements within a vector in arbitrary order. May also be re-associated with
/// unordered additions on the inputs/outputs.
///
/// `T` must be a vector of integer or floating-point primitive types.
///
/// `U` must be the element type of `T`.
pub fn simd_reduce_add_unordered<T, U>(x: T) -> U;

/// Multiply elements within a vector from left to right.
///
/// `T` must be a vector of integer or floating-point primitive types.
Expand All @@ -326,6 +352,14 @@ extern "platform-intrinsic" {
/// Starting with the value `y`, multiply the elements of `x` and accumulate.
pub fn simd_reduce_mul_ordered<T, U>(x: T, y: U) -> U;

/// Add elements within a vector in arbitrary order. May also be re-associated with
/// unordered additions on the inputs/outputs.
///
/// `T` must be a vector of integer or floating-point primitive types.
///
/// `U` must be the element type of `T`.
pub fn simd_reduce_mul_unordered<T, U>(x: T) -> U;

/// Check if all mask values are true.
///
/// `T` must be a vector of integer primitive types.
Expand Down Expand Up @@ -518,4 +552,39 @@ extern "platform-intrinsic" {
///
/// `T` must be a vector of floats.
pub fn simd_fma<T>(x: T, y: T, z: T) -> T;

// Computes the sine of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fsin<T>(a: T) -> T;

// Computes the cosine of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fcos<T>(a: T) -> T;

// Computes the exponential function of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fexp<T>(a: T) -> T;

// Computes 2 raised to the power of each element.
///
/// `T` must be a vector of floats.
pub fn simd_fexp2<T>(a: T) -> T;

// Computes the base 10 logarithm of each element.
///
/// `T` must be a vector of floats.
pub fn simd_flog10<T>(a: T) -> T;

// Computes the base 2 logarithm of each element.
///
/// `T` must be a vector of floats.
pub fn simd_flog2<T>(a: T) -> T;

// Computes the natural logarithm of each element.
///
/// `T` must be a vector of floats.
pub fn simd_flog<T>(a: T) -> T;
}
6 changes: 0 additions & 6 deletions tests/ui/simd/intrinsic/generic-reduction-pass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@ extern "platform-intrinsic" {
fn simd_reduce_mul_ordered<T, U>(x: T, acc: U) -> U;
fn simd_reduce_min<T, U>(x: T) -> U;
fn simd_reduce_max<T, U>(x: T) -> U;
fn simd_reduce_min_nanless<T, U>(x: T) -> U;
fn simd_reduce_max_nanless<T, U>(x: T) -> U;
fn simd_reduce_and<T, U>(x: T) -> U;
fn simd_reduce_or<T, U>(x: T) -> U;
fn simd_reduce_xor<T, U>(x: T) -> U;
Expand Down Expand Up @@ -127,10 +125,6 @@ fn main() {
assert_eq!(r, -2_f32);
let r: f32 = simd_reduce_max(x);
assert_eq!(r, 4_f32);
let r: f32 = simd_reduce_min_nanless(x);
assert_eq!(r, -2_f32);
let r: f32 = simd_reduce_max_nanless(x);
assert_eq!(r, 4_f32);
}

unsafe {
Expand Down

0 comments on commit 975c7bf

Please sign in to comment.