Skip to content

Commit 3779b8e

Browse files
committed
Consistently use the most significant bit of vector masks
This improves the codegen for vector `select`, `gather`, `scatter` and boolean reduction intrinsics and fixes rust-lang/portable-simd#316. The current behavior of most mask operations during llvm codegen is to truncate the mask vector to <N x i1>, telling llvm to use the least significat bit. The exception is the `simd_bitmask` intrinsics, which already used the most signifiant bit. Since sse/avx instructions are defined to use the most significant bit, truncating means that llvm has to insert a left shift to move the bit into the most significant position, before the mask can actually be used. Similarly on aarch64, mask operations like blend work bit by bit, repeating the least significant bit across the whole lane involves shifting it into the sign position and then comparing against zero. By shifting before truncating to <N x i1>, we tell llvm that we only consider the most significant bit, removing the need for additional shift instructions in the assembly.
1 parent c2270be commit 3779b8e

13 files changed

+280
-172
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

+104-93
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
11821182
}};
11831183
}
11841184

1185+
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1186+
macro_rules! require_int_ty {
1187+
($ty: expr, $diag: expr) => {
1188+
match $ty {
1189+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1190+
_ => {
1191+
return_error!($diag);
1192+
}
1193+
}
1194+
};
1195+
}
1196+
1197+
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
1198+
macro_rules! require_int_or_uint_ty {
1199+
($ty: expr, $diag: expr) => {
1200+
match $ty {
1201+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1202+
ty::Uint(i) => {
1203+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
1204+
}
1205+
_ => {
1206+
return_error!($diag);
1207+
}
1208+
}
1209+
};
1210+
}
1211+
1212+
/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
1213+
/// down to an i1 based mask that can be used by llvm intrinsics.
1214+
///
1215+
/// The rust simd semantics are that each element should either consist of all ones or all zeroes,
1216+
/// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
1217+
/// but codegen for several targets is better if we consider the highest bit by shifting.
1218+
///
1219+
/// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
1220+
/// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
1221+
/// instead the mask can be used as is.
1222+
///
1223+
/// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
1224+
/// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
1225+
fn vector_mask_to_bitmask<'a, 'll, 'tcx>(
1226+
bx: &mut Builder<'a, 'll, 'tcx>,
1227+
i_xn: &'ll Value,
1228+
in_elem_bitwidth: u64,
1229+
in_len: u64,
1230+
) -> &'ll Value {
1231+
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1232+
let shift_idx = bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
1233+
let shift_indices = vec![shift_idx; in_len as _];
1234+
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
1235+
// Truncate vector to an <i1 x N>
1236+
bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len))
1237+
}
1238+
11851239
let tcx = bx.tcx();
11861240
let sig = tcx.normalize_erasing_late_bound_regions(bx.typing_env(), callee_ty.fn_sig(tcx));
11871241
let arg_tys = sig.inputs();
@@ -1433,14 +1487,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14331487
m_len,
14341488
v_len
14351489
});
1436-
match m_elem_ty.kind() {
1437-
ty::Int(_) => {}
1438-
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
1439-
}
1440-
// truncate the mask to a vector of i1s
1441-
let i1 = bx.type_i1();
1442-
let i1xn = bx.type_vector(i1, m_len as u64);
1443-
let m_i1s = bx.trunc(args[0].immediate(), i1xn);
1490+
let in_elem_bitwidth =
1491+
require_int_ty!(m_elem_ty.kind(), InvalidMonomorphization::MaskType {
1492+
span,
1493+
name,
1494+
ty: m_elem_ty
1495+
});
1496+
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
14441497
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
14451498
}
14461499

@@ -1457,33 +1510,15 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
14571510
let expected_bytes = in_len.div_ceil(8);
14581511

14591512
// Integer vector <i{in_bitwidth} x in_len>:
1460-
let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
1461-
ty::Int(i) => (
1462-
args[0].immediate(),
1463-
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1464-
),
1465-
ty::Uint(i) => (
1466-
args[0].immediate(),
1467-
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1468-
),
1469-
_ => return_error!(InvalidMonomorphization::VectorArgument {
1513+
let in_elem_bitwidth =
1514+
require_int_or_uint_ty!(in_elem.kind(), InvalidMonomorphization::VectorArgument {
14701515
span,
14711516
name,
14721517
in_ty,
14731518
in_elem
1474-
}),
1475-
};
1519+
});
14761520

1477-
// LLVM doesn't always know the inputs are `0` or `!0`, so we shift here so it optimizes to
1478-
// `pmovmskb` and similar on x86.
1479-
let shift_indices =
1480-
vec![
1481-
bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
1482-
in_len as _
1483-
];
1484-
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
1485-
// Truncate vector to an <i1 x N>
1486-
let i1xn = bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len));
1521+
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
14871522
// Bitcast <i1 x N> to iN:
14881523
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));
14891524

@@ -1704,28 +1739,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17041739
}
17051740
);
17061741

