Skip to content

Commit 7252870

Browse files
committed
Consistently use the most significant bit of vector masks
This improves the codegen for vector `select`, `gather`, `scatter` and boolean reduction intrinsics and fixes rust-lang/portable-simd#316. The current behavior of mask operations during llvm codegen is to truncate the mask vector to <N x i1>, telling llvm to use the least significat bit. Since sse/avx instructions are defined to use the most significant bit, llvm has to insert a left shift before the mask can actually be used. Similarly on aarch64, mask operations like blend work bit by bit, repeating the least significant bit across the whole lane involves shifting it into the sign position and then comparing against zero. By shifting before truncating to <N x i1>, we tell llvm that we only consider the most significant bit, removing the need for additional shift instructions in the assembly.
1 parent 6b1e5d9 commit 7252870

13 files changed

+282
-170
lines changed

compiler/rustc_codegen_llvm/src/intrinsic.rs

+103-88
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,60 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
10341034
}};
10351035
}
10361036

1037+
/// Returns the bitwidth of the `$ty` argument if it is an `Int` type.
1038+
macro_rules! require_int_ty {
1039+
($ty: expr, $diag: expr) => {
1040+
match $ty {
1041+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1042+
_ => {
1043+
return_error!($diag);
1044+
}
1045+
}
1046+
};
1047+
}
1048+
1049+
/// Returns the bitwidth of the `$ty` argument if it is an `Int` or `Uint` type.
1050+
macro_rules! require_int_or_uint_ty {
1051+
($ty: expr, $diag: expr) => {
1052+
match $ty {
1053+
ty::Int(i) => i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1054+
ty::Uint(i) => {
1055+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
1056+
}
1057+
_ => {
1058+
return_error!($diag);
1059+
}
1060+
}
1061+
};
1062+
}
1063+
1064+
/// Converts a vector mask, where each element has a bit width equal to the data elements it is used with,
1065+
/// down to an i1 based mask that can be used by llvm intrinsics.
1066+
///
1067+
/// The rust simd semantics are that each element should either consist of all ones or all zeroes,
1068+
/// but this information is not available to llvm. Truncating the vector effectively uses the lowest bit,
1069+
/// but codegen for several targets is better if we consider the highest bit by shifting.
1070+
///
1071+
/// For x86 SSE/AVX targets this is beneficial since most instructions with mask parameters only consider the highest bit.
1072+
/// So even though on llvm level we have an additional shift, in the final assembly there is no shift or truncate and
1073+
/// instead the mask can be used as is.
1074+
///
1075+
/// For aarch64 and other targets there is a benefit because a mask from the sign bit can be more
1076+
/// efficiently converted to an all ones / all zeroes mask by comparing whether each element is negative.
1077+
fn vector_mask_to_bitmask<'a, 'll, 'tcx>(
1078+
bx: &mut Builder<'a, 'll, 'tcx>,
1079+
i_xn: &'ll Value,
1080+
in_elem_bitwidth: u64,
1081+
in_len: u64,
1082+
) -> &'ll Value {
1083+
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1084+
let shift_idx = bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
1085+
let shift_indices = vec![shift_idx; in_len as _];
1086+
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
1087+
// Truncate vector to an <i1 x N>
1088+
bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len))
1089+
}
1090+
10371091
let tcx = bx.tcx();
10381092
let sig =
10391093
tcx.normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), callee_ty.fn_sig(tcx));
@@ -1294,14 +1348,11 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
12941348
m_len == v_len,
12951349
InvalidMonomorphization::MismatchedLengths { span, name, m_len, v_len }
12961350
);
1297-
match m_elem_ty.kind() {
1298-
ty::Int(_) => {}
1299-
_ => return_error!(InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }),
1300-
}
1301-
// truncate the mask to a vector of i1s
1302-
let i1 = bx.type_i1();
1303-
let i1xn = bx.type_vector(i1, m_len as u64);
1304-
let m_i1s = bx.trunc(args[0].immediate(), i1xn);
1351+
let in_elem_bitwidth = require_int_ty!(
1352+
m_elem_ty.kind(),
1353+
InvalidMonomorphization::MaskType { span, name, ty: m_elem_ty }
1354+
);
1355+
let m_i1s = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, m_len);
13051356
return Ok(bx.select(m_i1s, args[1].immediate(), args[2].immediate()));
13061357
}
13071358

