Skip to content

Commit bf9961c

Browse files
committed
Add f16 and f128 support to Miri
1 parent 836df19 commit bf9961c

File tree

3 files changed

+84
-117
lines changed

3 files changed

+84
-117
lines changed

compiler/rustc_const_eval/src/interpret/cast.rs

+10-17
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::assert_matches::assert_matches;
22

3-
use rustc_apfloat::ieee::{Double, Half, Quad, Single};
3+
use rustc_apfloat::ieee::{Double, Single};
44
use rustc_apfloat::{Float, FloatConvert};
55
use rustc_middle::mir::interpret::{InterpResult, PointerArithmetic, Scalar};
66
use rustc_middle::mir::CastKind;
@@ -189,10 +189,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
189189
bug!("FloatToFloat/FloatToInt cast: source type {} is not a float type", src.layout.ty)
190190
};
191191
let val = match fty {
192-
FloatTy::F16 => self.cast_from_float(src.to_scalar().to_f16()?, cast_to.ty),
192+
FloatTy::F16 => unimplemented!("f16_f128"),
193193
FloatTy::F32 => self.cast_from_float(src.to_scalar().to_f32()?, cast_to.ty),
194194
FloatTy::F64 => self.cast_from_float(src.to_scalar().to_f64()?, cast_to.ty),
195-
FloatTy::F128 => self.cast_from_float(src.to_scalar().to_f128()?, cast_to.ty),
195+
FloatTy::F128 => unimplemented!("f16_f128"),
196196
};
197197
Ok(ImmTy::from_scalar(val, cast_to))
198198
}
@@ -298,18 +298,18 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
298298
Float(fty) if signed => {
299299
let v = v as i128;
300300
match fty {
301-
FloatTy::F16 => Scalar::from_f16(Half::from_i128(v).value),
301+
FloatTy::F16 => unimplemented!("f16_f128"),
302302
FloatTy::F32 => Scalar::from_f32(Single::from_i128(v).value),
303303
FloatTy::F64 => Scalar::from_f64(Double::from_i128(v).value),
304-
FloatTy::F128 => Scalar::from_f128(Quad::from_i128(v).value),
304+
FloatTy::F128 => unimplemented!("f16_f128"),
305305
}
306306
}
307307
// unsigned int -> float
308308
Float(fty) => match fty {
309-
FloatTy::F16 => Scalar::from_f16(Half::from_u128(v).value),
309+
FloatTy::F16 => unimplemented!("f16_f128"),
310310
FloatTy::F32 => Scalar::from_f32(Single::from_u128(v).value),
311311
FloatTy::F64 => Scalar::from_f64(Double::from_u128(v).value),
312-
FloatTy::F128 => Scalar::from_f128(Quad::from_u128(v).value),
312+
FloatTy::F128 => unimplemented!("f16_f128"),
313313
},
314314

