Skip to content

Commit

Permalink
Use same code for masked conversion for masked stores and loads
Browse files Browse the repository at this point in the history
  • Loading branch information
jhorstmann committed Mar 1, 2024
1 parent 2abd0f0 commit df7fcb1
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 63 deletions.
121 changes: 62 additions & 59 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,33 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}};
}

/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
macro_rules! require_int_ty {
($ty: expr, $diag: expr) => {
match $ty {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
_ => {
return_error!($diag);
}
}
};
}

/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
macro_rules! require_int_or_uint_ty {
($ty: expr, $diag: expr) => {
match $ty {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
ty::Uint(i) => {
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
}
_ => {
return_error!($diag);
}
}
};
}

/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
/// down to an i1 based mask that can be used by llvm intrinsics.
///
Expand Down Expand Up @@ -1252,10 +1279,10 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
m_len == v_len,
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
);
let in_elem_bitwidth = match m_elem_ty.kind() {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
};
let in_elem_bitwidth = require_int_ty!(
m_elem_ty.kind(),
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
);
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
}
Expand All @@ -1274,24 +1301,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
let expected_bytes = expected_int_bits / 8 + ((expected_int_bits % 8 > 0) as u64);

// Integer vector <i{in_bitwidth} x in_len>:
let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
ty::Int(i) => (
args[0].immediate(),
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
),
ty::Uint(i) => (
args[0].immediate(),
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
),
_ => return_error!(InvalidMonomorphization::VectorArgument {
span,
name,
in_ty,
in_elem
}),
};
let in_elem_bitwidth = require_int_or_uint_ty!(
in_elem.kind(),
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
);

let i1xn = vector_mask_to_bitmask(bx, i_xn, in_elem_bitwidth, in_len);
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
// Bitcast <i1 x N> to iN:
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));

Expand Down Expand Up @@ -1509,17 +1524,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let mask_elem_bitwidth = match element_ty2.kind() {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
_ => {
return_error!(InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
})
let mask_elem_bitwidth = require_int_ty!(
element_ty2.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
};
);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
Expand Down Expand Up @@ -1612,8 +1625,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

require!(
matches!(mask_elem.kind(), ty::Int(_)),
let m_elem_bitwidth = require_int_ty!(
mask_elem.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
Expand All @@ -1622,17 +1635,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);

// Truncate the mask vector to a vector of i1s:
let (mask, mask_ty) = {
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, mask_len);
(bx.trunc(args[0].immediate(), i1xn), i1xn)
};

let llvm_pointer = bx.type_ptr();

// Type of the vector of elements:
Expand Down Expand Up @@ -1704,8 +1713,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

require!(
matches!(mask_elem.kind(), ty::Int(_)),
let m_elem_bitwidth = require_int_ty!(
mask_elem.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
Expand All @@ -1714,17 +1723,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
}
);

let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);

// Truncate the mask vector to a vector of i1s:
let (mask, mask_ty) = {
let i1 = bx.type_i1();
let i1xn = bx.type_vector(i1, in_len);
(bx.trunc(args[0].immediate(), i1xn), i1xn)
};

let ret_t = bx.type_void();

let llvm_pointer = bx.type_ptr();
Expand Down Expand Up @@ -1803,17 +1808,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
);

// The element type of the third argument must be a signed integer type of any width:
let mask_elem_bitwidth = match element_ty2.kind() {
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
_ => {
return_error!(InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
});
let mask_elem_bitwidth = require_int_ty!(
element_ty2.kind(),
InvalidMonomorphization::ThirdArgElementType {
span,
name,
expected_element: element_ty2,
third_arg: arg_tys[2]
}
};
);

// Alignment of T, must be a constant integer value:
let alignment_ty = bx.type_i32();
Expand Down
48 changes: 48 additions & 0 deletions tests/assembly/simd-intrinsic-mask-load.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// verify that simd mask reductions do not introduce additional bit shift operations
//@ revisions: x86 aarch64
//@ [x86] compile-flags: --target=x86_64-unknown-linux-gnu -C llvm-args=-x86-asm-syntax=intel
//@ [x86] needs-llvm-components: x86
//@ [aarch64] compile-flags: --target=aarch64-unknown-linux-gnu
//@ [aarch64] needs-llvm-components: aarch64
//@ [aarch64] min-llvm-version: 15.0
//@ assembly-output: emit-asm
//@ compile-flags: --crate-type=lib -O

