From ce8bf3b696004353ac4209b7f63b940a44e8d191 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Thu, 18 Sep 2025 22:02:46 +0200 Subject: [PATCH 1/2] implement sqrt for f16 and f128 --- src/intrinsics/math.rs | 16 +++++++++++++++ tests/pass/float.rs | 44 ++++++++++++++++++++++++++++-------------- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/src/intrinsics/math.rs b/src/intrinsics/math.rs index 21d4a92e7d..ca81454a98 100644 --- a/src/intrinsics/math.rs +++ b/src/intrinsics/math.rs @@ -20,6 +20,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { match intrinsic_name { // Operations we can do with soft-floats. + "sqrtf16" => { + let [f] = check_intrinsic_arg_count(args)?; + let f = this.read_scalar(f)?.to_f16()?; + // Sqrt is specified to be fully precise. + let res = math::sqrt(f); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest)?; + } "sqrtf32" => { let [f] = check_intrinsic_arg_count(args)?; let f = this.read_scalar(f)?.to_f32()?; @@ -36,6 +44,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let res = this.adjust_nan(res, &[f]); this.write_scalar(res, dest)?; } + "sqrtf128" => { + let [f] = check_intrinsic_arg_count(args)?; + let f = this.read_scalar(f)?.to_f128()?; + // Sqrt is specified to be fully precise. + let res = math::sqrt(f); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest)?; + } "fmaf32" => { let [a, b, c] = check_intrinsic_arg_count(args)?; diff --git a/tests/pass/float.rs b/tests/pass/float.rs index 9f1b3f612b..3ce5ea8356 100644 --- a/tests/pass/float.rs +++ b/tests/pass/float.rs @@ -281,6 +281,35 @@ fn basic() { assert_eq!(34.2f64.abs(), 34.2f64); assert_eq!((-1.0f128).abs(), 1.0f128); assert_eq!(34.2f128.abs(), 34.2f128); + + assert_eq!(64_f16.sqrt(), 8_f16); + assert_eq!(64_f32.sqrt(), 8_f32); + assert_eq!(64_f64.sqrt(), 8_f64); + assert_eq!(64_f128.sqrt(), 8_f128); + assert_eq!(f16::INFINITY.sqrt(), f16::INFINITY); + assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); + assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); + assert_eq!(f128::INFINITY.sqrt(), f128::INFINITY); + assert_eq!(0.0_f16.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f128.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f16).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f128).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert!((-5.0_f16).sqrt().is_nan()); + assert!((-5.0_f32).sqrt().is_nan()); + assert!((-5.0_f64).sqrt().is_nan()); + assert!((-5.0_f128).sqrt().is_nan()); + assert!(f16::NEG_INFINITY.sqrt().is_nan()); + assert!(f32::NEG_INFINITY.sqrt().is_nan()); + assert!(f64::NEG_INFINITY.sqrt().is_nan()); + assert!(f128::NEG_INFINITY.sqrt().is_nan()); + assert!(f16::NAN.sqrt().is_nan()); + assert!(f32::NAN.sqrt().is_nan()); + assert!(f64::NAN.sqrt().is_nan()); + assert!(f128::NAN.sqrt().is_nan()); } /// Test casts from floats to ints and back @@ -1012,21 +1041,6 @@ pub fn libm() { unsafe { ldexp(a, b) } } - assert_eq!(64_f32.sqrt(), 8_f32); - assert_eq!(64_f64.sqrt(), 8_f64); - assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); - assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); - assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert!((-5.0_f32).sqrt().is_nan()); - assert!((-5.0_f64).sqrt().is_nan()); - assert!(f32::NEG_INFINITY.sqrt().is_nan()); - assert!(f64::NEG_INFINITY.sqrt().is_nan()); - assert!(f32::NAN.sqrt().is_nan()); - assert!(f64::NAN.sqrt().is_nan()); - assert_approx_eq!(25f32.powi(-2), 0.0016f32); assert_approx_eq!(23.2f64.powi(2), 538.24f64); From d019108b23e27e15fb387fb9bd307a792e4205a5 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Thu, 18 Sep 2025 22:13:15 +0200 Subject: [PATCH 2/2] share sqrt implemention across float types --- src/intrinsics/math.rs | 52 +++++++++++++++--------------------------- src/math.rs | 12 +++++----- 2 files changed, 25 insertions(+), 39 deletions(-) diff --git a/src/intrinsics/math.rs b/src/intrinsics/math.rs index ca81454a98..b9c99f2859 100644 --- a/src/intrinsics/math.rs +++ b/src/intrinsics/math.rs @@ -1,5 +1,5 @@ use rand::Rng; -use rustc_apfloat::{self, Float, Round}; +use rustc_apfloat::{self, Float, FloatConvert, Round}; use rustc_middle::mir; use rustc_middle::ty::{self, FloatTy}; @@ -7,6 +7,20 @@ use self::helpers::{ToHost, ToSoft}; use super::check_intrinsic_arg_count; use crate::*; +fn sqrt<'tcx, F: Float + FloatConvert + Into>( + this: &mut MiriInterpCx<'tcx>, + args: &[OpTy<'tcx>], + dest: &MPlaceTy<'tcx>, +) -> InterpResult<'tcx> { + let [f] = check_intrinsic_arg_count(args)?; + let f = this.read_scalar(f)?; + let f: F = f.to_float()?; + // Sqrt is specified to be fully precise. + let res = math::sqrt(f); + let res = this.adjust_nan(res, &[f]); + this.write_scalar(res, dest) +} + impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {} pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { fn emulate_math_intrinsic( @@ -20,38 +34,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { match intrinsic_name { // Operations we can do with soft-floats. - "sqrtf16" => { - let [f] = check_intrinsic_arg_count(args)?; - let f = this.read_scalar(f)?.to_f16()?; - // Sqrt is specified to be fully precise. - let res = math::sqrt(f); - let res = this.adjust_nan(res, &[f]); - this.write_scalar(res, dest)?; - } - "sqrtf32" => { - let [f] = check_intrinsic_arg_count(args)?; - let f = this.read_scalar(f)?.to_f32()?; - // Sqrt is specified to be fully precise. - let res = math::sqrt(f); - let res = this.adjust_nan(res, &[f]); - this.write_scalar(res, dest)?; - } - "sqrtf64" => { - let [f] = check_intrinsic_arg_count(args)?; - let f = this.read_scalar(f)?.to_f64()?; - // Sqrt is specified to be fully precise. - let res = math::sqrt(f); - let res = this.adjust_nan(res, &[f]); - this.write_scalar(res, dest)?; - } - "sqrtf128" => { - let [f] = check_intrinsic_arg_count(args)?; - let f = this.read_scalar(f)?.to_f128()?; - // Sqrt is specified to be fully precise. - let res = math::sqrt(f); - let res = this.adjust_nan(res, &[f]); - this.write_scalar(res, dest)?; - } + "sqrtf16" => sqrt::(this, args, dest)?, + "sqrtf32" => sqrt::(this, args, dest)?, + "sqrtf64" => sqrt::(this, args, dest)?, + "sqrtf128" => sqrt::(this, args, dest)?, "fmaf32" => { let [a, b, c] = check_intrinsic_arg_count(args)?; diff --git a/src/math.rs b/src/math.rs index 401e6dd7aa..50472ed363 100644 --- a/src/math.rs +++ b/src/math.rs @@ -2,7 +2,7 @@ use std::ops::Neg; use std::{f32, f64}; use rand::Rng as _; -use rustc_apfloat::Float as _; +use rustc_apfloat::Float; use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS}; use rustc_middle::ty::{self, FloatTy, ScalarInt}; @@ -317,19 +317,19 @@ where } } -pub(crate) fn sqrt(x: IeeeFloat) -> IeeeFloat { +pub(crate) fn sqrt(x: F) -> F { match x.category() { // preserve zero sign rustc_apfloat::Category::Zero => x, // propagate NaN rustc_apfloat::Category::NaN => x, // sqrt of negative number is NaN - _ if x.is_negative() => IeeeFloat::NAN, + _ if x.is_negative() => F::NAN, // sqrt(∞) = ∞ - rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY, + rustc_apfloat::Category::Infinity => F::INFINITY, rustc_apfloat::Category::Normal => { // Floating point precision, excluding the integer bit - let prec = i32::try_from(S::PRECISION).unwrap() - 1; + let prec = i32::try_from(F::PRECISION).unwrap() - 1; // x = 2^(exp - prec) * mant // where mant is an integer with prec+1 bits @@ -394,7 +394,7 @@ pub(crate) fn sqrt(x: IeeeFloat) -> IeeeFl res = (res + 1) >> 1; // Build resulting value with res as mantissa and exp/2 as exponent - IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec) + F::from_u128(res).value.scalbn(exp / 2 - prec) } } }