Skip to content

Commit 607208c

Browse files
committed
Auto merge of rust-lang#3176 - eduardosm:cmp, r=RalfJung
Implement all 16 AVX compare operators for 128-bit SIMD vectors `_mm_cmp_{ss,ps,sd,pd}` functions are AVX functions that use `llvm.x86.sse{,2}.` prefixed intrinsics, so they were "accidentally" partially implemented when SSE and SSE2 intrinsics were implemented. The 16 AVX compare operators are now implemented and tested.
2 parents 6730f22 + 81303e7 commit 607208c

8 files changed

+261
-71
lines changed

src/tools/miri/src/shims/x86/mod.rs

+66-54
Original file line numberDiff line numberDiff line change
@@ -119,53 +119,32 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
119119
}
120120
}
121121

122-
/// Floating point comparison operation
123-
///
124-
/// <https://www.felixcloutier.com/x86/cmpss>
125-
/// <https://www.felixcloutier.com/x86/cmpps>
126-
/// <https://www.felixcloutier.com/x86/cmpsd>
127-
/// <https://www.felixcloutier.com/x86/cmppd>
128-
#[derive(Copy, Clone)]
129-
enum FloatCmpOp {
130-
Eq,
131-
Lt,
132-
Le,
133-
Unord,
134-
Neq,
135-
/// Not less-than
136-
Nlt,
137-
/// Not less-or-equal
138-
Nle,
139-
/// Ordered, i.e. neither of them is NaN
140-
Ord,
141-
}
142-
143-
impl FloatCmpOp {
144-
/// Convert from the `imm` argument used to specify the comparison
145-
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
146-
fn from_intrinsic_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
147-
match imm {
148-
0 => Ok(Self::Eq),
149-
1 => Ok(Self::Lt),
150-
2 => Ok(Self::Le),
151-
3 => Ok(Self::Unord),
152-
4 => Ok(Self::Neq),
153-
5 => Ok(Self::Nlt),
154-
6 => Ok(Self::Nle),
155-
7 => Ok(Self::Ord),
156-
imm => {
157-
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: {imm}");
158-
}
159-
}
160-
}
161-
}
162-
163122
#[derive(Copy, Clone)]
164123
enum FloatBinOp {
165124
/// Arithmetic operation
166125
Arith(mir::BinOp),
167126
/// Comparison
168-
Cmp(FloatCmpOp),
127+
///
128+
/// The semantics of this operator is a case distinction: we compare the two operands,
129+
/// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
130+
/// which class they fall into.
131+
///
132+
/// AVX supports all 16 combinations, SSE only a subset
133+
///
134+
/// <https://www.felixcloutier.com/x86/cmpss>
135+
/// <https://www.felixcloutier.com/x86/cmpps>
136+
/// <https://www.felixcloutier.com/x86/cmpsd>
137+
/// <https://www.felixcloutier.com/x86/cmppd>
138+
Cmp {
139+
/// Result when lhs < rhs
140+
gt: bool,
141+
/// Result when lhs > rhs
142+
lt: bool,
143+
/// Result when lhs == rhs
144+
eq: bool,
145+
/// Result when lhs is NaN or rhs is NaN
146+
unord: bool,
147+
},
169148
/// Minimum value (with SSE semantics)
170149
///
171150
/// <https://www.felixcloutier.com/x86/minss>
@@ -182,6 +161,44 @@ enum FloatBinOp {
182161
Max,
183162
}
184163

164+
impl FloatBinOp {
165+
/// Convert from the `imm` argument used to specify the comparison
166+
/// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
167+
fn cmp_from_imm(imm: i8, intrinsic: &str) -> InterpResult<'_, Self> {
168+
// Only bits 0..=4 are used, remaining should be zero.
169+
if imm & !0b1_1111 != 0 {
170+
throw_unsup_format!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
171+
}
172+
// Bit 4 specifies whether the operation is quiet or signaling, which
173+
// we do not care in Miri.
174+
// Bits 0..=2 specifies the operation.
175+
// `gt` indicates the result to be returned when the LHS is strictly
176+
// greater than the RHS, and so on.
177+
let (gt, lt, eq, unord) = match imm & 0b111 {
178+
// Equal
179+
0x0 => (false, false, true, false),
180+
// Less-than
181+
0x1 => (false, true, false, false),
182+
// Less-or-equal
183+
0x2 => (false, true, true, false),
184+
// Unordered (either is NaN)
185+
0x3 => (false, false, false, true),
186+
// Not equal
187+
0x4 => (true, true, false, true),
188+
// Not less-than
189+
0x5 => (true, false, true, true),
190+
// Not less-or-equal
191+
0x6 => (true, false, false, true),
192+
// Ordered (neither is NaN)
193+
0x7 => (true, true, true, false),
194+
_ => unreachable!(),
195+
};
196+
// When bit 3 is 1 (only possible in AVX), unord is toggled.
197+
let unord = unord ^ (imm & 0b1000 != 0);
198+
Ok(Self::Cmp { gt, lt, eq, unord })
199+
}
200+
}
201+
185202
/// Performs `which` scalar operation on `left` and `right` and returns
186203
/// the result.
187204
fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
@@ -195,20 +212,15 @@ fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
195212
let res = this.wrapping_binary_op(which, left, right)?;
196213
Ok(res.to_scalar())
197214
}
198-
FloatBinOp::Cmp(which) => {
215+
FloatBinOp::Cmp { gt, lt, eq, unord } => {
199216
let left = left.to_scalar().to_float::<F>()?;
200217
let right = right.to_scalar().to_float::<F>()?;
201-
// FIXME: Make sure that these operations match the semantics
202-
// of cmpps/cmpss/cmppd/cmpsd
203-
let res = match which {
204-
FloatCmpOp::Eq => left == right,
205-
FloatCmpOp::Lt => left < right,
206-
FloatCmpOp::Le => left <= right,
207-
FloatCmpOp::Unord => left.is_nan() || right.is_nan(),
208-
FloatCmpOp::Neq => left != right,
209-
FloatCmpOp::Nlt => !(left < right),
210-
FloatCmpOp::Nle => !(left <= right),
211-
FloatCmpOp::Ord => !left.is_nan() && !right.is_nan(),
218+
219+
let res = match left.partial_cmp(&right) {
220+
None => unord,
221+
Some(std::cmp::Ordering::Less) => lt,
222+
Some(std::cmp::Ordering::Equal) => eq,
223+
Some(std::cmp::Ordering::Greater) => gt,
212224
};
213225
Ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
214226
}

src/tools/miri/src/shims/x86/sse.rs

+15-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use rustc_target::spec::abi::Abi;
55

66
use rand::Rng as _;
77

8-
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
8+
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
99
use crate::*;
1010
use shims::foreign_items::EmulateForeignItemResult;
1111

@@ -95,33 +95,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
9595

9696
unary_op_ps(this, which, op, dest)?;
9797
}
98-
// Used to implement the _mm_cmp_ss function.
98+
// Used to implement the _mm_cmp*_ss functions.
9999
// Performs a comparison operation on the first component of `left`
100100
// and `right`, returning 0 if false or `u32::MAX` if true. The remaining
101101
// components are copied from `left`.
102+
// _mm_cmp_ss is actually an AVX function where the operation is specified
103+
// by a const parameter.
104+
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ss are SSE functions
105+
// with hard-coded operations.
102106
"cmp.ss" => {
103107
let [left, right, imm] =
104108
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
105109

106-
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
110+
let which = FloatBinOp::cmp_from_imm(
107111
this.read_scalar(imm)?.to_i8()?,
108112
"llvm.x86.sse.cmp.ss",
109-
)?);
113+
)?;
110114

