Skip to content

Commit

Permalink
Rollup merge of rust-lang#78863 - KodrAus:feat/simd-array, r=oli-obk
Browse files Browse the repository at this point in the history
Support repr(simd) on ADTs containing a single array field

This is a squash and rebase of `@gnzlbg's` rust-lang#63531

I've never actually written code in the compiler before so just fumbled my way around until it would build 😅

I imagine there'll be some work we need to do in `rustc_codegen_cranelift` too for this now, but might need some input from `@bjorn3` to know what that is.

cc `@rust-lang/project-portable-simd`

-----

This PR allows using `#[repr(simd)]` on ADTs containing a single array field:

```rust
 #[repr(simd)] struct S0([f32; 4]);
 #[repr(simd)] struct S1<const N: usize>([f32; N]);
 #[repr(simd)] struct S2<T, const N: usize>([T; N]);
```

This should allow experimenting with portable packed SIMD abstractions on nightly that make use of const generics.
  • Loading branch information
Dylan-DPC authored Nov 17, 2020
2 parents 975b50d + e217fc4 commit 80c3b55
Show file tree
Hide file tree
Showing 17 changed files with 467 additions and 157 deletions.
105 changes: 55 additions & 50 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ fn generic_simd_intrinsic(
_ => return_error!("`{}` is not an integral type", in_ty),
};
require_simd!(arg_tys[1], "argument");
let v_len = arg_tys[1].simd_size(tcx);
let (v_len, _) = arg_tys[1].simd_size_and_type(bx.tcx());
require!(
// Allow masks for vectors with fewer than 8 elements to be
// represented with a u8 or i8.
Expand All @@ -812,8 +812,6 @@ fn generic_simd_intrinsic(
// every intrinsic below takes a SIMD vector as its first argument
require_simd!(arg_tys[0], "input");
let in_ty = arg_tys[0];
let in_elem = arg_tys[0].simd_type(tcx);
let in_len = arg_tys[0].simd_size(tcx);

let comparison = match name {
sym::simd_eq => Some(hir::BinOpKind::Eq),
Expand All @@ -825,14 +823,15 @@ fn generic_simd_intrinsic(
_ => None,
};

let (in_len, in_elem) = arg_tys[0].simd_size_and_type(bx.tcx());
if let Some(cmp_op) = comparison {
require_simd!(ret_ty, "return");

let out_len = ret_ty.simd_size(tcx);
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
require!(
in_len == out_len,
"expected return type with length {} (same as input type `{}`), \
found `{}` with length {}",
found `{}` with length {}",
in_len,
in_ty,
ret_ty,
Expand All @@ -842,7 +841,7 @@ fn generic_simd_intrinsic(
bx.type_kind(bx.element_type(llret_ty)) == TypeKind::Integer,
"expected return type with integer elements, found `{}` with non-integer `{}`",
ret_ty,
ret_ty.simd_type(tcx)
out_ty
);

return Ok(compare_simd_types(
Expand All @@ -862,7 +861,7 @@ fn generic_simd_intrinsic(

require_simd!(ret_ty, "return");

let out_len = ret_ty.simd_size(tcx);
let (out_len, out_ty) = ret_ty.simd_size_and_type(bx.tcx());
require!(
out_len == n,
"expected return type of length {}, found `{}` with length {}",
Expand All @@ -871,13 +870,13 @@ fn generic_simd_intrinsic(
out_len
);
require!(
in_elem == ret_ty.simd_type(tcx),
in_elem == out_ty,
"expected return element type `{}` (element of input `{}`), \
found `{}` with element type `{}`",
found `{}` with element type `{}`",
in_elem,
in_ty,
ret_ty,
ret_ty.simd_type(tcx)
out_ty
);

let total_len = u128::from(in_len) * 2;
Expand Down Expand Up @@ -946,7 +945,7 @@ fn generic_simd_intrinsic(
let m_elem_ty = in_elem;
let m_len = in_len;
require_simd!(arg_tys[1], "argument");
let v_len = arg_tys[1].simd_size(tcx);
let (v_len, _) = arg_tys[1].simd_size_and_type(bx.tcx());
require!(
m_len == v_len,
"mismatched lengths: mask length `{}` != other vector length `{}`",
Expand Down Expand Up @@ -1173,25 +1172,27 @@ fn generic_simd_intrinsic(
require_simd!(ret_ty, "return");

// Of the same length:
let (out_len, _) = arg_tys[1].simd_size_and_type(bx.tcx());
let (out_len2, _) = arg_tys[2].simd_size_and_type(bx.tcx());
require!(
in_len == arg_tys[1].simd_size(tcx),
in_len == out_len,
"expected {} argument with length {} (same as input type `{}`), \
found `{}` with length {}",
found `{}` with length {}",
"second",
in_len,
in_ty,
arg_tys[1],
arg_tys[1].simd_size(tcx)
out_len
);
require!(
in_len == arg_tys[2].simd_size(tcx),
in_len == out_len2,
"expected {} argument with length {} (same as input type `{}`), \
found `{}` with length {}",
found `{}` with length {}",
"third",
in_len,
in_ty,
arg_tys[2],
arg_tys[2].simd_size(tcx)
out_len2
);

// The return type must match the first argument type
Expand All @@ -1215,39 +1216,40 @@ fn generic_simd_intrinsic(

// The second argument must be a simd vector with an element type that's a pointer
// to the element type of the first argument
let (pointer_count, underlying_ty) = match arg_tys[1].simd_type(tcx).kind() {
ty::RawPtr(p) if p.ty == in_elem => {
(ptr_count(arg_tys[1].simd_type(tcx)), non_ptr(arg_tys[1].simd_type(tcx)))
}
let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
let (pointer_count, underlying_ty) = match element_ty1.kind() {
ty::RawPtr(p) if p.ty == in_elem => (ptr_count(element_ty1), non_ptr(element_ty1)),
_ => {
require!(
false,
"expected element type `{}` of second argument `{}` \
to be a pointer to the element type `{}` of the first \
argument `{}`, found `{}` != `*_ {}`",
arg_tys[1].simd_type(tcx),
to be a pointer to the element type `{}` of the first \
argument `{}`, found `{}` != `*_ {}`",
element_ty1,
arg_tys[1],
in_elem,
in_ty,
arg_tys[1].simd_type(tcx),
element_ty1,
in_elem
);
unreachable!();
}
};
assert!(pointer_count > 0);
assert_eq!(pointer_count - 1, ptr_count(arg_tys[0].simd_type(tcx)));
assert_eq!(underlying_ty, non_ptr(arg_tys[0].simd_type(tcx)));
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
assert_eq!(underlying_ty, non_ptr(element_ty0));

// The element type of the third argument must be a signed integer type of any width:
match arg_tys[2].simd_type(tcx).kind() {
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
match element_ty2.kind() {
ty::Int(_) => (),
_ => {
require!(
false,
"expected element type `{}` of third argument `{}` \
to be a signed integer type",
arg_tys[2].simd_type(tcx),
element_ty2,
arg_tys[2]
);
}
Expand Down Expand Up @@ -1299,25 +1301,27 @@ fn generic_simd_intrinsic(
require_simd!(arg_tys[2], "third");

// Of the same length:
let (element_len1, _) = arg_tys[1].simd_size_and_type(bx.tcx());
let (element_len2, _) = arg_tys[2].simd_size_and_type(bx.tcx());
require!(
in_len == arg_tys[1].simd_size(tcx),
in_len == element_len1,
"expected {} argument with length {} (same as input type `{}`), \
found `{}` with length {}",
found `{}` with length {}",
"second",
in_len,
in_ty,
arg_tys[1],
arg_tys[1].simd_size(tcx)
element_len1
);
require!(
in_len == arg_tys[2].simd_size(tcx),
in_len == element_len2,
"expected {} argument with length {} (same as input type `{}`), \
found `{}` with length {}",
found `{}` with length {}",
"third",
in_len,
in_ty,
arg_tys[2],
arg_tys[2].simd_size(tcx)
element_len2
);

// This counts how many pointers
Expand All @@ -1338,39 +1342,42 @@ fn generic_simd_intrinsic(

// The second argument must be a simd vector with an element type that's a pointer
// to the element type of the first argument
let (pointer_count, underlying_ty) = match arg_tys[1].simd_type(tcx).kind() {
let (_, element_ty0) = arg_tys[0].simd_size_and_type(bx.tcx());
let (_, element_ty1) = arg_tys[1].simd_size_and_type(bx.tcx());
let (_, element_ty2) = arg_tys[2].simd_size_and_type(bx.tcx());
let (pointer_count, underlying_ty) = match element_ty1.kind() {
ty::RawPtr(p) if p.ty == in_elem && p.mutbl == hir::Mutability::Mut => {
(ptr_count(arg_tys[1].simd_type(tcx)), non_ptr(arg_tys[1].simd_type(tcx)))
(ptr_count(element_ty1), non_ptr(element_ty1))
}
_ => {
require!(
false,
"expected element type `{}` of second argument `{}` \
to be a pointer to the element type `{}` of the first \
argument `{}`, found `{}` != `*mut {}`",
arg_tys[1].simd_type(tcx),
to be a pointer to the element type `{}` of the first \
argument `{}`, found `{}` != `*mut {}`",
element_ty1,
arg_tys[1],
in_elem,
in_ty,
arg_tys[1].simd_type(tcx),
element_ty1,
in_elem
);
unreachable!();
}
};
assert!(pointer_count > 0);
assert_eq!(pointer_count - 1, ptr_count(arg_tys[0].simd_type(tcx)));
assert_eq!(underlying_ty, non_ptr(arg_tys[0].simd_type(tcx)));
assert_eq!(pointer_count - 1, ptr_count(element_ty0));
assert_eq!(underlying_ty, non_ptr(element_ty0));

// The element type of the third argument must be a signed integer type of any width:
match arg_tys[2].simd_type(tcx).kind() {
match element_ty2.kind() {
ty::Int(_) => (),
_ => {
require!(
false,
"expected element type `{}` of third argument `{}` \
to be a signed integer type",
arg_tys[2].simd_type(tcx),
be a signed integer type",
element_ty2,
arg_tys[2]
);
}
Expand Down Expand Up @@ -1567,7 +1574,7 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,

if name == sym::simd_cast {
require_simd!(ret_ty, "return");
let out_len = ret_ty.simd_size(tcx);
let (out_len, out_elem) = ret_ty.simd_size_and_type(bx.tcx());
require!(
in_len == out_len,
"expected return type with length {} (same as input type `{}`), \
Expand All @@ -1578,8 +1585,6 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
out_len
);
// casting cares about nominal type, not just structural type
let out_elem = ret_ty.simd_type(tcx);

if in_elem == out_elem {
return Ok(args[0].immediate());
}
Expand Down Expand Up @@ -1695,7 +1700,7 @@ unsupported {} from `{}` with element `{}` of size `{}` to `{}`"#,
return_error!(
"expected element type `{}` of vector type `{}` \
to be a signed or unsigned integer type",
arg_tys[0].simd_type(tcx),
arg_tys[0].simd_size_and_type(bx.tcx()).1,
arg_tys[0]
);
}
Expand Down
Loading

0 comments on commit 80c3b55

Please sign in to comment.