Skip to content

Commit c3136b2

Browse files
committed
Auto merge of #3429 - eduardosm:shift, r=RalfJung
De-duplicate SSE2 sll/srl/sra code
2 parents 788a1db + 474a047 commit c3136b2

File tree

2 files changed

+97
-159
lines changed

2 files changed

+97
-159
lines changed

src/tools/miri/src/shims/x86/mod.rs

+80
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,86 @@ fn unary_op_ps<'tcx>(
468468
Ok(())
469469
}
470470

471+
enum ShiftOp {
472+
/// Shift left, logically (shift in zeros) -- same as shift left, arithmetically
473+
Left,
474+
/// Shift right, logically (shift in zeros)
475+
RightLogic,
476+
/// Shift right, arithmetically (shift in sign)
477+
RightArith,
478+
}
479+
480+
/// Shifts each element of `left` by a scalar amount. The shift amount
481+
/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector).
482+
///
483+
/// For logic shifts, when right is larger than BITS - 1, zero is produced.
484+
/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
485+
/// bit is copied to remaining bits.
486+
fn shift_simd_by_scalar<'tcx>(
487+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
488+
left: &OpTy<'tcx, Provenance>,
489+
right: &OpTy<'tcx, Provenance>,
490+
which: ShiftOp,
491+
dest: &MPlaceTy<'tcx, Provenance>,
492+
) -> InterpResult<'tcx, ()> {
493+
let (left, left_len) = this.operand_to_simd(left)?;
494+
let (dest, dest_len) = this.mplace_to_simd(dest)?;
495+
496+
assert_eq!(dest_len, left_len);
497+
// `right` may have a different length, and we only care about its
498+
// lowest 64bit anyway.
499+
500+
// Get the 64-bit shift operand and convert it to the type expected
501+
// by checked_{shl,shr} (u32).
502+
// It is ok to saturate the value to u32::MAX because any value
503+
// above BITS - 1 will produce the same result.
504+
let shift = u32::try_from(extract_first_u64(this, right)?).unwrap_or(u32::MAX);
505+
506+
for i in 0..dest_len {
507+
let left = this.read_scalar(&this.project_index(&left, i)?)?;
508+
let dest = this.project_index(&dest, i)?;
509+
510+
let res = match which {
511+
ShiftOp::Left => {
512+
let left = left.to_uint(dest.layout.size)?;
513+
let res = left.checked_shl(shift).unwrap_or(0);
514+
// `truncate` is needed as left-shift can make the absolute value larger.
515+
Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
516+
}
517+
ShiftOp::RightLogic => {
518+
let left = left.to_uint(dest.layout.size)?;
519+
let res = left.checked_shr(shift).unwrap_or(0);
520+
// No `truncate` needed as right-shift can only make the absolute value smaller.
521+
Scalar::from_uint(res, dest.layout.size)
522+
}
523+
ShiftOp::RightArith => {
524+
let left = left.to_int(dest.layout.size)?;
525+
// On overflow, copy the sign bit to the remaining bits
526+
let res = left.checked_shr(shift).unwrap_or(left >> 127);
527+
// No `truncate` needed as right-shift can only make the absolute value smaller.
528+
Scalar::from_int(res, dest.layout.size)
529+
}
530+
};
531+
this.write_scalar(res, &dest)?;
532+
}
533+
534+
Ok(())
535+
}
536+
537+
/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
538+
/// the first value.
539+
fn extract_first_u64<'tcx>(
540+
this: &crate::MiriInterpCx<'_, 'tcx>,
541+
op: &OpTy<'tcx, Provenance>,
542+
) -> InterpResult<'tcx, u64> {
543+
// Transmute vector to `[u64; 2]`
544+
let array_layout = this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u64, 2))?;
545+
let op = op.transmute(array_layout, this)?;
546+
547+
// Get the first u64 from the array
548+
this.read_scalar(&this.project_index(&op, 0)?)?.to_u64()
549+
}
550+
471551
// Rounds the first element of `right` according to `rounding`
472552
// and copies the remaining elements from `left`.
473553
fn round_first<'tcx, F: rustc_apfloat::Float>(

