Skip to content

Commit

Permalink
add support for missing SIMD float intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed Mar 23, 2024
1 parent 51ff5cd commit 9509f21
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
44 changes: 37 additions & 7 deletions src/shims/intrinsics/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
| "round"
| "trunc"
| "fsqrt"
| "fsin"
| "fcos"
| "fexp"
| "fexp2"
| "flog"
| "flog2"
| "flog10"
| "ctlz"
| "cttz"
| "bswap"
Expand All @@ -45,17 +52,16 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
assert_eq!(dest_len, op_len);

#[derive(Copy, Clone)]
enum Op {
enum Op<'a> {
MirOp(mir::UnOp),
Abs,
Sqrt,
Round(rustc_apfloat::Round),
Numeric(Symbol),
HostOp(&'a str),
}
let which = match intrinsic_name {
"neg" => Op::MirOp(mir::UnOp::Neg),
"fabs" => Op::Abs,
"fsqrt" => Op::Sqrt,
"ceil" => Op::Round(rustc_apfloat::Round::TowardPositive),
"floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
"round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
Expand All @@ -64,7 +70,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
"cttz" => Op::Numeric(sym::cttz),
"bswap" => Op::Numeric(sym::bswap),
"bitreverse" => Op::Numeric(sym::bitreverse),
_ => unreachable!(),
_ => Op::HostOp(intrinsic_name),
};

for i in 0..dest_len {
Expand All @@ -89,7 +95,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
FloatTy::F128 => unimplemented!("f16_f128"),
}
}
Op::Sqrt => {
Op::HostOp(host_op) => {
let ty::Float(float_ty) = op.layout.ty.kind() else {
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
};
Expand All @@ -98,13 +104,37 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
FloatTy::F16 => unimplemented!("f16_f128"),
FloatTy::F32 => {
let f = op.to_scalar().to_f32()?;
let res = f.to_host().sqrt().to_soft();
let f_host = f.to_host();
let res = match host_op {
"fsqrt" => f_host.sqrt(),
"fsin" => f_host.sin(),
"fcos" => f_host.cos(),
"fexp" => f_host.exp(),
"fexp2" => f_host.exp2(),
"flog" => f_host.ln(),
"flog2" => f_host.log2(),
"flog10" => f_host.log10(),
_ => bug!(),
};
let res = res.to_soft();
let res = this.adjust_nan(res, &[f]);
Scalar::from(res)
}
FloatTy::F64 => {
let f = op.to_scalar().to_f64()?;
let res = f.to_host().sqrt().to_soft();
let f_host = f.to_host();
let res = match host_op {
"fsqrt" => f_host.sqrt(),
"fsin" => f_host.sin(),
"fcos" => f_host.cos(),
"fexp" => f_host.exp(),
"fexp2" => f_host.exp2(),
"flog" => f_host.ln(),
"flog2" => f_host.log2(),
"flog10" => f_host.log10(),
_ => bug!(),
};
let res = res.to_soft();
let res = this.adjust_nan(res, &[f]);
Scalar::from(res)
}
Expand Down
18 changes: 18 additions & 0 deletions tests/pass/portable-simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,23 @@ fn simd_intrinsics() {
}
}

fn simd_float_intrinsics() {
use intrinsics::*;

// These are just smoke tests to ensure the intrinsics can be called.
unsafe {
let a = f32x4::splat(10.0);
simd_fsqrt(a);
simd_fsin(a);
simd_fcos(a);
simd_fexp(a);
simd_fexp2(a);
simd_flog(a);
simd_flog2(a);
simd_flog10(a);
}
}

fn simd_masked_loadstore() {
// The buffer is deliberarely too short, so reading the last element would be UB.
let buf = [3i32; 3];
Expand Down Expand Up @@ -559,5 +576,6 @@ fn main() {
simd_gather_scatter();
simd_round();
simd_intrinsics();
simd_float_intrinsics();
simd_masked_loadstore();
}

0 comments on commit 9509f21

Please sign in to comment.