@@ -1319,32 +1370,12 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
13191370
let expected_bytes = expected_int_bits / 8 + ((expected_int_bits % 8 > 0) as u64);
13201371

13211372
// Integer vector <i{in_bitwidth} x in_len>:
1322-
let (i_xn, in_elem_bitwidth) = match in_elem.kind() {
1323-
ty::Int(i) => (
1324-
args[0].immediate(),
1325-
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1326-
),
1327-
ty::Uint(i) => (
1328-
args[0].immediate(),
1329-
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits()),
1330-
),
1331-
_ => return_error!(InvalidMonomorphization::VectorArgument {
1332-
span,
1333-
name,
1334-
in_ty,
1335-
in_elem
1336-
}),
1337-
};
1373+
let in_elem_bitwidth = require_int_or_uint_ty!(
1374+
in_elem.kind(),
1375+
InvalidMonomorphization::VectorArgument { span, name, in_ty, in_elem }
1376+
);
13381377

1339-
// Shift the MSB to the right by "in_elem_bitwidth - 1" into the first bit position.
1340-
let shift_indices =
1341-
vec![
1342-
bx.cx.const_int(bx.type_ix(in_elem_bitwidth), (in_elem_bitwidth - 1) as _);
1343-
in_len as _
1344-
];
1345-
let i_xn_msb = bx.lshr(i_xn, bx.const_vector(shift_indices.as_slice()));
1346-
// Truncate vector to an <i1 x N>
1347-
let i1xn = bx.trunc(i_xn_msb, bx.type_vector(bx.type_i1(), in_len));
1378+
let i1xn = vector_mask_to_bitmask(bx, args[0].immediate(), in_elem_bitwidth, in_len);
13481379
// Bitcast <i1 x N> to iN:
13491380
let i_ = bx.bitcast(i1xn, bx.type_ix(in_len));
13501381

@@ -1562,28 +1593,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15621593
}
15631594
);
15641595

1565-
match element_ty2.kind() {
1566-
ty::Int(_) => (),
1567-
_ => {
1568-
return_error!(InvalidMonomorphization::ThirdArgElementType {
1569-
span,
1570-
name,
1571-
expected_element: element_ty2,
1572-
third_arg: arg_tys[2]
1573-
});
1596+
let mask_elem_bitwidth = require_int_ty!(
1597+
element_ty2.kind(),
1598+
InvalidMonomorphization::ThirdArgElementType {
1599+
span,
1600+
name,
1601+
expected_element: element_ty2,
1602+
third_arg: arg_tys[2]
15741603
}
1575-
}
1604+
);
15761605

15771606
// Alignment of T, must be a constant integer value:
15781607
let alignment_ty = bx.type_i32();
15791608
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
15801609

15811610
// Truncate the mask vector to a vector of i1s:
1582-
let (mask, mask_ty) = {
1583-
let i1 = bx.type_i1();
1584-
let i1xn = bx.type_vector(i1, in_len);
1585-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1586-
};
1611+
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
1612+
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
15871613

15881614
// Type of the vector of pointers:
15891615
let llvm_pointer_vec_ty = llvm_vector_ty(bx, element_ty1, in_len);
@@ -1668,8 +1694,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
16681694
}
16691695
);
16701696

1671-
require!(
1672-
matches!(mask_elem.kind(), ty::Int(_)),
1697+
let m_elem_bitwidth = require_int_ty!(
1698+
mask_elem.kind(),
16731699
InvalidMonomorphization::ThirdArgElementType {
16741700
span,
16751701
name,
@@ -1678,17 +1704,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
16781704
}
16791705
);
16801706

1707+
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
1708+
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
1709+
16811710
// Alignment of T, must be a constant integer value:
16821711
let alignment_ty = bx.type_i32();
16831712
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
16841713

1685-
// Truncate the mask vector to a vector of i1s:
1686-
let (mask, mask_ty) = {
1687-
let i1 = bx.type_i1();
1688-
let i1xn = bx.type_vector(i1, mask_len);
1689-
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1690-
};
1691-
16921714
let llvm_pointer = bx.type_ptr();
16931715

16941716
// Type of the vector of elements:
@@ -1760,8 +1782,8 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17601782
}
17611783
);
17621784