315315
// u8 -> char
@@ -323,12 +323,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
323323
/// Low-level cast helper function. Converts an apfloat `f` into int or float types.
324324
fn cast_from_float<F>(&self, f: F, dest_ty: Ty<'tcx>) -> Scalar<M::Provenance>
325325
where
326-
F: Float
327-
+ Into<Scalar<M::Provenance>>
328-
+ FloatConvert<Half>
329-
+ FloatConvert<Single>
330-
+ FloatConvert<Double>
331-
+ FloatConvert<Quad>,
326+
F: Float + Into<Scalar<M::Provenance>> + FloatConvert<Single> + FloatConvert<Double>,
332327
{
333328
use rustc_type_ir::TyKind::*;
334329

@@ -365,12 +360,10 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
365360
}
366361
// float -> float
367362
Float(fty) => match fty {
368-
FloatTy::F16 => Scalar::from_f16(adjust_nan(self, f, f.convert(&mut false).value)),
363+
FloatTy::F16 => unimplemented!("f16_f128"),
369364
FloatTy::F32 => Scalar::from_f32(adjust_nan(self, f, f.convert(&mut false).value)),
370365
FloatTy::F64 => Scalar::from_f64(adjust_nan(self, f, f.convert(&mut false).value)),
371-
FloatTy::F128 => {
372-
Scalar::from_f128(adjust_nan(self, f, f.convert(&mut false).value))
373-
}
366+
FloatTy::F128 => unimplemented!("f16_f128"),
374367
},
375368
// That's it.
376369
_ => span_bug!(self.cur_span(), "invalid float to {} cast", dest_ty),

src/tools/miri/src/helpers.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::time::Duration;
77

88
use rand::RngCore;
99

10-
use rustc_apfloat::ieee::{Double, Single};
10+
use rustc_apfloat::ieee::{Double, Half, Quad, Single};
1111
use rustc_apfloat::Float;
1212
use rustc_hir::{
1313
def::{DefKind, Namespace},
@@ -1201,12 +1201,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
12011201
};
12021202

12031203
let (val, status) = match fty {
1204-
FloatTy::F16 => unimplemented!("f16_f128"),
1204+
FloatTy::F16 =>
1205+
float_to_int_inner::<Half>(this, src.to_scalar().to_f16()?, cast_to, round),
12051206
FloatTy::F32 =>
12061207
float_to_int_inner::<Single>(this, src.to_scalar().to_f32()?, cast_to, round),
12071208
FloatTy::F64 =>
12081209
float_to_int_inner::<Double>(this, src.to_scalar().to_f64()?, cast_to, round),
1209-
FloatTy::F128 => unimplemented!("f16_f128"),
1210+
FloatTy::F128 =>
1211+
float_to_int_inner::<Quad>(this, src.to_scalar().to_f128()?, cast_to, round),
12101212
};
12111213

12121214
if status.intersects(

src/tools/miri/tests/pass/float.rs

+69-97
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#![feature(stmt_expr_attributes)]
22
#![feature(float_gamma)]
33
#![feature(core_intrinsics)]
4+
#![feature(f128)]
5+
#![feature(f16)]
46
#![allow(arithmetic_overflow)]
57

68
use std::fmt::Debug;
@@ -41,104 +43,43 @@ trait FloatToInt<Int>: Copy {
4143
unsafe fn cast_unchecked(self) -> Int;
4244
}
4345

44-
impl FloatToInt<i8> for f32 {
45-
fn cast(self) -> i8 {
46-
self as _
47-
}
48-
unsafe fn cast_unchecked(self) -> i8 {
49-
self.to_int_unchecked()
50-
}
51-
}
52-
impl FloatToInt<i32> for f32 {
53-
fn cast(self) -> i32 {
54-
self as _
55-
}
56-
unsafe fn cast_unchecked(self) -> i32 {
57-
self.to_int_unchecked()
58-
}
59-
}
60-
impl FloatToInt<u32> for f32 {
61-
fn cast(self) -> u32 {
62-
self as _
63-
}
64-
unsafe fn cast_unchecked(self) -> u32 {
65-
self.to_int_unchecked()
66-
}
67-
}
68-
impl FloatToInt<i64> for f32 {
69-
fn cast(self) -> i64 {
70-
self as _
71-
}
72-
unsafe fn cast_unchecked(self) -> i64 {
73-
self.to_int_unchecked()
74-
}
75-
}
76-
impl FloatToInt<u64> for f32 {
77-
fn cast(self) -> u64 {
78-
self as _
79-
}
80-
unsafe fn cast_unchecked(self) -> u64 {
81-
self.to_int_unchecked()
82-
}
46+
macro_rules! float_to_int {
47+
($fty:ty => $($ity:ty),+ $(,)?) => {
48+
$(
49+
impl FloatToInt<$ity> for $fty {
50+
fn cast(self) -> $ity {
51+
self as _
52+
}
53+
unsafe fn cast_unchecked(self) -> $ity {
54+
self.to_int_unchecked()
55+
}
56+
}
57+
)*
58+
};
8359
}
8460

85-
impl FloatToInt<i8> for f64 {
86-
fn cast(self) -> i8 {
87-
self as _
88-
}
89-
unsafe fn cast_unchecked(self) -> i8 {
90-
self.to_int_unchecked()
91-
}
92-
}
93-
impl FloatToInt<i32> for f64 {
94-
fn cast(self) -> i32 {
95-
self as _
96-
}
97-
unsafe fn cast_unchecked(self) -> i32 {
98-
self.to_int_unchecked()
99-
}
100-
}
101-
impl FloatToInt<u32> for f64 {
102-
fn cast(self) -> u32 {
103-
self as _
104-
}
105-
unsafe fn cast_unchecked(self) -> u32 {
106-
self.to_int_unchecked()
107-
}
108-
}
109-
impl FloatToInt<i64> for f64 {
110-
fn cast(self) -> i64 {
111-
self as _
112-
}
113-
unsafe fn cast_unchecked(self) -> i64 {
114-
self.to_int_unchecked()
115-
}
116-
}
117-
impl FloatToInt<u64> for f64 {
118-
fn cast(self) -> u64 {
119-
self as _
120-
}
121-
unsafe fn cast_unchecked(self) -> u64 {
122-
self.to_int_unchecked()
123-
}
124-
}
125-
impl FloatToInt<i128> for f64 {
126-
fn cast(self) -> i128 {
127-
self as _
128-
}
129-
unsafe fn cast_unchecked(self) -> i128 {
130-
self.to_int_unchecked()
131-
}
132-
}
133-
impl FloatToInt<u128> for f64 {
134-
fn cast(self) -> u128 {
135-
self as _
136-
}
137-
unsafe fn cast_unchecked(self) -> u128 {
138-
self.to_int_unchecked()
139-
}
61+
// FIXME(f16_f128): this is just used while we don't have `to_int_unchecked` on `f16` and `f128`.
62+
// Just use `float_to_int` once available.
63+
macro_rules! float_to_int_fallback {
64+
($fty:ty => $($ity:ty),+ $(,)?) => {
65+
$(
66+
impl FloatToInt<$ity> for $fty {
67+
fn cast(self) -> $ity {
68+
self as _
69+
}
70+
unsafe fn cast_unchecked(self) -> $ity {
71+
self as _
72+
}
73+
}
74+
)*
75+
};
14076
}
14177

78+
float_to_int_fallback!(f16 => i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
79+
float_to_int!(f32=> i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
80+
float_to_int!(f64 => i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
81+
float_to_int_fallback!(f128 => i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
82+
14283
/// Test this cast both via `as` and via `approx_unchecked` (i.e., it must not saturate).
14384
#[track_caller]
14485
#[inline(never)]
@@ -153,18 +94,29 @@ where
15394

15495
fn basic() {
15596
// basic arithmetic
97+
assert_eq(6.0_f16 * 6.0_f16, 36.0_f16);
15698
assert_eq(6.0_f32 * 6.0_f32, 36.0_f32);
15799
assert_eq(6.0_f64 * 6.0_f64, 36.0_f64);
100+
assert_eq(6.0_f128 * 6.0_f128, 36.0_f128);
101+
assert_eq(-{ 5.0_f16 }, -5.0_f16);
158102
assert_eq(-{ 5.0_f32 }, -5.0_f32);
159103
assert_eq(-{ 5.0_f64 }, -5.0_f64);
104+
assert_eq(-{ 5.0_f128 }, -5.0_f128);
105+
160106
// infinities, NaN
107+
// FIXME(f16_f128): add when constants and `is_infinite` are available
161108
assert!((5.0_f32 / 0.0).is_infinite());
162109
assert_ne!({ 5.0_f32 / 0.0 }, { -5.0_f32 / 0.0 });
163110
assert!((5.0_f64 / 0.0).is_infinite());
164111
assert_ne!({ 5.0_f64 / 0.0 }, { 5.0_f64 / -0.0 });
165112
assert_ne!(f32::NAN, f32::NAN);
166113
assert_ne!(f64::NAN, f64::NAN);
114+
167115
// negative zero
116+
let posz = 0.0f16;
117+
let negz = -0.0f16;
118+
assert_eq(posz, negz);
119+
assert_ne!(posz.to_bits(), negz.to_bits());
168120
let posz = 0.0f32;
169121
let negz = -0.0f32;
170122
assert_eq(posz, negz);
@@ -173,15 +125,30 @@ fn basic() {
173125
let negz = -0.0f64;
174126
assert_eq(posz, negz);
175127
assert_ne!(posz.to_bits(), negz.to_bits());
128+
let posz = 0.0f128;
129+
let negz = -0.0f128;
130+
assert_eq(posz, negz);
131+
assert_ne!(posz.to_bits(), negz.to_bits());
132+
176133
// byte-level transmute
177-
let x: u64 = unsafe { std::mem::transmute(42.0_f64) };
178-
let y: f64 = unsafe { std::mem::transmute(x) };
179-
assert_eq(y, 42.0_f64);
134+
let x: u16 = unsafe { std::mem::transmute(42.0_f16) };
135+
let y: f16 = unsafe { std::mem::transmute(x) };
136+
assert_eq(y, 42.0_f16);
180137
let x: u32 = unsafe { std::mem::transmute(42.0_f32) };
181138
let y: f32 = unsafe { std::mem::transmute(x) };
182139
assert_eq(y, 42.0_f32);
140+
let x: u64 = unsafe { std::mem::transmute(42.0_f64) };
141+
let y: f64 = unsafe { std::mem::transmute(x) };
142+
assert_eq(y, 42.0_f64);
143+
let x: u128 = unsafe { std::mem::transmute(42.0_f128) };
144+
let y: f128 = unsafe { std::mem::transmute(x) };
145+
assert_eq(y, 42.0_f128);
183146

184147
// `%` sign behavior, some of this used to be buggy
148+
assert!((black_box(1.0f16) % 1.0).is_sign_positive());
149+
assert!((black_box(1.0f16) % -1.0).is_sign_positive());
150+
assert!((black_box(-1.0f16) % 1.0).is_sign_negative());
151+
assert!((black_box(-1.0f16) % -1.0).is_sign_negative());
185152
assert!((black_box(1.0f32) % 1.0).is_sign_positive());
186153
assert!((black_box(1.0f32) % -1.0).is_sign_positive());
187154
assert!((black_box(-1.0f32) % 1.0).is_sign_negative());
@@ -190,7 +157,12 @@ fn basic() {
190157
assert!((black_box(1.0f64) % -1.0).is_sign_positive());
191158
assert!((black_box(-1.0f64) % 1.0).is_sign_negative());
192159
assert!((black_box(-1.0f64) % -1.0).is_sign_negative());
160+
assert!((black_box(1.0f128) % 1.0).is_sign_positive());
161+
assert!((black_box(1.0f128) % -1.0).is_sign_positive());
162+
assert!((black_box(-1.0f128) % 1.0).is_sign_negative());
163+
assert!((black_box(-1.0f128) % -1.0).is_sign_negative());
193164

165+
// FIXME(f16_f128): add when `abs` is available
194166
assert_eq!((-1.0f32).abs(), 1.0f32);
195167
assert_eq!(34.2f64.abs(), 34.2f64);
196168
}

0 commit comments

Comments
 (0)