src/tools/miri/src/shims/x86/sse2.rs

+17-159
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use rustc_apfloat::ieee::Double;
2-
use rustc_middle::ty::layout::LayoutOf as _;
3-
use rustc_middle::ty::Ty;
42
use rustc_span::Symbol;
53
use rustc_target::spec::abi::Abi;
64

7-
use super::{bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, FloatBinOp};
5+
use super::{
6+
bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int, shift_simd_by_scalar,
7+
FloatBinOp, ShiftOp,
8+
};
89
use crate::*;
910
use shims::foreign_items::EmulateForeignItemResult;
1011

@@ -109,156 +110,27 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
109110
this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
110111
}
111112
}
112-
// Used to implement the _mm_{sll,srl,sra}_epi16 functions.
113-
// Shifts 16-bit packed integers in left by the amount in right.
114-
// Both operands are vectors of 16-bit integers. However, right is
115-
// interpreted as a single 64-bit integer (remaining bits are ignored).
116-
// For logic shifts, when right is larger than 15, zero is produced.
117-
// For arithmetic shifts, when right is larger than 15, the sign bit
113+
// Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
114+
// (except _mm_sra_epi64, which is not available in SSE2).
115+
// Shifts N-bit packed integers in left by the amount in right.
116+
// Both operands are 128-bit vectors. However, right is interpreted as
117+
// a single 64-bit integer (remaining bits are ignored).
118+
// For logic shifts, when right is larger than N - 1, zero is produced.
119+
// For arithmetic shifts, when right is larger than N - 1, the sign bit
118120
// is copied to remaining bits.
119-
"psll.w" | "psrl.w" | "psra.w" => {
121+
"psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
122+
| "psrl.q" => {
120123
let [left, right] =
121124
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
122125

123-
let (left, left_len) = this.operand_to_simd(left)?;
124-
let (right, right_len) = this.operand_to_simd(right)?;
125-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
126-
127-
assert_eq!(dest_len, left_len);
128-
assert_eq!(dest_len, right_len);
129-
130-
enum ShiftOp {
131-
Sll,
132-
Srl,
133-
Sra,
134-
}
135126
let which = match unprefixed_name {
136-
"psll.w" => ShiftOp::Sll,
137-
"psrl.w" => ShiftOp::Srl,
138-
"psra.w" => ShiftOp::Sra,
127+
"psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
128+
"psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
129+
"psra.w" | "psra.d" => ShiftOp::RightArith,
139130
_ => unreachable!(),
140131
};
141132

142-
// Get the 64-bit shift operand and convert it to the type expected
143-
// by checked_{shl,shr} (u32).
144-
// It is ok to saturate the value to u32::MAX because any value
145-
// above 15 will produce the same result.
146-
let shift = extract_first_u64(this, &right)?.try_into().unwrap_or(u32::MAX);
147-
148-
for i in 0..dest_len {
149-
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u16()?;
150-
let dest = this.project_index(&dest, i)?;
151-
152-
let res = match which {
153-
ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0),
154-
ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0),
155-
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
156-
ShiftOp::Sra => {
157-
// Convert u16 to i16 to use arithmetic shift
158-
let left = left as i16;
159-
// Copy the sign bit to the remaining bits
160-
left.checked_shr(shift).unwrap_or(left >> 15) as u16
161-
}
162-
};
163-
164-
this.write_scalar(Scalar::from_u16(res), &dest)?;
165-
}
166-
}
167-
// Used to implement the _mm_{sll,srl,sra}_epi32 functions.
168-
// 32-bit equivalent to the shift functions above.
169-
"psll.d" | "psrl.d" | "psra.d" => {
170-
let [left, right] =
171-
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
172-
173-
let (left, left_len) = this.operand_to_simd(left)?;
174-
let (right, right_len) = this.operand_to_simd(right)?;
175-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
176-
177-
assert_eq!(dest_len, left_len);
178-
assert_eq!(dest_len, right_len);
179-
180-
enum ShiftOp {
181-
Sll,
182-
Srl,
183-
Sra,
184-
}
185-
let which = match unprefixed_name {
186-
"psll.d" => ShiftOp::Sll,
187-
"psrl.d" => ShiftOp::Srl,
188-
"psra.d" => ShiftOp::Sra,
189-
_ => unreachable!(),
190-
};
191-
192-
// Get the 64-bit shift operand and convert it to the type expected
193-
// by checked_{shl,shr} (u32).
194-
// It is ok to saturate the value to u32::MAX because any value
195-
// above 31 will produce the same result.
196-
let shift = extract_first_u64(this, &right)?.try_into().unwrap_or(u32::MAX);
197-
198-
for i in 0..dest_len {
199-
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u32()?;
200-
let dest = this.project_index(&dest, i)?;
201-
202-
let res = match which {
203-
ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0),
204-
ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0),
205-
#[allow(clippy::cast_possible_wrap, clippy::cast_sign_loss)]
206-
ShiftOp::Sra => {
207-
// Convert u32 to i32 to use arithmetic shift
208-
let left = left as i32;
209-
// Copy the sign bit to the remaining bits
210-
left.checked_shr(shift).unwrap_or(left >> 31) as u32
211-
}
212-
};
213-
214-
this.write_scalar(Scalar::from_u32(res), &dest)?;
215-
}
216-
}
217-
// Used to implement the _mm_{sll,srl}_epi64 functions.
218-
// 64-bit equivalent to the shift functions above, except _mm_sra_epi64,
219-
// which is not available in SSE2.
220-
"psll.q" | "psrl.q" => {
221-
let [left, right] =
222-
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
223-
224-
let (left, left_len) = this.operand_to_simd(left)?;
225-
let (right, right_len) = this.operand_to_simd(right)?;
226-
let (dest, dest_len) = this.mplace_to_simd(dest)?;
227-
228-
assert_eq!(dest_len, left_len);
229-
assert_eq!(dest_len, right_len);
230-
231-
enum ShiftOp {
232-
Sll,
233-
Srl,
234-
}
235-
let which = match unprefixed_name {
236-
"psll.q" => ShiftOp::Sll,
237-
"psrl.q" => ShiftOp::Srl,
238-
_ => unreachable!(),
239-
};
240-
241-
// Get the 64-bit shift operand and convert it to the type expected
242-
// by checked_{shl,shr} (u32).
243-
// It is ok to saturate the value to u32::MAX because any value
244-
// above 63 will produce the same result.
245-
let shift = this
246-
.read_scalar(&this.project_index(&right, 0)?)?
247-
.to_u64()?
248-
.try_into()
249-
.unwrap_or(u32::MAX);
250-
251-
for i in 0..dest_len {
252-
let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u64()?;
253-
let dest = this.project_index(&dest, i)?;
254-
255-
let res = match which {
256-
ShiftOp::Sll => left.checked_shl(shift).unwrap_or(0),
257-
ShiftOp::Srl => left.checked_shr(shift).unwrap_or(0),
258-
};
259-
260-
this.write_scalar(Scalar::from_u64(res), &dest)?;
261-
}
133+
shift_simd_by_scalar(this, left, right, which, dest)?;
262134
}
263135
// Used to implement the _mm_cvtps_epi32, _mm_cvttps_epi32, _mm_cvtpd_epi32
264136
// and _mm_cvttpd_epi32 functions.
@@ -585,17 +457,3 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
585457
Ok(EmulateForeignItemResult::NeedsJumping)
586458
}
587459
}
588-
589-
/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
590-
/// the first value.
591-
fn extract_first_u64<'tcx>(
592-
this: &crate::MiriInterpCx<'_, 'tcx>,
593-
op: &MPlaceTy<'tcx, Provenance>,
594-
) -> InterpResult<'tcx, u64> {
595-
// Transmute vector to `[u64; 2]`
596-
let u64_array_layout = this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u64, 2))?;
597-
let op = op.transmute(u64_array_layout, this)?;
598-
599-
// Get the first u64 from the array
600-
this.read_scalar(&this.project_index(&op, 0)?)?.to_u64()
601-
}

0 commit comments

Comments
 (0)