Skip to content

Commit 9136560

Browse files
committed
Auto merge of rust-lang#115933 - oli-obk:simd_shuffle_const, r=workingjubilee
Prototype using const generic for simd_shuffle IDX array cc rust-lang#85229 r? `@workingjubilee` on the design TLDR: there is now a `fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;` intrinsic that allows replacing ```rust simd_shuffle(a, b, const { stuff }) ``` with ```rust simd_shuffle_generic::<_, _, {&stuff}>(a, b) ``` which makes the compiler implementations much simpler, if we manage to at some point eliminate `simd_shuffle`. There are some issues with this today though (can't do math without bubbling it up in the generic arguments). With this change, we can start porting the simple cases and get better data on the others.
2 parents 4efd655 + a9030e6 commit 9136560

File tree

11 files changed

+313
-46
lines changed

11 files changed

+313
-46
lines changed

compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

+49-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn report_simd_type_validation_error(
2121
pub(super) fn codegen_simd_intrinsic_call<'tcx>(
2222
fx: &mut FunctionCx<'_, '_, 'tcx>,
2323
intrinsic: Symbol,
24-
_args: GenericArgsRef<'tcx>,
24+
generic_args: GenericArgsRef<'tcx>,
2525
args: &[mir::Operand<'tcx>],
2626
ret: CPlace<'tcx>,
2727
target: BasicBlock,
@@ -117,6 +117,54 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
117117
});
118118
}
119119

120+
// simd_shuffle_generic<T, U, const I: &[u32]>(x: T, y: T) -> U
121+
sym::simd_shuffle_generic => {
122+
let [x, y] = args else {
123+
bug!("wrong number of args for intrinsic {intrinsic}");
124+
};
125+
let x = codegen_operand(fx, x);
126+
let y = codegen_operand(fx, y);
127+
128+
if !x.layout().ty.is_simd() {
129+
report_simd_type_validation_error(fx, intrinsic, span, x.layout().ty);
130+
return;
131+
}
132+
133+
let idx = generic_args[2]
134+
.expect_const()
135+
.eval(fx.tcx, ty::ParamEnv::reveal_all(), Some(span))
136+
.unwrap()
137+
.unwrap_branch();
138+
139+
assert_eq!(x.layout(), y.layout());
140+
let layout = x.layout();
141+
142+
let (lane_count, lane_ty) = layout.ty.simd_size_and_type(fx.tcx);
143+
let (ret_lane_count, ret_lane_ty) = ret.layout().ty.simd_size_and_type(fx.tcx);
144+
145+
assert_eq!(lane_ty, ret_lane_ty);
146+
assert_eq!(idx.len() as u64, ret_lane_count);
147+
148+
let total_len = lane_count * 2;
149+
150+
let indexes =
151+
idx.iter().map(|idx| idx.unwrap_leaf().try_to_u16().unwrap()).collect::<Vec<u16>>();
152+
153+
for &idx in &indexes {
154+
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
155+
}
156+
157+
for (out_idx, in_idx) in indexes.into_iter().enumerate() {
158+
let in_lane = if u64::from(in_idx) < lane_count {
159+
x.value_lane(fx, in_idx.into())
160+
} else {
161+
y.value_lane(fx, u64::from(in_idx) - lane_count)
162+
};
163+
let out_lane = ret.place_lane(fx, u64::try_from(out_idx).unwrap());
164+
out_lane.write_cvalue(fx, in_lane);
165+
}
166+
}
167+
120168
// simd_shuffle<T, I, U>(x: T, y: T, idx: I) -> U
121169
sym::simd_shuffle => {
122170
let (x, y, idx) = match args {

compiler/rustc_codegen_llvm/src/intrinsic.rs

+55-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use rustc_codegen_ssa::mir::place::PlaceRef;
1515
use rustc_codegen_ssa::traits::*;
1616
use rustc_hir as hir;
1717
use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, LayoutOf};
18-
use rustc_middle::ty::{self, Ty};
18+
use rustc_middle::ty::{self, GenericArgsRef, Ty};
1919
use rustc_middle::{bug, span_bug};
2020
use rustc_span::{sym, symbol::kw, Span, Symbol};
2121
use rustc_target::abi::{self, Align, HasDataLayout, Primitive};
@@ -376,7 +376,9 @@ impl<'ll, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'_, 'll, 'tcx> {
376376
}
377377

378378
_ if name.as_str().starts_with("simd_") => {
379-
match generic_simd_intrinsic(self, name, callee_ty, args, ret_ty, llret_ty, span) {
379+
match generic_simd_intrinsic(
380+
self, name, callee_ty, fn_args, args, ret_ty, llret_ty, span,
381+
) {
380382
Ok(llval) => llval,
381383
Err(()) => return,
382384
}
@@ -911,6 +913,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
911913
bx: &mut Builder<'_, 'll, 'tcx>,
912914
name: Symbol,
913915
callee_ty: Ty<'tcx>,
916+
fn_args: GenericArgsRef<'tcx>,
914917
args: &[OperandRef<'tcx, &'ll Value>],
915918
ret_ty: Ty<'tcx>,
916919
llret_ty: &'ll Type,
@@ -1030,6 +1033,56 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
10301033
));
10311034
}
10321035

