Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions src/intrinsics/math.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
use rand::Rng;
use rustc_apfloat::{self, Float, Round};
use rustc_apfloat::{self, Float, FloatConvert, Round};
use rustc_middle::mir;
use rustc_middle::ty::{self, FloatTy};

use self::helpers::{ToHost, ToSoft};
use super::check_intrinsic_arg_count;
use crate::*;

fn sqrt<'tcx, F: Float + FloatConvert<F> + Into<Scalar>>(
this: &mut MiriInterpCx<'tcx>,
args: &[OpTy<'tcx>],
dest: &MPlaceTy<'tcx>,
) -> InterpResult<'tcx> {
let [f] = check_intrinsic_arg_count(args)?;
let f = this.read_scalar(f)?;
let f: F = f.to_float()?;
// Sqrt is specified to be fully precise.
let res = math::sqrt(f);
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)
}

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn emulate_math_intrinsic(
Expand All @@ -20,22 +34,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

match intrinsic_name {
// Operations we can do with soft-floats.
"sqrtf32" => {
let [f] = check_intrinsic_arg_count(args)?;
let f = this.read_scalar(f)?.to_f32()?;
// Sqrt is specified to be fully precise.
let res = math::sqrt(f);
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}
"sqrtf64" => {
let [f] = check_intrinsic_arg_count(args)?;
let f = this.read_scalar(f)?.to_f64()?;
// Sqrt is specified to be fully precise.
let res = math::sqrt(f);
let res = this.adjust_nan(res, &[f]);
this.write_scalar(res, dest)?;
}
"sqrtf16" => sqrt::<rustc_apfloat::ieee::Half>(this, args, dest)?,
"sqrtf32" => sqrt::<rustc_apfloat::ieee::Single>(this, args, dest)?,
"sqrtf64" => sqrt::<rustc_apfloat::ieee::Double>(this, args, dest)?,
"sqrtf128" => sqrt::<rustc_apfloat::ieee::Quad>(this, args, dest)?,

"fmaf32" => {
let [a, b, c] = check_intrinsic_arg_count(args)?;
Expand Down
12 changes: 6 additions & 6 deletions src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ops::Neg;
use std::{f32, f64};

use rand::Rng as _;
use rustc_apfloat::Float as _;
use rustc_apfloat::Float;
use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS};
use rustc_middle::ty::{self, FloatTy, ScalarInt};

Expand Down Expand Up @@ -317,19 +317,19 @@ where
}
}

pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
pub(crate) fn sqrt<F: Float>(x: F) -> F {
match x.category() {
// preserve zero sign
rustc_apfloat::Category::Zero => x,
// propagate NaN
rustc_apfloat::Category::NaN => x,
// sqrt of negative number is NaN
_ if x.is_negative() => IeeeFloat::NAN,
_ if x.is_negative() => F::NAN,
// sqrt(∞) = ∞
rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
rustc_apfloat::Category::Infinity => F::INFINITY,
rustc_apfloat::Category::Normal => {
// Floating point precision, excluding the integer bit
let prec = i32::try_from(S::PRECISION).unwrap() - 1;
let prec = i32::try_from(F::PRECISION).unwrap() - 1;

// x = 2^(exp - prec) * mant
// where mant is an integer with prec+1 bits
Expand Down Expand Up @@ -394,7 +394,7 @@ pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFl
res = (res + 1) >> 1;

// Build resulting value with res as mantissa and exp/2 as exponent
IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
F::from_u128(res).value.scalbn(exp / 2 - prec)
}
}
}
Expand Down
44 changes: 29 additions & 15 deletions tests/pass/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,35 @@ fn basic() {
assert_eq!(34.2f64.abs(), 34.2f64);
assert_eq!((-1.0f128).abs(), 1.0f128);
assert_eq!(34.2f128.abs(), 34.2f128);

assert_eq!(64_f16.sqrt(), 8_f16);
assert_eq!(64_f32.sqrt(), 8_f32);
assert_eq!(64_f64.sqrt(), 8_f64);
assert_eq!(64_f128.sqrt(), 8_f128);
assert_eq!(f16::INFINITY.sqrt(), f16::INFINITY);
assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
assert_eq!(f128::INFINITY.sqrt(), f128::INFINITY);
assert_eq!(0.0_f16.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!(0.0_f128.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f16).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f128).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert!((-5.0_f16).sqrt().is_nan());
assert!((-5.0_f32).sqrt().is_nan());
assert!((-5.0_f64).sqrt().is_nan());
assert!((-5.0_f128).sqrt().is_nan());
assert!(f16::NEG_INFINITY.sqrt().is_nan());
assert!(f32::NEG_INFINITY.sqrt().is_nan());
assert!(f64::NEG_INFINITY.sqrt().is_nan());
assert!(f128::NEG_INFINITY.sqrt().is_nan());
assert!(f16::NAN.sqrt().is_nan());
assert!(f32::NAN.sqrt().is_nan());
assert!(f64::NAN.sqrt().is_nan());
assert!(f128::NAN.sqrt().is_nan());
}

/// Test casts from floats to ints and back
Expand Down Expand Up @@ -1012,21 +1041,6 @@ pub fn libm() {
unsafe { ldexp(a, b) }
}

assert_eq!(64_f32.sqrt(), 8_f32);
assert_eq!(64_f64.sqrt(), 8_f64);
assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
assert!((-5.0_f32).sqrt().is_nan());
assert!((-5.0_f64).sqrt().is_nan());
assert!(f32::NEG_INFINITY.sqrt().is_nan());
assert!(f64::NEG_INFINITY.sqrt().is_nan());
assert!(f32::NAN.sqrt().is_nan());
assert!(f64::NAN.sqrt().is_nan());

assert_approx_eq!(25f32.powi(-2), 0.0016f32);
assert_approx_eq!(23.2f64.powi(2), 538.24f64);

Expand Down
Loading