Skip to content

Commit

Permalink
miri: improve support for f16 and f128
Browse files Browse the repository at this point in the history
Rounding intrinsics are now implemented for `f16` and `f128` and tests for `is_infinite`, NaN, `abs`, `copysign`, `min`, `max`, rounding, `*_fast` and `*_algebraic` have been added.
  • Loading branch information
eduardosm committed Oct 17, 2024
1 parent f7ccac9 commit d7e91ba
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 9 deletions.
30 changes: 30 additions & 0 deletions src/tools/miri/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
this.write_scalar(Scalar::from_bool(branch), dest)?;
}

"floorf16" | "ceilf16" | "truncf16" | "roundf16" | "rintf16" => {
let [f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f16()?;
let mode = match intrinsic_name {
"floorf16" => Round::TowardNegative,
"ceilf16" => Round::TowardPositive,
"truncf16" => Round::TowardZero,
"roundf16" => Round::NearestTiesToAway,
"rintf16" => Round::NearestTiesToEven,
_ => bug!(),
};
let res = f.round_to_integral(mode).value;
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}
"floorf32" | "ceilf32" | "truncf32" | "roundf32" | "rintf32" => {
let [f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f32()?;
Expand Down Expand Up @@ -175,6 +190,21 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}
"floorf128" | "ceilf128" | "truncf128" | "roundf128" | "rintf128" => {
let [f] = check_arg_count(args)?;
let f = this.read_scalar(f)?.to_f128()?;
let mode = match intrinsic_name {
"floorf128" => Round::TowardNegative,
"ceilf128" => Round::TowardPositive,
"truncf128" => Round::TowardZero,
"roundf128" => Round::NearestTiesToAway,
"rintf128" => Round::NearestTiesToEven,
_ => bug!(),
};
let res = f.round_to_integral(mode).value;
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}

#[rustfmt::skip]
| "sinf32"
Expand Down
160 changes: 151 additions & 9 deletions src/tools/miri/tests/pass/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,18 @@ fn basic() {
assert_eq(-{ 5.0_f128 }, -5.0_f128);

// infinities, NaN
// FIXME(f16_f128): add when constants and `is_infinite` are available
assert!((5.0_f16 / 0.0).is_infinite());
assert_ne!({ 5.0_f16 / 0.0 }, { -5.0_f16 / 0.0 });
assert!((5.0_f32 / 0.0).is_infinite());
assert_ne!({ 5.0_f32 / 0.0 }, { -5.0_f32 / 0.0 });
assert!((5.0_f64 / 0.0).is_infinite());
assert_ne!({ 5.0_f64 / 0.0 }, { 5.0_f64 / -0.0 });
assert!((5.0_f128 / 0.0).is_infinite());
assert_ne!({ 5.0_f128 / 0.0 }, { 5.0_f128 / -0.0 });
assert_ne!(f16::NAN, f16::NAN);
assert_ne!(f32::NAN, f32::NAN);
assert_ne!(f64::NAN, f64::NAN);
assert_ne!(f128::NAN, f128::NAN);

// negative zero
let posz = 0.0f16;
Expand Down Expand Up @@ -215,9 +220,14 @@ fn basic() {
assert!((black_box(-1.0f128) % 1.0).is_sign_negative());
assert!((black_box(-1.0f128) % -1.0).is_sign_negative());

// FIXME(f16_f128): add when `abs` is available
assert_eq!((-1.0f16).abs(), 1.0f16);
assert_eq!(34.2f16.abs(), 34.2f16);
assert_eq!((-1.0f32).abs(), 1.0f32);
assert_eq!(34.2f32.abs(), 34.2f32);
assert_eq!((-1.0f64).abs(), 1.0f64);
assert_eq!(34.2f64.abs(), 34.2f64);
assert_eq!((-1.0f128).abs(), 1.0f128);
assert_eq!(34.2f128.abs(), 34.2f128);
}

/// Test casts from floats to ints and back
Expand Down Expand Up @@ -654,6 +664,14 @@ fn casts() {
}

fn ops() {
// f16 min/max
assert_eq((1.0_f16).max(-1.0), 1.0);
assert_eq((1.0_f16).min(-1.0), -1.0);
assert_eq(f16::NAN.min(9.0), 9.0);
assert_eq(f16::NAN.max(-9.0), -9.0);
assert_eq((9.0_f16).min(f16::NAN), 9.0);
assert_eq((-9.0_f16).max(f16::NAN), -9.0);

// f32 min/max
assert_eq((1.0 as f32).max(-1.0), 1.0);
assert_eq((1.0 as f32).min(-1.0), -1.0);
Expand All @@ -670,6 +688,21 @@ fn ops() {
assert_eq((9.0 as f64).min(f64::NAN), 9.0);
assert_eq((-9.0 as f64).max(f64::NAN), -9.0);

// f128 min/max
assert_eq((1.0_f128).max(-1.0), 1.0);
assert_eq((1.0_f128).min(-1.0), -1.0);
assert_eq(f128::NAN.min(9.0), 9.0);
assert_eq(f128::NAN.max(-9.0), -9.0);
assert_eq((9.0_f128).min(f128::NAN), 9.0);
assert_eq((-9.0_f128).max(f128::NAN), -9.0);

// f16 copysign
assert_eq(3.5_f16.copysign(0.42), 3.5_f16);
assert_eq(3.5_f16.copysign(-0.42), -3.5_f16);
assert_eq((-3.5_f16).copysign(0.42), 3.5_f16);
assert_eq((-3.5_f16).copysign(-0.42), -3.5_f16);
assert!(f16::NAN.copysign(1.0).is_nan());

// f32 copysign
assert_eq(3.5_f32.copysign(0.42), 3.5_f32);
assert_eq(3.5_f32.copysign(-0.42), -3.5_f32);
Expand All @@ -683,6 +716,13 @@ fn ops() {
assert_eq((-3.5_f64).copysign(0.42), 3.5_f64);
assert_eq((-3.5_f64).copysign(-0.42), -3.5_f64);
assert!(f64::NAN.copysign(1.0).is_nan());

// f128 copysign
assert_eq(3.5_f128.copysign(0.42), 3.5_f128);
assert_eq(3.5_f128.copysign(-0.42), -3.5_f128);
assert_eq((-3.5_f128).copysign(0.42), 3.5_f128);
assert_eq((-3.5_f128).copysign(-0.42), -3.5_f128);
assert!(f128::NAN.copysign(1.0).is_nan());
}

/// Tests taken from rustc test suite.
Expand Down Expand Up @@ -807,6 +847,18 @@ fn nan_casts() {

fn rounding() {
// Test cases taken from the library's tests for this feature
// f16
assert_eq(2.5f16.round_ties_even(), 2.0f16);
assert_eq(1.0f16.round_ties_even(), 1.0f16);
assert_eq(1.3f16.round_ties_even(), 1.0f16);
assert_eq(1.5f16.round_ties_even(), 2.0f16);
assert_eq(1.7f16.round_ties_even(), 2.0f16);
assert_eq(0.0f16.round_ties_even(), 0.0f16);
assert_eq((-0.0f16).round_ties_even(), -0.0f16);
assert_eq((-1.0f16).round_ties_even(), -1.0f16);
assert_eq((-1.3f16).round_ties_even(), -1.0f16);
assert_eq((-1.5f16).round_ties_even(), -2.0f16);
assert_eq((-1.7f16).round_ties_even(), -2.0f16);
// f32
assert_eq(2.5f32.round_ties_even(), 2.0f32);
assert_eq(1.0f32.round_ties_even(), 1.0f32);
Expand All @@ -831,23 +883,59 @@ fn rounding() {
assert_eq((-1.3f64).round_ties_even(), -1.0f64);
assert_eq((-1.5f64).round_ties_even(), -2.0f64);
assert_eq((-1.7f64).round_ties_even(), -2.0f64);

// f128
assert_eq(2.5f128.round_ties_even(), 2.0f128);
assert_eq(1.0f128.round_ties_even(), 1.0f128);
assert_eq(1.3f128.round_ties_even(), 1.0f128);
assert_eq(1.5f128.round_ties_even(), 2.0f128);
assert_eq(1.7f128.round_ties_even(), 2.0f128);
assert_eq(0.0f128.round_ties_even(), 0.0f128);
assert_eq((-0.0f128).round_ties_even(), -0.0f128);
assert_eq((-1.0f128).round_ties_even(), -1.0f128);
assert_eq((-1.3f128).round_ties_even(), -1.0f128);
assert_eq((-1.5f128).round_ties_even(), -2.0f128);
assert_eq((-1.7f128).round_ties_even(), -2.0f128);

assert_eq!(3.8f16.floor(), 3.0f16);
assert_eq!((-1.1f16).floor(), -2.0f16);
assert_eq!(3.8f32.floor(), 3.0f32);
assert_eq!((-1.1f32).floor(), -2.0f32);
assert_eq!(3.8f64.floor(), 3.0f64);
assert_eq!((-1.1f64).floor(), -2.0f64);
assert_eq!(3.8f128.floor(), 3.0f128);
assert_eq!((-1.1f128).floor(), -2.0f128);

assert_eq!(3.8f16.ceil(), 4.0f16);
assert_eq!((-2.3f16).ceil(), -2.0f16);
assert_eq!(3.8f32.ceil(), 4.0f32);
assert_eq!((-2.3f32).ceil(), -2.0f32);
assert_eq!(3.8f64.ceil(), 4.0f64);
assert_eq!((-2.3f64).ceil(), -2.0f64);
assert_eq!(3.8f128.ceil(), 4.0f128);
assert_eq!((-2.3f128).ceil(), -2.0f128);

assert_eq!(0.1f16.trunc(), 0.0f16);
assert_eq!((-0.1f16).trunc(), 0.0f16);
assert_eq!(0.1f32.trunc(), 0.0f32);
assert_eq!((-0.1f32).trunc(), 0.0f32);
assert_eq!(0.1f64.trunc(), 0.0f64);
assert_eq!((-0.1f64).trunc(), 0.0f64);
assert_eq!(0.1f128.trunc(), 0.0f128);
assert_eq!((-0.1f128).trunc(), 0.0f128);

assert_eq!(3.3_f16.round(), 3.0);
assert_eq!(2.5_f16.round(), 3.0);
assert_eq!(3.3_f32.round(), 3.0);
assert_eq!(2.5_f32.round(), 3.0);
assert_eq!(3.9_f64.round(), 4.0);
assert_eq!(2.5_f64.round(), 3.0);
assert_eq!(3.9_f128.round(), 4.0);
assert_eq!(2.5_f128.round(), 3.0);
}

fn mul_add() {
// FIXME(f16_f128): add when supported

assert_eq!(3.0f32.mul_add(2.0f32, 5.0f32), 11.0);
assert_eq!(0.0f32.mul_add(-2.0, f32::consts::E), f32::consts::E);
assert_eq!(3.0f64.mul_add(2.0, 5.0), 11.0);
Expand Down Expand Up @@ -983,7 +1071,7 @@ fn test_fast() {
use std::intrinsics::{fadd_fast, fdiv_fast, fmul_fast, frem_fast, fsub_fast};

#[inline(never)]
pub fn test_operations_f64(a: f64, b: f64) {
pub fn test_operations_f16(a: f16, b: f16) {
// make sure they all map to the correct operation
unsafe {
assert_eq!(fadd_fast(a, b), a + b);
Expand All @@ -1006,10 +1094,38 @@ fn test_fast() {
}
}

test_operations_f64(1., 2.);
test_operations_f64(10., 5.);
#[inline(never)]
pub fn test_operations_f64(a: f64, b: f64) {
// make sure they all map to the correct operation
unsafe {
assert_eq!(fadd_fast(a, b), a + b);
assert_eq!(fsub_fast(a, b), a - b);
assert_eq!(fmul_fast(a, b), a * b);
assert_eq!(fdiv_fast(a, b), a / b);
assert_eq!(frem_fast(a, b), a % b);
}
}

#[inline(never)]
pub fn test_operations_f128(a: f128, b: f128) {
// make sure they all map to the correct operation
unsafe {
assert_eq!(fadd_fast(a, b), a + b);
assert_eq!(fsub_fast(a, b), a - b);
assert_eq!(fmul_fast(a, b), a * b);
assert_eq!(fdiv_fast(a, b), a / b);
assert_eq!(frem_fast(a, b), a % b);
}
}

test_operations_f16(11., 2.);
test_operations_f16(10., 15.);
test_operations_f32(11., 2.);
test_operations_f32(10., 15.);
test_operations_f64(1., 2.);
test_operations_f64(10., 5.);
test_operations_f128(1., 2.);
test_operations_f128(10., 5.);
}

fn test_algebraic() {
Expand All @@ -1018,7 +1134,7 @@ fn test_algebraic() {
};

#[inline(never)]
pub fn test_operations_f64(a: f64, b: f64) {
pub fn test_operations_f16(a: f16, b: f16) {
// make sure they all map to the correct operation
assert_eq!(fadd_algebraic(a, b), a + b);
assert_eq!(fsub_algebraic(a, b), a - b);
Expand All @@ -1037,15 +1153,41 @@ fn test_algebraic() {
assert_eq!(frem_algebraic(a, b), a % b);
}

test_operations_f64(1., 2.);
test_operations_f64(10., 5.);
#[inline(never)]
pub fn test_operations_f64(a: f64, b: f64) {
// make sure they all map to the correct operation
assert_eq!(fadd_algebraic(a, b), a + b);
assert_eq!(fsub_algebraic(a, b), a - b);
assert_eq!(fmul_algebraic(a, b), a * b);
assert_eq!(fdiv_algebraic(a, b), a / b);
assert_eq!(frem_algebraic(a, b), a % b);
}

#[inline(never)]
pub fn test_operations_f128(a: f128, b: f128) {
// make sure they all map to the correct operation
assert_eq!(fadd_algebraic(a, b), a + b);
assert_eq!(fsub_algebraic(a, b), a - b);
assert_eq!(fmul_algebraic(a, b), a * b);
assert_eq!(fdiv_algebraic(a, b), a / b);
assert_eq!(frem_algebraic(a, b), a % b);
}

test_operations_f16(11., 2.);
test_operations_f16(10., 15.);
test_operations_f32(11., 2.);
test_operations_f32(10., 15.);
test_operations_f64(1., 2.);
test_operations_f64(10., 5.);
test_operations_f128(1., 2.);
test_operations_f128(10., 5.);
}

fn test_fmuladd() {
use std::intrinsics::{fmuladdf32, fmuladdf64};

// FIXME(f16_f128): add when supported

#[inline(never)]
pub fn test_operations_f32(a: f32, b: f32, c: f32) {
assert_approx_eq!(unsafe { fmuladdf32(a, b, c) }, a * b + c);
Expand Down

0 comments on commit d7e91ba

Please sign in to comment.