1707-
match element_ty2.kind() {
1708-
ty::Int(_) => (),
1709-
_ => {
1710-
return_error!(InvalidMonomorphization::ThirdArgElementType {
1711-
span,
1712-
name,
1713-
expected_element: element_ty2,
1714-
third_arg: arg_tys[2]
1715-
});
1716-
}
1717-
}
1742+
let mask_elem_bitwidth =
1743+
require_int_ty!(element_ty2.kind(), InvalidMonomorphization::ThirdArgElementType {
1744+
span,
1745+
name,
1746+
expected_element: element_ty2,
1747+
third_arg: arg_tys[2]
1748+
});
17181749

17191750
// Alignment of T, must be a constant integer value:
17201751
let alignment_ty = bx.type_i32();
17211752
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
17221753

17231754
// Truncate the mask vector to a vector of i1s:
1724-
let (mask, mask_ty) = {
1725-
let i1 = bx.type_i1();
1726-
let i1xn = bx.type_vector(i1, in_len);
1727-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1728-
};
1755+
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
1756+
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
17291757

17301758
// Type of the vector of pointers:
17311759
let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
@@ -1810,27 +1838,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18101838
}
18111839
);
18121840

1813-
require!(
1814-
matches!(mask_elem.kind(), ty::Int(_)),
1815-
InvalidMonomorphization::ThirdArgElementType {
1841+
let m_elem_bitwidth =
1842+
require_int_ty!(mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType {
18161843
span,
18171844
name,
18181845
expected_element: values_elem,
18191846
third_arg: mask_ty,
1820-
}
1821-
);
1847+
});
1848+
1849+
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
1850+
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
18221851

18231852
// Alignment of T, must be a constant integer value:
18241853
let alignment_ty = bx.type_i32();
18251854
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
18261855

1827-
// Truncate the mask vector to a vector of i1s:
1828-
let (mask, mask_ty) = {
1829-
let i1 = bx.type_i1();
1830-
let i1xn = bx.type_vector(i1, mask_len);
1831-
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1832-
};
1833-
18341856
let llvm_pointer = bx.type_ptr();
18351857

18361858
// Type of the vector of elements:
@@ -1901,27 +1923,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19011923
}
19021924
);
19031925

1904-
require!(
1905-
matches!(mask_elem.kind(), ty::Int(_)),
1906-
InvalidMonomorphization::ThirdArgElementType {
1926+
let m_elem_bitwidth =
1927+
require_int_ty!(mask_elem.kind(), InvalidMonomorphization::ThirdArgElementType {
19071928
span,
19081929
name,
19091930
expected_element: values_elem,
19101931
third_arg: mask_ty,
1911-
}
1912-
);
1932+
});
1933+
1934+
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
1935+
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
19131936

19141937
// Alignment of T, must be a constant integer value:
19151938
let alignment_ty = bx.type_i32();
19161939
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
19171940

1918-
// Truncate the mask vector to a vector of i1s:
1919-
let (mask, mask_ty) = {
1920-
let i1 = bx.type_i1();
1921-
let i1xn = bx.type_vector(i1, in_len);
1922-
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1923-
};
1924-
19251941
let ret_t = bx.type_void();
19261942

19271943
let llvm_pointer = bx.type_ptr();
@@ -1995,28 +2011,21 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19952011
);
19962012

19972013
// The element type of the third argument must be a signed integer type of any width:
1998-
match element_ty2.kind() {
1999-
ty::Int(_) => (),
2000-
_ => {
2001-
return_error!(InvalidMonomorphization::ThirdArgElementType {
2002-
span,
2003-
name,
2004-
expected_element: element_ty2,
2005-
third_arg: arg_tys[2]
2006-
});
2007-
}
2008-
}
2014+
let mask_elem_bitwidth =
2015+
require_int_ty!(element_ty2.kind(), InvalidMonomorphization::ThirdArgElementType {
2016+
span,
2017+
name,
2018+
expected_element: element_ty2,
2019+
third_arg: arg_tys[2]
2020+
});
20092021

20102022
// Alignment of T, must be a constant integer value:
20112023
let alignment_ty = bx.type_i32();
20122024
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
20132025

20142026
// Truncate the mask vector to a vector of i1s:
2015-
let (mask, mask_ty) = {
2016-
let i1 = bx.type_i1();
2017-
let i1xn = bx.type_vector(i1, in_len);
2018-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
2019-
};
2027+
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
2028+
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
20202029

20212030
let ret_t = bx.type_void();
20222031