1036+
if name == sym::simd_shuffle_generic {
1037+
let idx = fn_args[2]
1038+
.expect_const()
1039+
.eval(tcx, ty::ParamEnv::reveal_all(), Some(span))
1040+
.unwrap()
1041+
.unwrap_branch();
1042+
let n = idx.len() as u64;
1043+
1044+
require_simd!(ret_ty, InvalidMonomorphization::SimdReturn { span, name, ty: ret_ty });
1045+
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
1046+
require!(
1047+
out_len == n,
1048+
InvalidMonomorphization::ReturnLength { span, name, in_len: n, ret_ty, out_len }
1049+
);
1050+
require!(
1051+
in_elem == out_ty,
1052+
InvalidMonomorphization::ReturnElement { span, name, in_elem, in_ty, ret_ty, out_ty }
1053+
);
1054+
1055+
let total_len = in_len * 2;
1056+
1057+
let indices: Option<Vec<_>> = idx
1058+
.iter()
1059+
.enumerate()
1060+
.map(|(arg_idx, val)| {
1061+
let idx = val.unwrap_leaf().try_to_i32().unwrap();
1062+
if idx >= i32::try_from(total_len).unwrap() {
1063+
bx.sess().emit_err(InvalidMonomorphization::ShuffleIndexOutOfBounds {
1064+
span,
1065+
name,
1066+
arg_idx: arg_idx as u64,
1067+
total_len: total_len.into(),
1068+
});
1069+
None
1070+
} else {
1071+
Some(bx.const_i32(idx))
1072+
}
1073+
})
1074+
.collect();
1075+
let Some(indices) = indices else {
1076+
return Ok(bx.const_null(llret_ty));
1077+
};
1078+
1079+
return Ok(bx.shuffle_vector(
1080+
args[0].immediate(),
1081+
args[1].immediate(),
1082+
bx.const_vector(&indices),
1083+
));
1084+
}
1085+
10331086
if name == sym::simd_shuffle {
10341087
// Make sure this is actually an array, since typeck only checks the length-suffixed
10351088
// version of this intrinsic.

compiler/rustc_hir_analysis/src/check/intrinsic.rs

+23-21
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ fn equate_intrinsic_type<'tcx>(
2020
it: &hir::ForeignItem<'_>,
2121
n_tps: usize,
2222
n_lts: usize,
23+
n_cts: usize,
2324
sig: ty::PolyFnSig<'tcx>,
2425
) {
2526
let (own_counts, span) = match &it.kind {
@@ -51,7 +52,7 @@ fn equate_intrinsic_type<'tcx>(
5152

5253
if gen_count_ok(own_counts.lifetimes, n_lts, "lifetime")
5354
&& gen_count_ok(own_counts.types, n_tps, "type")
54-
&& gen_count_ok(own_counts.consts, 0, "const")
55+
&& gen_count_ok(own_counts.consts, n_cts, "const")
5556
{
5657
let it_def_id = it.owner_id.def_id;
5758
check_function_signature(
@@ -489,7 +490,7 @@ pub fn check_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>) {
489490
};
490491
let sig = tcx.mk_fn_sig(inputs, output, false, unsafety, Abi::RustIntrinsic);
491492
let sig = ty::Binder::bind_with_vars(sig, bound_vars);
492-
equate_intrinsic_type(tcx, it, n_tps, n_lts, sig)
493+
equate_intrinsic_type(tcx, it, n_tps, n_lts, 0, sig)
493494
}
494495

495496
/// Type-check `extern "platform-intrinsic" { ... }` functions.
@@ -501,9 +502,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
501502

502503
let name = it.ident.name;
503504

504-
let (n_tps, inputs, output) = match name {
505+
let (n_tps, n_cts, inputs, output) = match name {
505506
sym::simd_eq | sym::simd_ne | sym::simd_lt | sym::simd_le | sym::simd_gt | sym::simd_ge => {
506-
(2, vec![param(0), param(0)], param(1))
507+
(2, 0, vec![param(0), param(0)], param(1))
507508
}
508509
sym::simd_add
509510
| sym::simd_sub
@@ -519,8 +520,8 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
519520
| sym::simd_fmax
520521
| sym::simd_fpow
521522
| sym::simd_saturating_add
522-
| sym::simd_saturating_sub => (1, vec![param(0), param(0)], param(0)),
523-
sym::simd_arith_offset => (2, vec![param(0), param(1)], param(0)),
523+
| sym::simd_saturating_sub => (1, 0, vec![param(0), param(0)], param(0)),
524+
sym::simd_arith_offset => (2, 0, vec![param(0), param(1)], param(0)),
524525
sym::simd_neg
525526
| sym::simd_bswap
526527
| sym::simd_bitreverse
@@ -538,25 +539,25 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
538539
| sym::simd_ceil
539540
| sym::simd_floor
540541
| sym::simd_round
541-
| sym::simd_trunc => (1, vec![param(0)], param(0)),
542-
sym::simd_fpowi => (1, vec![param(0), tcx.types.i32], param(0)),
543-
sym::simd_fma => (1, vec![param(0), param(0), param(0)], param(0)),
544-
sym::simd_gather => (3, vec![param(0), param(1), param(2)], param(0)),
545-
sym::simd_scatter => (3, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
546-
sym::simd_insert => (2, vec![param(0), tcx.types.u32, param(1)], param(0)),
547-
sym::simd_extract => (2, vec![param(0), tcx.types.u32], param(1)),
542+
| sym::simd_trunc => (1, 0, vec![param(0)], param(0)),
543+
sym::simd_fpowi => (1, 0, vec![param(0), tcx.types.i32], param(0)),
544+
sym::simd_fma => (1, 0, vec![param(0), param(0), param(0)], param(0)),
545+
sym::simd_gather => (3, 0, vec![param(0), param(1), param(2)], param(0)),
546+
sym::simd_scatter => (3, 0, vec![param(0), param(1), param(2)], Ty::new_unit(tcx)),
547+
sym::simd_insert => (2, 0, vec![param(0), tcx.types.u32, param(1)], param(0)),
548+
sym::simd_extract => (2, 0, vec![param(0), tcx.types.u32], param(1)),
548549
sym::simd_cast
549550
| sym::simd_as
550551
| sym::simd_cast_ptr
551552
| sym::simd_expose_addr
552-
| sym::simd_from_exposed_addr => (2, vec![param(0)], param(1)),
553-
sym::simd_bitmask => (2, vec![param(0)], param(1)),
553+
| sym::simd_from_exposed_addr => (2, 0, vec![param(0)], param(1)),
554+
sym::simd_bitmask => (2, 0, vec![param(0)], param(1)),
554555
sym::simd_select | sym::simd_select_bitmask => {
555-
(2, vec![param(0), param(1), param(1)], param(1))
556+
(2, 0, vec![param(0), param(1), param(1)], param(1))
556557
}
557-
sym::simd_reduce_all | sym::simd_reduce_any => (1, vec![param(0)], tcx.types.bool),
558+
sym::simd_reduce_all | sym::simd_reduce_any => (1, 0, vec![param(0)], tcx.types.bool),
558559
sym::simd_reduce_add_ordered | sym::simd_reduce_mul_ordered => {
559-
(2, vec![param(0), param(1)], param(1))
560+
(2, 0, vec![param(0), param(1)], param(1))
560561
}
561562
sym::simd_reduce_add_unordered
562563
| sym::simd_reduce_mul_unordered
@@ -566,8 +567,9 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
566567
| sym::simd_reduce_min
567568
| sym::simd_reduce_max
568569
| sym::simd_reduce_min_nanless
569-
| sym::simd_reduce_max_nanless => (2, vec![param(0)], param(1)),
570-
sym::simd_shuffle => (3, vec![param(0), param(0), param(1)], param(2)),
570+
| sym::simd_reduce_max_nanless => (2, 0, vec![param(0)], param(1)),
571+
sym::simd_shuffle => (3, 0, vec![param(0), param(0), param(1)], param(2)),
572+
sym::simd_shuffle_generic => (2, 1, vec![param(0), param(0)], param(1)),
571573
_ => {
572574
let msg = format!("unrecognized platform-specific intrinsic function: `{name}`");
573575
tcx.sess.struct_span_err(it.span, msg).emit();
@@ -577,5 +579,5 @@ pub fn check_platform_intrinsic_type(tcx: TyCtxt<'_>, it: &hir::ForeignItem<'_>)
577579

578580
let sig = tcx.mk_fn_sig(inputs, output, false, hir::Unsafety::Unsafe, Abi::PlatformIntrinsic);
579581
let sig = ty::Binder::dummy(sig);
580-
equate_intrinsic_type(tcx, it, n_tps, 0, sig)
582+
equate_intrinsic_type(tcx, it, n_tps, 0, n_cts, sig)
581583
}

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,7 @@ symbols! {
14651465
simd_shl,
14661466
simd_shr,
14671467
simd_shuffle,
1468+
simd_shuffle_generic,
14681469
simd_sub,
14691470
simd_trunc,
14701471
simd_xor,

src/tools/miri/src/shims/intrinsics/mod.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
6060
}
6161

6262
// The rest jumps to `ret` immediately.
63-
this.emulate_intrinsic_by_name(intrinsic_name, args, dest)?;
63+
this.emulate_intrinsic_by_name(intrinsic_name, instance.args, args, dest)?;
6464

6565
trace!("{:?}", this.dump_place(dest));
6666
this.go_to_block(ret);
@@ -71,6 +71,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
7171
fn emulate_intrinsic_by_name(
7272
&mut self,
7373
intrinsic_name: &str,
74+
generic_args: ty::GenericArgsRef<'tcx>,
7475
args: &[OpTy<'tcx, Provenance>],
7576
dest: &PlaceTy<'tcx, Provenance>,
7677
) -> InterpResult<'tcx> {
@@ -80,7 +81,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
8081
return this.emulate_atomic_intrinsic(name, args, dest);
8182
}
8283
if let Some(name) = intrinsic_name.strip_prefix("simd_") {
83-
return this.emulate_simd_intrinsic(name, args, dest);
84+
return this.emulate_simd_intrinsic(name, generic_args, args, dest);
8485
}
8586

8687
match intrinsic_name {

src/tools/miri/src/shims/intrinsics/simd.rs

+33
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
1212
fn emulate_simd_intrinsic(
1313
&mut self,
1414
intrinsic_name: &str,
15+
generic_args: ty::GenericArgsRef<'tcx>,
1516
args: &[OpTy<'tcx, Provenance>],
1617
dest: &PlaceTy<'tcx, Provenance>,
1718
) -> InterpResult<'tcx> {
@@ -488,6 +489,38 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
488489
this.write_immediate(*val, &dest)?;
489490
}
490491
}
492+
"shuffle_generic" => {
493+
let [left, right] = check_arg_count(args)?;
494+
let (left, left_len) = this.operand_to_simd(left)?;
495+
let (right, right_len) = this.operand_to_simd(right)?;
496+
let (dest, dest_len) = this.place_to_simd(dest)?;
497+
498+
let index = generic_args[2].expect_const().eval(*this.tcx, this.param_env(), Some(this.tcx.span)).unwrap().unwrap_branch();
499+
let index_len = index.len();
500+
501+
assert_eq!(left_len, right_len);
502+
assert_eq!(index_len as u64, dest_len);
503+
504+
for i in 0..dest_len {
505+
let src_index: u64 = index[i as usize].unwrap_leaf()
506+
.try_to_u32().unwrap()
507+
.into();
508+
let dest = this.project_index(&dest, i)?;
509+
510+
let val = if src_index < left_len {
511+
this.read_immediate(&this.project_index(&left, src_index)?)?
512+
} else if src_index < left_len.checked_add(right_len).unwrap() {
513+
let right_idx = src_index.checked_sub(left_len).unwrap();
514+
this.read_immediate(&this.project_index(&right, right_idx)?)?
515+
} else {
516+
span_bug!(
517+
this.cur_span(),
518+
"simd_shuffle index {src_index} is out of bounds for 2 vectors of size {left_len}",
519+
);
520+
};
521+
this.write_immediate(*val, &dest)?;
522+
}
523+
}
491524
"shuffle" => {
492525
let [left, right, index] = check_arg_count(args)?;
493526
let (left, left_len) = this.operand_to_simd(left)?;

src/tools/miri/tests/pass/portable-simd.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
//@compile-flags: -Zmiri-strict-provenance
2-
#![feature(portable_simd, platform_intrinsics)]
2+
#![feature(portable_simd, platform_intrinsics, adt_const_params, inline_const)]
3+
#![allow(incomplete_features)]
34
use std::simd::*;
45

56
extern "platform-intrinsic" {
@@ -390,6 +391,8 @@ fn simd_intrinsics() {
390391
fn simd_reduce_any<T>(x: T) -> bool;
391392
fn simd_reduce_all<T>(x: T) -> bool;
392393
fn simd_select<M, T>(m: M, yes: T, no: T) -> T;
394+
fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;
395+
fn simd_shuffle<T, IDX, U>(x: T, y: T, idx: IDX) -> U;
393396
}
394397
unsafe {
395398
// Make sure simd_eq returns all-1 for `true`
@@ -413,6 +416,22 @@ fn simd_intrinsics() {
413416
simd_select(i8x4::from_array([0, -1, -1, 0]), b, a),
414417
i32x4::from_array([10, 2, 10, 10])
415418
);
419+
assert_eq!(
420+
simd_shuffle_generic::<_, i32x4, {&[3, 1, 0, 2]}>(a, b),
421+
a,
422+
);
423+
assert_eq!(
424+
simd_shuffle::<_, _, i32x4>(a, b, const {[3, 1, 0, 2]}),
425+
a,
426+
);
427+
assert_eq!(
428+
simd_shuffle_generic::<_, i32x4, {&[7, 5, 4, 6]}>(a, b),
429+
i32x4::from_array([4, 2, 1, 10]),
430+
);
431+
assert_eq!(
432+
simd_shuffle::<_, _, i32x4>(a, b, const {[7, 5, 4, 6]}),
433+
i32x4::from_array([4, 2, 1, 10]),
434+
);
416435
}
417436
}
418437

0 commit comments

Comments
 (0)