diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index c02dd338da0e..fc96628e31a3 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -965,6 +965,33 @@ fn generic_simd_intrinsic<'ll, 'tcx>( }}; } + /// Returns the bitwidth of the `$ty` argument if it is an `Int` type. + macro_rules! require_int_ty { + ($ty: expr, $diag: expr) => { + match $ty { + ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), + _ => { + return_error!($diag); + } + } + }; + } + + /// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type. + macro_rules! require_int_or_uint_ty { + ($ty: expr, $diag: expr) => { + match $ty { + ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), + ty::Uint(i) => { + i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()) + } + _ => { + return_error!($diag); + } + } + }; + } + /// Converts a vector mask, where each element has a bit width equal to the data elements it is used with, /// down to an i1 based mask that can be used by llvm intrinsics. /// @@ -1252,10 +1279,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>( m_len == v_len, InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len } ); - let in_elem_bitwidth = match m_elem_ty.kind() { - ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), - _ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }), - }; + let in_elem_bitwidth = require_int_ty!( + m_elem_ty.kind(), + InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty } + ); let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len); return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate())); } @@ -1274,24 +1301,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>( let expected_bytes = expected_int_bits / 8 + ((expected_int_bits % 8 > 0) as u64); // Integer vector : - let (i_xn, in_elem_bitwidth) = match in_elem.kind() { - ty::Int(i) => ( - args[0].immediate(), - i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), - ), - ty::Uint(i) => ( - args[0].immediate(), - i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), - ), - _ => return_error!(InvalidMonomorphization::VectorArgument { - span, - name, - in_ty, - in_elem - }), - }; + let in_elem_bitwidth = require_int_or_uint_ty!( + in_elem.kind(), + InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem } + ); - let i1xn = vector_mask_to_bitmask(bx, i_xn, in_elem_bitwidth, in_len); + let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len); // Bitcast to iN: let i_ = bx.bitcast(i1xn, bx.type_ix(in_len)); @@ -1509,17 +1524,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>( } ); - let mask_elem_bitwidth = match element_ty2.kind() { - ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), - _ => { - return_error!(InvalidMonomorphization::ThirdArgElementType { - span, - name, - expected_element: element_ty2, - third_arg: arg_tys[2] - }) + let mask_elem_bitwidth = require_int_ty!( + element_ty2.kind(), + InvalidMonomorphization::ThirdArgElementType { + span, + name, + expected_element: element_ty2, + third_arg: arg_tys[2] } - }; + ); // Alignment of T, must be a constant integer value: let alignment_ty = bx.type_i32(); @@ -1612,8 +1625,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>( } ); - require!( - matches!(mask_elem.kind(), ty::Int(_)), + let m_elem_bitwidth = require_int_ty!( + mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType { span, name, @@ -1622,17 +1635,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>( } ); + let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len); + let mask_ty = bx.type_vector(bx.type_i1(), mask_len); + // Alignment of T, must be a constant integer value: let alignment_ty = bx.type_i32(); let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32); - // Truncate the mask vector to a vector of i1s: - let (mask, mask_ty) = { - let i1 = bx.type_i1(); - let i1xn = bx.type_vector(i1, mask_len); - (bx.trunc(args[0].immediate(), i1xn), i1xn) - }; - let llvm_pointer = bx.type_ptr(); // Type of the vector of elements: @@ -1704,8 +1713,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>( } ); - require!( - matches!(mask_elem.kind(), ty::Int(_)), + let m_elem_bitwidth = require_int_ty!( + mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType { span, name, @@ -1714,17 +1723,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>( } ); + let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len); + let mask_ty = bx.type_vector(bx.type_i1(), mask_len); + // Alignment of T, must be a constant integer value: let alignment_ty = bx.type_i32(); let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32); - // Truncate the mask vector to a vector of i1s: - let (mask, mask_ty) = { - let i1 = bx.type_i1(); - let i1xn = bx.type_vector(i1, in_len); - (bx.trunc(args[0].immediate(), i1xn), i1xn) - }; - let ret_t = bx.type_void(); let llvm_pointer = bx.type_ptr(); @@ -1803,17 +1808,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>( ); // The element type of the third argument must be a signed integer type of any width: - let mask_elem_bitwidth = match element_ty2.kind() { - ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()), - _ => { - return_error!(InvalidMonomorphization::ThirdArgElementType { - span, - name, - expected_element: element_ty2, - third_arg: arg_tys[2] - }); + let mask_elem_bitwidth = require_int_ty!( + element_ty2.kind(), + InvalidMonomorphization::ThirdArgElementType { + span, + name, + expected_element: element_ty2, + third_arg: arg_tys[2] } - }; + ); // Alignment of T, must be a constant integer value: let alignment_ty = bx.type_i32(); diff --git a/tests/assembly/simd-intrinsic-mask-load.rs b/tests/assembly/simd-intrinsic-mask-load.rs new file mode 100644 index 000000000000..b80af1141cc5 --- /dev/null +++ b/tests/assembly/simd-intrinsic-mask-load.rs @@ -0,0 +1,48 @@ +// verify that simd mask reductions do not introduce additional bit shift operations +//@ revisions: x86 aarch64 +//@ [x86] compile-flags: --target=x86_64-unknown-linux-gnu -C llvm-args=-x86-asm-syntax=intel +//@ [x86] needs-llvm-components: x86 +//@ [aarch64] compile-flags: --target=aarch64-unknown-linux-gnu +//@ [aarch64] needs-llvm-components: aarch64 +//@ [aarch64] min-llvm-version: 15.0 +//@ assembly-output: emit-asm +//@ compile-flags: --crate-type=lib -O + +#![feature(no_core, lang_items, repr_simd, intrinsics)] +#![no_core] +#![allow(non_camel_case_types)] + +// Because we don't have core yet. +#[lang = "sized"] +pub trait Sized {} + +#[lang = "copy"] +trait Copy {} + +#[repr(simd)] +pub struct mask8x16([i8; 16]); + +extern "rust-intrinsic" { + fn simd_reduce_all(x: T) -> bool; + fn simd_reduce_any(x: T) -> bool; +} + +// CHECK-LABEL: mask_reduce_all: +#[no_mangle] +pub unsafe fn mask_reduce_all(m: mask8x16) -> bool { + // x86: movdqa + // x86-NEXT: pmovmskb + // aarch64: cmge + // aarch64-NEXT: umaxv + simd_reduce_all(m) +} + +// CHECK-LABEL: mask_reduce_any: +#[no_mangle] +pub unsafe fn mask_reduce_any(m: mask8x16) -> bool { + // x86: movdqa + // x86-NEXT: pmovmskb + // aarch64: cmlt + // aarch64-NEXT: umaxv + simd_reduce_any(m) +} diff --git a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs index b41c42810aaf..0b9a21320336 100644 --- a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs +++ b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-load.rs @@ -21,7 +21,9 @@ extern "rust-intrinsic" { #[no_mangle] pub unsafe fn load_f32x2(mask: Vec2, pointer: *const f32, values: Vec2) -> Vec2 { - // CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> {{.*}}, <2 x float> {{.*}}) + // CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, + // CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1> + // CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}}) simd_masked_load(mask, pointer, values) } @@ -29,6 +31,8 @@ pub unsafe fn load_f32x2(mask: Vec2, pointer: *const f32, #[no_mangle] pub unsafe fn load_pf32x4(mask: Vec4, pointer: *const *const f32, values: Vec4<*const f32>) -> Vec4<*const f32> { - // CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}}, <4 x ptr> {{.*}}) + // CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, + // CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1> + // CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]], <4 x ptr> {{.*}}) simd_masked_load(mask, pointer, values) } diff --git a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs index 066392bcde68..407eaecc1bf6 100644 --- a/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs +++ b/tests/codegen/simd-intrinsic/simd-intrinsic-generic-masked-store.rs @@ -20,13 +20,17 @@ extern "rust-intrinsic" { // CHECK-LABEL: @store_f32x2 #[no_mangle] pub unsafe fn store_f32x2(mask: Vec2, pointer: *mut f32, values: Vec2) { - // CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> {{.*}}) + // CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, + // CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1> + // CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]]) simd_masked_store(mask, pointer, values) } // CHECK-LABEL: @store_pf32x4 #[no_mangle] pub unsafe fn store_pf32x4(mask: Vec4, pointer: *mut *const f32, values: Vec4<*const f32>) { - // CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}}) + // CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, + // CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1> + // CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]]) simd_masked_store(mask, pointer, values) }