Skip to content

Commit d98270b

Browse files
committedJan 31, 2025·
miri: make float min/max non-deterministic
1 parent 25a1657 commit d98270b

File tree

5 files changed

+64
-6
lines changed

5 files changed

+64
-6
lines changed
 

Diff for: ‎compiler/rustc_const_eval/src/interpret/intrinsics.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
747747
{
748748
let a: F = self.read_scalar(&args[0])?.to_float()?;
749749
let b: F = self.read_scalar(&args[1])?.to_float()?;
750-
let res = self.adjust_nan(a.min(b), &[a, b]);
750+
let res = if a == b {
751+
// They are definitely not NaN (those are never equal), but they could be `+0` and `-0`.
752+
// Let the machine decide which one to return.
753+
M::equal_float_min_max(self, a, b)
754+
} else {
755+
self.adjust_nan(a.min(b), &[a, b])
756+
};
751757
self.write_scalar(res, dest)?;
752758
interp_ok(())
753759
}
@@ -762,7 +768,13 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
762768
{
763769
let a: F = self.read_scalar(&args[0])?.to_float()?;
764770
let b: F = self.read_scalar(&args[1])?.to_float()?;
765-
let res = self.adjust_nan(a.max(b), &[a, b]);
771+
let res = if a == b {
772+
// They are definitely not NaN (those are never equal), but they could be `+0` and `-0`.
773+
// Let the machine decide which one to return.
774+
M::equal_float_min_max(self, a, b)
775+
} else {
776+
self.adjust_nan(a.max(b), &[a, b])
777+
};
766778
self.write_scalar(res, dest)?;
767779
interp_ok(())
768780
}

Diff for: ‎compiler/rustc_const_eval/src/interpret/machine.rs

+6
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,12 @@ pub trait Machine<'tcx>: Sized {
278278
F2::NAN
279279
}
280280

281+
/// Determines the result of `min`/`max` on floats when the arguments are equal.
282+
fn equal_float_min_max<F: Float>(_ecx: &InterpCx<'tcx, Self>, a: F, _b: F) -> F {
283+
// By default, we pick the left argument.
284+
a
285+
}
286+
281287
/// Called before a basic block terminator is executed.
282288
#[inline]
283289
fn before_terminator(_ecx: &mut InterpCx<'tcx, Self>) -> InterpResult<'tcx> {

Diff for: ‎src/tools/miri/src/machine.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::{fmt, process};
1111
use rand::rngs::StdRng;
1212
use rand::{Rng, SeedableRng};
1313
use rustc_abi::{Align, ExternAbi, Size};
14+
use rustc_apfloat::{Float, FloatConvert};
1415
use rustc_attr_parsing::InlineAttr;
1516
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
1617
#[allow(unused)]
@@ -1129,20 +1130,24 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> {
11291130
}
11301131

11311132
#[inline(always)]
1132-
fn generate_nan<
1133-
F1: rustc_apfloat::Float + rustc_apfloat::FloatConvert<F2>,
1134-
F2: rustc_apfloat::Float,
1135-
>(
1133+
fn generate_nan<F1: Float + FloatConvert<F2>, F2: Float>(
11361134
ecx: &InterpCx<'tcx, Self>,
11371135
inputs: &[F1],
11381136
) -> F2 {
11391137
ecx.generate_nan(inputs)
11401138
}
11411139

1140+
#[inline(always)]
1141+
fn equal_float_min_max<F: Float>(ecx: &MiriInterpCx<'tcx>, a: F, b: F) -> F {
1142+
ecx.equal_float_min_max(a, b)
1143+
}
1144+
1145+
#[inline(always)]
11421146
fn ub_checks(ecx: &InterpCx<'tcx, Self>) -> InterpResult<'tcx, bool> {
11431147
interp_ok(ecx.tcx.sess.ub_checks())
11441148
}
11451149

1150+
#[inline(always)]
11461151
fn thread_local_static_pointer(
11471152
ecx: &mut MiriInterpCx<'tcx>,
11481153
def_id: DefId,

Diff for: ‎src/tools/miri/src/operator.rs

+7
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
115115
nan
116116
}
117117
}
118+
119+
fn equal_float_min_max<F: Float>(&self, a: F, b: F) -> F {
120+
let this = self.eval_context_ref();
121+
// Return one side non-deterministically.
122+
let mut rand = this.machine.rng.borrow_mut();
123+
if rand.gen() { a } else { b }
124+
}
118125
}

Diff for: ‎src/tools/miri/tests/pass/float.rs

+28
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ fn main() {
3131
test_fast();
3232
test_algebraic();
3333
test_fmuladd();
34+
test_min_max_nondet();
3435
}
3536

3637
trait Float: Copy + PartialEq + Debug {
@@ -1211,3 +1212,30 @@ fn test_fmuladd() {
12111212
test_operations_f32(0.1, 0.2, 0.3);
12121213
test_operations_f64(1.1, 1.2, 1.3);
12131214
}
1215+
1216+
/// `min` and `max` on equal arguments are non-deterministic.
1217+
fn test_min_max_nondet() {
1218+
/// Ensure that if we call the closure often enough, we see both `true` and `false.`
1219+
#[track_caller]
1220+
fn ensure_both(f: impl Fn() -> bool) {
1221+
let rounds = 16;
1222+
let first = f();
1223+
for _ in 1..rounds {
1224+
if f() != first {
1225+
// We saw two different values!
1226+
return;
1227+
}
1228+
}
1229+
// We saw the same thing N times.
1230+
panic!("expected non-determinism, got {rounds} times the same result: {first:?}");
1231+
}
1232+
1233+
ensure_both(|| f16::min(0.0, -0.0).is_sign_positive());
1234+
ensure_both(|| f16::max(0.0, -0.0).is_sign_positive());
1235+
ensure_both(|| f32::min(0.0, -0.0).is_sign_positive());
1236+
ensure_both(|| f32::max(0.0, -0.0).is_sign_positive());
1237+
ensure_both(|| f64::min(0.0, -0.0).is_sign_positive());
1238+
ensure_both(|| f64::max(0.0, -0.0).is_sign_positive());
1239+
ensure_both(|| f128::min(0.0, -0.0).is_sign_positive());
1240+
ensure_both(|| f128::max(0.0, -0.0).is_sign_positive());
1241+
}

0 commit comments

Comments
 (0)
Please sign in to comment.