111115
bin_op_simd_float_first::<Single>(this, which, left, right, dest)?;
112116
}
113-
// Used to implement the _mm_cmp_ps function.
117+
// Used to implement the _mm_cmp*_ps functions.
114118
// Performs a comparison operation on each component of `left`
115119
// and `right`. For each component, returns 0 if false or u32::MAX
116120
// if true.
121+
// _mm_cmp_ps is actually an AVX function where the operation is specified
122+
// by a const parameter.
123+
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_ps are SSE functions
124+
// with hard-coded operations.
117125
"cmp.ps" => {
118126
let [left, right, imm] =
119127
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
120128

121-
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
129+
let which = FloatBinOp::cmp_from_imm(
122130
this.read_scalar(imm)?.to_i8()?,
123131
"llvm.x86.sse.cmp.ps",
124-
)?);
132+
)?;
125133

126134
bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
127135
}

src/tools/miri/src/shims/x86/sse2.rs

+14-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
44
use rustc_span::Symbol;
55
use rustc_target::spec::abi::Abi;
66

7-
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp, FloatCmpOp};
7+
use super::{bin_op_simd_float_all, bin_op_simd_float_first, FloatBinOp};
88
use crate::*;
99
use shims::foreign_items::EmulateForeignItemResult;
1010

@@ -461,33 +461,41 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
461461
this.write_scalar(res, &dest)?;
462462
}
463463
}
464-
// Used to implement the _mm_cmp*_sd function.
464+
// Used to implement the _mm_cmp*_sd functions.
465465
// Performs a comparison operation on the first component of `left`
466466
// and `right`, returning 0 if false or `u64::MAX` if true. The remaining
467467
// components are copied from `left`.
468+
// _mm_cmp_sd is actually an AVX function where the operation is specified
469+
// by a const parameter.
470+
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
471+
// with hard-coded operations.
468472
"cmp.sd" => {
469473
let [left, right, imm] =
470474
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
471475

472-
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
476+
let which = FloatBinOp::cmp_from_imm(
473477
this.read_scalar(imm)?.to_i8()?,
474478
"llvm.x86.sse2.cmp.sd",
475-
)?);
479+
)?;
476480

477481
bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
478482
}
479483
// Used to implement the _mm_cmp*_pd functions.
480484
// Performs a comparison operation on each component of `left`
481485
// and `right`. For each component, returns 0 if false or `u64::MAX`
482486
// if true.
487+
// _mm_cmp_pd is actually an AVX function where the operation is specified
488+
// by a const parameter.
489+
// _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
490+
// with hard-coded operations.
483491
"cmp.pd" => {
484492
let [left, right, imm] =
485493
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
486494

487-
let which = FloatBinOp::Cmp(FloatCmpOp::from_intrinsic_imm(
495+
let which = FloatBinOp::cmp_from_imm(
488496
this.read_scalar(imm)?.to_i8()?,
489497
"llvm.x86.sse2.cmp.pd",
490-
)?);
498+
)?;
491499

492500
bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
493501
}

src/tools/miri/tests/pass/intrinsics-x86-aes-vaes.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// Ignore everything except x86 and x86_64
2-
// Any additional target are added to CI should be ignored here
2+
// Any new targets that are added to CI should be ignored here.
33
// (We cannot use `cfg`-based tricks here since the `target-feature` flags below only work on x86.)
44
//@ignore-target-aarch64
55
//@ignore-target-arm

0 commit comments

Comments
 (0)