#![feature(no_core, lang_items, repr_simd, intrinsics)]
#![no_core]
#![allow(non_camel_case_types)]

// Because we don't have core yet.
#[lang = "sized"]
pub trait Sized {}

#[lang = "copy"]
trait Copy {}

#[repr(simd)]
pub struct mask8x16([i8; 16]);

extern "rust-intrinsic" {
fn simd_reduce_all<T>(x: T) -> bool;
fn simd_reduce_any<T>(x: T) -> bool;
}

// CHECK-LABEL: mask_reduce_all:
#[no_mangle]
pub unsafe fn mask_reduce_all(m: mask8x16) -> bool {
// x86: movdqa
// x86-NEXT: pmovmskb
// aarch64: cmge
// aarch64-NEXT: umaxv
simd_reduce_all(m)
}

// CHECK-LABEL: mask_reduce_any:
#[no_mangle]
pub unsafe fn mask_reduce_any(m: mask8x16) -> bool {
// x86: movdqa
// x86-NEXT: pmovmskb
// aarch64: cmlt
// aarch64-NEXT: umaxv
simd_reduce_any(m)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@ extern "rust-intrinsic" {
#[no_mangle]
pub unsafe fn load_f32x2(mask: Vec2<i32>, pointer: *const f32,
values: Vec2<f32>) -> Vec2<f32> {
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> {{.*}}, <2 x float> {{.*}})
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, <i32 31, i32 31>
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call <2 x float> @llvm.masked.load.v2f32.p0(ptr {{.*}}, i32 4, <2 x i1> [[B]], <2 x float> {{.*}})
simd_masked_load(mask, pointer, values)
}

// CHECK-LABEL: @load_pf32x4
#[no_mangle]
pub unsafe fn load_pf32x4(mask: Vec4<i32>, pointer: *const *const f32,
values: Vec4<*const f32>) -> Vec4<*const f32> {
// CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}}, <4 x ptr> {{.*}})
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, <i32 31, i32 31, i32 31, i32 31>
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
// CHECK: call <4 x ptr> @llvm.masked.load.v4p0.p0(ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]], <4 x ptr> {{.*}})
simd_masked_load(mask, pointer, values)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ extern "rust-intrinsic" {
// CHECK-LABEL: @store_f32x2
#[no_mangle]
pub unsafe fn store_f32x2(mask: Vec2<i32>, pointer: *mut f32, values: Vec2<f32>) {
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> {{.*}})
// CHECK: [[A:%[0-9]+]] = lshr <2 x i32> {{.*}}, <i32 31, i32 31>
// CHECK: [[B:%[0-9]+]] = trunc <2 x i32> [[A]] to <2 x i1>
// CHECK: call void @llvm.masked.store.v2f32.p0(<2 x float> {{.*}}, ptr {{.*}}, i32 4, <2 x i1> [[B]])
simd_masked_store(mask, pointer, values)
}

// CHECK-LABEL: @store_pf32x4
#[no_mangle]
pub unsafe fn store_pf32x4(mask: Vec4<i32>, pointer: *mut *const f32, values: Vec4<*const f32>) {
// CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> {{.*}})
// CHECK: [[A:%[0-9]+]] = lshr <4 x i32> {{.*}}, <i32 31, i32 31, i32 31, i32 31>
// CHECK: [[B:%[0-9]+]] = trunc <4 x i32> [[A]] to <4 x i1>
// CHECK: call void @llvm.masked.store.v4p0.p0(<4 x ptr> {{.*}}, ptr {{.*}}, i32 {{.*}}, <4 x i1> [[B]])
simd_masked_store(mask, pointer, values)
}

0 comments on commit df7fcb1

Please sign in to comment.