1763-
require!(
1764-
matches!(mask_elem.kind(), ty::Int(_)),
1785+
let m_elem_bitwidth = require_int_ty!(
1786+
mask_elem.kind(),
17651787
InvalidMonomorphization::ThirdArgElementType {
17661788
span,
17671789
name,
@@ -1770,17 +1792,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
17701792
}
17711793
);
17721794

1795+
let mask = vector_mask_to_bitmask(bx, args[0].immediate(), m_elem_bitwidth, mask_len);
1796+
let mask_ty = bx.type_vector(bx.type_i1(), mask_len);
1797+
17731798
// Alignment of T, must be a constant integer value:
17741799
let alignment_ty = bx.type_i32();
17751800
let alignment = bx.const_i32(bx.align_of(values_elem).bytes() as i32);
17761801

1777-
// Truncate the mask vector to a vector of i1s:
1778-
let (mask, mask_ty) = {
1779-
let i1 = bx.type_i1();
1780-
let i1xn = bx.type_vector(i1, in_len);
1781-
(bx.trunc(args[0].immediate(), i1xn), i1xn)
1782-
};
1783-
17841802
let ret_t = bx.type_void();
17851803

17861804
let llvm_pointer = bx.type_ptr();
@@ -1859,28 +1877,23 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18591877
);
18601878