@@ -2164,8 +2173,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
21642173
});
21652174
args[0].immediate()
21662175
} else {
2167-
match in_elem.kind() {
2168-
ty::Int(_) | ty::Uint(_) => {}
2176+
let bitwidth = match in_elem.kind() {
2177+
ty::Int(i) => {
2178+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
2179+
}
2180+
ty::Uint(i) => {
2181+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
2182+
}
21692183
_ => return_error!(InvalidMonomorphization::UnsupportedSymbol {
21702184
span,
21712185
name,
@@ -2174,12 +2188,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
21742188
in_elem,
21752189
ret_ty
21762190
}),
2177-
}
2191+
};
21782192

2179-
// boolean reductions operate on vectors of i1s:
2180-
let i1 = bx.type_i1();
2181-
let i1xn = bx.type_vector(i1, in_len as u64);
2182-
bx.trunc(args[0].immediate(), i1xn)
2193+
vector_mask_to_bitmask(bx, args[0].immediate(), bitwidth, in_len as _)
21832194
};
21842195
return match in_elem.kind() {
21852196
ty::Int(_) | ty::Uint(_) => {

tests/assembly/simd-intrinsic-gather.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ pub unsafe extern "C" fn gather_f64x4(mask: m64x4, ptrs: pf64x4) -> f64x4 {
3535
// FIXME: This should also get checked to generate a gather instruction for avx2.
3636
// Currently llvm scalarizes this code, see https://github.com/llvm/llvm-project/issues/59789
3737
//
38-
// x86-avx512: vpsllq ymm0, ymm0, 63
39-
// x86-avx512-NEXT: vpmovq2m k1, ymm0
38+
// x86-avx512-NOT: vpsllq
39+
// x86-avx512: vpmovq2m k1, ymm0
4040
// x86-avx512-NEXT: vpxor xmm0, xmm0, xmm0
4141
// x86-avx512-NEXT: vgatherqpd ymm0 {k1}, {{(ymmword)|(qword)}} ptr [1*ymm1]
4242
simd_gather(f64x4([0_f64, 0_f64, 0_f64, 0_f64]), ptrs, mask)

tests/assembly/simd-intrinsic-mask-load.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ extern "rust-intrinsic" {
4747
pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
4848
// Since avx2 supports no masked loads for bytes, the code tests each individual bit
4949
// and jumps to code that inserts individual bytes.
50-
// x86-avx2: vpsllw xmm0, xmm0, 7
51-
// x86-avx2-NEXT: vpmovmskb eax, xmm0
52-
// x86-avx2-NEXT: vpxor xmm0, xmm0
50+
// x86-avx2-NOT: vpsllw
51+
// x86-avx2-DAG: vpmovmskb eax
52+
// x86-avx2-DAG: vpxor
5353
// x86-avx2-NEXT: test al, 1
5454
// x86-avx2-NEXT: jne
5555
// x86-avx2-NEXT: test al, 2
@@ -58,32 +58,31 @@ pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
5858
// x86-avx2-NEXT: vmovd xmm0, [[REG]]
5959
// x86-avx2-DAG: vpinsrb xmm0, xmm0, byte ptr [rdi + 1], 1
6060
//
61-
// x86-avx512: vpsllw xmm0, xmm0, 7
62-
// x86-avx512-NEXT: vpmovb2m k1, xmm0
61+
// x86-avx512-NOT: vpsllw
62+
// x86-avx512: vpmovb2m k1, xmm0
6363
// x86-avx512-NEXT: vmovdqu8 xmm0 {k1} {z}, xmmword ptr [rdi]
6464
simd_masked_load(mask, pointer, i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
6565
}
6666

6767
// CHECK-LABEL: load_f32x8
6868
#[no_mangle]
6969
pub unsafe extern "C" fn load_f32x8(mask: m32x8, pointer: *const f32) -> f32x8 {
70-
// x86-avx2: vpslld ymm0, ymm0, 31
71-
// x86-avx2-NEXT: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
70+
// x86-avx2-NOT: vpslld
71+
// x86-avx2: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
7272
//
73-
// x86-avx512: vpslld ymm0, ymm0, 31
74-
// x86-avx512-NEXT: vpmovd2m k1, ymm0
73+
// x86-avx512-NOT: vpslld
74+
// x86-avx512: vpmovd2m k1, ymm0
7575
// x86-avx512-NEXT: vmovups ymm0 {k1} {z}, ymmword ptr [rdi]
7676
simd_masked_load(mask, pointer, f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]))
7777
}
7878

7979
// CHECK-LABEL: load_f64x4
8080
#[no_mangle]
8181
pub unsafe extern "C" fn load_f64x4(mask: m64x4, pointer: *const f64) -> f64x4 {
82-
// x86-avx2: vpsllq ymm0, ymm0, 63
83-
// x86-avx2-NEXT: vmaskmovpd ymm0, ymm0, ymmword ptr [rdi]
82+
// x86-avx2-NOT: vpsllq
83+
// x86-avx2: vmaskmovpd ymm0, ymm0, ymmword ptr [rdi]
8484
//
85-
// x86-avx512: vpsllq ymm0, ymm0, 63
86-
// x86-avx512-NEXT: vpmovq2m k1, ymm0
87-
// x86-avx512-NEXT: vmovupd ymm0 {k1} {z}, ymmword ptr [rdi]
85+
// x86-avx512-NOT: vpsllq
86+
// x86-avx512: vpmovq2m k1, ymm0
8887
simd_masked_load(mask, pointer, f64x4([0_f64, 0_f64, 0_f64, 0_f64]))
8988
}

0 commit comments

Comments
 (0)