diff --git a/library/core/src/num/f32.rs b/library/core/src/num/f32.rs index 047cb64ce5069..7551a50949201 100644 --- a/library/core/src/num/f32.rs +++ b/library/core/src/num/f32.rs @@ -1016,25 +1016,42 @@ impl f32 { /// ``` #[unstable(feature = "num_midpoint", issue = "110840")] pub fn midpoint(self, other: f32) -> f32 { - const LO: f32 = f32::MIN_POSITIVE * 2.; - const HI: f32 = f32::MAX / 2.; - - let (a, b) = (self, other); - let abs_a = a.abs_private(); - let abs_b = b.abs_private(); - - if abs_a <= HI && abs_b <= HI { - // Overflow is impossible - (a + b) / 2. - } else if abs_a < LO { - // Not safe to halve a - a + (b / 2.) - } else if abs_b < LO { - // Not safe to halve b - (a / 2.) + b - } else { - // Not safe to halve a and b - (a / 2.) + (b / 2.) + cfg_if! { + if #[cfg(any( + target_arch = "x86_64", + target_arch = "aarch64", + all(any(target_arch="riscv32", target_arch= "riscv64"), target_feature="d"), + all(target_arch = "arm", target_feature="vfp2"), + target_arch = "wasm32", + target_arch = "wasm64", + ))] { + // whitelist the faster implementation to targets that have known good 64-bit float + // implementations. Falling back to the branchy code on targets that don't have + // 64-bit hardware floats or buggy implementations. + // see: https://github.com/rust-lang/rust/pull/121062#issuecomment-2123408114 + ((f64::from(self) + f64::from(other)) / 2.0) as f32 + } else { + const LO: f32 = f32::MIN_POSITIVE * 2.; + const HI: f32 = f32::MAX / 2.; + + let (a, b) = (self, other); + let abs_a = a.abs_private(); + let abs_b = b.abs_private(); + + if abs_a <= HI && abs_b <= HI { + // Overflow is impossible + (a + b) / 2. + } else if abs_a < LO { + // Not safe to halve a + a + (b / 2.) + } else if abs_b < LO { + // Not safe to halve b + (a / 2.) + b + } else { + // Not safe to halve a and b + (a / 2.) + (b / 2.) + } + } } } diff --git a/library/core/tests/num/mod.rs b/library/core/tests/num/mod.rs index 863da9b18a289..4b41db72f9bd3 100644 --- a/library/core/tests/num/mod.rs +++ b/library/core/tests/num/mod.rs @@ -719,7 +719,7 @@ assume_usize_width! { } macro_rules! test_float { - ($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => { + ($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr, $max_exp:expr) => { mod $modname { #[test] fn min() { @@ -870,6 +870,27 @@ macro_rules! test_float { assert!(($nan as $fty).midpoint(1.0).is_nan()); assert!((1.0 as $fty).midpoint($nan).is_nan()); assert!(($nan as $fty).midpoint($nan).is_nan()); + + // test if large differences in magnitude are still correctly computed. + // NOTE: that because of how small x and y are, x + y can never overflow + // so (x + y) / 2.0 is always correct + // in particular, `2.pow(i)` will never be at the max exponent, so it could + // be safely doubled, while j is significantly smaller. + for i in $max_exp.saturating_sub(64)..$max_exp { + for j in 0..64u8 { + let large = <$fty>::from(2.0f32).powi(i); + // a much smaller number, such that there is no chance of overflow to test + // potential double rounding in midpoint's implementation. + let small = <$fty>::from(2.0f32).powi($max_exp - 1) + * <$fty>::EPSILON + * <$fty>::from(j); + + let naive = (large + small) / 2.0; + let midpoint = large.midpoint(small); + + assert_eq!(naive, midpoint); + } + } } #[test] fn rem_euclid() { @@ -902,7 +923,8 @@ test_float!( f32::NAN, f32::MIN, f32::MAX, - f32::MIN_POSITIVE + f32::MIN_POSITIVE, + f32::MAX_EXP ); test_float!( f64, @@ -912,5 +934,6 @@ test_float!( f64::NAN, f64::MIN, f64::MAX, - f64::MIN_POSITIVE + f64::MIN_POSITIVE, + f64::MAX_EXP );