18611879
// The element type of the third argument must be a signed integer type of any width:
1862-
match element_ty2.kind() {
1863-
ty::Int(_) => (),
1864-
_ => {
1865-
return_error!(InvalidMonomorphization::ThirdArgElementType {
1866-
span,
1867-
name,
1868-
expected_element: element_ty2,
1869-
third_arg: arg_tys[2]
1870-
});
1880+
let mask_elem_bitwidth = require_int_ty!(
1881+
element_ty2.kind(),
1882+
InvalidMonomorphization::ThirdArgElementType {
1883+
span,
1884+
name,
1885+
expected_element: element_ty2,
1886+
third_arg: arg_tys[2]
18711887
}
1872-
}
1888+
);
18731889

18741890
// Alignment of T, must be a constant integer value:
18751891
let alignment_ty = bx.type_i32();
18761892
let alignment = bx.const_i32(bx.align_of(in_elem).bytes() as i32);
18771893

18781894
// Truncate the mask vector to a vector of i1s:
1879-
let (mask, mask_ty) = {
1880-
let i1 = bx.type_i1();
1881-
let i1xn = bx.type_vector(i1, in_len);
1882-
(bx.trunc(args[2].immediate(), i1xn), i1xn)
1883-
};
1895+
let mask = vector_mask_to_bitmask(bx, args[2].immediate(), mask_elem_bitwidth, in_len);
1896+
let mask_ty = bx.type_vector(bx.type_i1(), in_len);
18841897

18851898
let ret_t = bx.type_void();
18861899

@@ -2018,8 +2031,13 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20182031
);
20192032
args[0].immediate()
20202033
} else {
2021-
match in_elem.kind() {
2022-
ty::Int(_) | ty::Uint(_) => {}
2034+
let bitwidth = match in_elem.kind() {
2035+
ty::Int(i) => {
2036+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
2037+
}
2038+
ty::Uint(i) => {
2039+
i.bit_width().unwrap_or_else(|| bx.data_layout().pointer_size.bits())
2040+
}
20232041
_ => return_error!(InvalidMonomorphization::UnsupportedSymbol {
20242042
span,
20252043
name,
@@ -2028,12 +2046,9 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20282046
in_elem,
20292047
ret_ty
20302048
}),
2031-
}
2049+
};
20322050

2033-
// boolean reductions operate on vectors of i1s:
2034-
let i1 = bx.type_i1();
2035-
let i1xn = bx.type_vector(i1, in_len as u64);
2036-
bx.trunc(args[0].immediate(), i1xn)
2051+
vector_mask_to_bitmask(bx, args[0].immediate(), bitwidth, in_len as _)
20372052
};
20382053
return match in_elem.kind() {
20392054
ty::Int(_) | ty::Uint(_) => {

tests/assembly/simd-intrinsic-gather.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ pub unsafe extern "C" fn gather_f64x4(mask: m64x4, ptrs: pf64x4) -> f64x4 {
3636
// FIXME: This should also get checked to generate a gather instruction for avx2.
3737
// Currently llvm scalarizes this code, see https://github.com/llvm/llvm-project/issues/59789
3838
//
39-
// x86-avx512: vpsllq ymm0, ymm0, 63
40-
// x86-avx512-NEXT: vpmovq2m k1, ymm0
39+
// x86-avx512-NOT: vpsllq
40+
// x86-avx512: vpmovq2m k1, ymm0
4141
// x86-avx512-NEXT: vpxor xmm0, xmm0, xmm0
4242
// x86-avx512-NEXT: vgatherqpd ymm0 {k1}, ymmword ptr [1*ymm1]
4343
simd_gather(f64x4([0_f64, 0_f64, 0_f64, 0_f64]), ptrs, mask)

tests/assembly/simd-intrinsic-mask-load.rs

+13-14
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ extern "rust-intrinsic" {
4646
pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
4747
// Since avx2 supports no masked loads for bytes, the code tests each individual bit
4848
// and jumps to code that inserts individual bytes.
49-
// x86-avx2: vpsllw xmm0, xmm0, 7
50-
// x86-avx2-NEXT: vpmovmskb eax, xmm0
51-
// x86-avx2-NEXT: vpxor xmm0, xmm0
49+
// x86-avx2-NOT: vpsllw
50+
// x86-avx2-DAG: vpmovmskb eax
51+
// x86-avx2-DAG: vpxor
5252
// x86-avx2-NEXT: test al, 1
5353
// x86-avx2-NEXT: jne
5454
// x86-avx2-NEXT: test al, 2
@@ -57,32 +57,31 @@ pub unsafe extern "C" fn load_i8x16(mask: m8x16, pointer: *const i8) -> i8x16 {
5757
// x86-avx2-NEXT: vmovd xmm0, [[REG]]
5858
// x86-avx2-DAG: vpinsrb xmm0, xmm0, byte ptr [rdi + 1], 1
5959
//
60-
// x86-avx512: vpsllw xmm0, xmm0, 7
61-
// x86-avx512-NEXT: vpmovb2m k1, xmm0
60+
// x86-avx512-NOT: vpsllw
61+
// x86-avx512: vpmovb2m k1, xmm0
6262
// x86-avx512-NEXT: vmovdqu8 xmm0 {k1} {z}, xmmword ptr [rdi]
6363
simd_masked_load(mask, pointer, i8x16([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]))
6464
}
6565

6666
// CHECK-LABEL: load_f32x8
6767
#[no_mangle]
6868
pub unsafe extern "C" fn load_f32x8(mask: m32x8, pointer: *const f32) -> f32x8 {
69-
// x86-avx2: vpslld ymm0, ymm0, 31
70-
// x86-avx2-NEXT: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
69+
// x86-avx2-NOT: vpslld
70+
// x86-avx2: vmaskmovps ymm0, ymm0, ymmword ptr [rdi]
7171
//
72-
// x86-avx512: vpslld ymm0, ymm0, 31
73-
// x86-avx512-NEXT: vpmovd2m k1, ymm0
72+
// x86-avx512-NOT: vpslld
73+
// x86-avx512: vpmovd2m k1, ymm0
7474
// x86-avx512-NEXT: vmovups ymm0 {k1} {z}, ymmword ptr [rdi]
7575
simd_masked_load(mask, pointer, f32x8([0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32, 0_f32]))
7676
}
7777

7878
// CHECK-LABEL: load_f64x4
7979
#[no_mangle]
8080
pub unsafe extern "C" fn load_f64x4(mask: m64x4, pointer: *const f64) -> f64x4 {
81-
// x86-avx2: vpsllq ymm0, ymm0, 63
82-
// x86-avx2-NEXT: vmaskmovpd ymm0, ymm0, ymmword ptr [rdi]
81+
// x86-avx2-NOT: vpsllq
82+
// x86-avx2: vmaskmovpd ymm0, ymm0, ymmword ptr [rdi]
8383
//
84-
// x86-avx512: vpsllq ymm0, ymm0, 63
85-
// x86-avx512-NEXT: vpmovq2m k1, ymm0
86-
// x86-avx512-NEXT: vmovupd ymm0 {k1} {z}, ymmword ptr [rdi]
84+
// x86-avx512-NOT: vpsllq
85+
// x86-avx512: vpmovq2m k1, ymm0
8786
simd_masked_load(mask, pointer, f64x4([0_f64, 0_f64, 0_f64, 0_f64]))
8887
}

0 commit comments

Comments
 (0)