Skip to content

Commit 713cdcd

Browse files
Rollup merge of #121062 - RustyYato:f32-midpoint, r=the8472
Change f32::midpoint to upcast to f64 This has been verified by kani as a correct optimization see: #110840 (comment) The new implementation is branchless and only differs in which NaN values are produced (if any are produced at all), which is fine to change. Aside from NaN handling, this implementation produces bitwise identical results to the original implementation. Question: do we need a codegen test for this? I didn't add one, since the original PR #92048 didn't have any codegen tests.
2 parents eda9d7f + 849c525 commit 713cdcd

File tree

2 files changed

+62
-22
lines changed

2 files changed

+62
-22
lines changed

library/core/src/num/f32.rs

+36-19
Original file line numberDiff line numberDiff line change
@@ -1030,25 +1030,42 @@ impl f32 {
10301030
/// ```
10311031
#[unstable(feature = "num_midpoint", issue = "110840")]
10321032
pub fn midpoint(self, other: f32) -> f32 {
1033-
const LO: f32 = f32::MIN_POSITIVE * 2.;
1034-
const HI: f32 = f32::MAX / 2.;
1035-
1036-
let (a, b) = (self, other);
1037-
let abs_a = a.abs_private();
1038-
let abs_b = b.abs_private();
1039-
1040-
if abs_a <= HI && abs_b <= HI {
1041-
// Overflow is impossible
1042-
(a + b) / 2.
1043-
} else if abs_a < LO {
1044-
// Not safe to halve a
1045-
a + (b / 2.)
1046-
} else if abs_b < LO {
1047-
// Not safe to halve b
1048-
(a / 2.) + b
1049-
} else {
1050-
// Not safe to halve a and b
1051-
(a / 2.) + (b / 2.)
1033+
cfg_if! {
1034+
if #[cfg(any(
1035+
target_arch = "x86_64",
1036+
target_arch = "aarch64",
1037+
all(any(target_arch="riscv32", target_arch= "riscv64"), target_feature="d"),
1038+
all(target_arch = "arm", target_feature="vfp2"),
1039+
target_arch = "wasm32",
1040+
target_arch = "wasm64",
1041+
))] {
1042+
// whitelist the faster implementation to targets that have known good 64-bit float
1043+
// implementations. Falling back to the branchy code on targets that don't have
1044+
// 64-bit hardware floats or buggy implementations.
1045+
// see: https://github.com/rust-lang/rust/pull/121062#issuecomment-2123408114
1046+
((f64::from(self) + f64::from(other)) / 2.0) as f32
1047+
} else {
1048+
const LO: f32 = f32::MIN_POSITIVE * 2.;
1049+
const HI: f32 = f32::MAX / 2.;
1050+
1051+
let (a, b) = (self, other);
1052+
let abs_a = a.abs_private();
1053+
let abs_b = b.abs_private();
1054+
1055+
if abs_a <= HI && abs_b <= HI {
1056+
// Overflow is impossible
1057+
(a + b) / 2.
1058+
} else if abs_a < LO {
1059+
// Not safe to halve a
1060+
a + (b / 2.)
1061+
} else if abs_b < LO {
1062+
// Not safe to halve b
1063+
(a / 2.) + b
1064+
} else {
1065+
// Not safe to halve a and b
1066+
(a / 2.) + (b / 2.)
1067+
}
1068+
}
10521069
}
10531070
}
10541071

library/core/tests/num/mod.rs

+26-3
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ assume_usize_width! {
729729
}
730730

731731
macro_rules! test_float {
732-
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr) => {
732+
($modname: ident, $fty: ty, $inf: expr, $neginf: expr, $nan: expr, $min: expr, $max: expr, $min_pos: expr, $max_exp:expr) => {
733733
mod $modname {
734734
#[test]
735735
fn min() {
@@ -880,6 +880,27 @@ macro_rules! test_float {
880880
assert!(($nan as $fty).midpoint(1.0).is_nan());
881881
assert!((1.0 as $fty).midpoint($nan).is_nan());
882882
assert!(($nan as $fty).midpoint($nan).is_nan());
883+
884+
// test if large differences in magnitude are still correctly computed.
885+
// NOTE: that because of how small x and y are, x + y can never overflow
886+
// so (x + y) / 2.0 is always correct
887+
// in particular, `2.pow(i)` will never be at the max exponent, so it could
888+
// be safely doubled, while j is significantly smaller.
889+
for i in $max_exp.saturating_sub(64)..$max_exp {
890+
for j in 0..64u8 {
891+
let large = <$fty>::from(2.0f32).powi(i);
892+
// a much smaller number, such that there is no chance of overflow to test
893+
// potential double rounding in midpoint's implementation.
894+
let small = <$fty>::from(2.0f32).powi($max_exp - 1)
895+
* <$fty>::EPSILON
896+
* <$fty>::from(j);
897+
898+
let naive = (large + small) / 2.0;
899+
let midpoint = large.midpoint(small);
900+
901+
assert_eq!(naive, midpoint);
902+
}
903+
}
883904
}
884905
#[test]
885906
fn rem_euclid() {
@@ -912,7 +933,8 @@ test_float!(
912933
f32::NAN,
913934
f32::MIN,
914935
f32::MAX,
915-
f32::MIN_POSITIVE
936+
f32::MIN_POSITIVE,
937+
f32::MAX_EXP
916938
);
917939
test_float!(
918940
f64,
@@ -922,5 +944,6 @@ test_float!(
922944
f64::NAN,
923945
f64::MIN,
924946
f64::MAX,
925-
f64::MIN_POSITIVE
947+
f64::MIN_POSITIVE,
948+
f64::MAX_EXP
926949
);

0 commit comments

Comments
 (0)