Skip to content

Commit

Permalink
Merge branch 'fix_ptx_kernel_abi' into HEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
bjorn3 committed Dec 6, 2024
2 parents e1f75dd + 0a44f67 commit f456ebf
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 44 deletions.
44 changes: 30 additions & 14 deletions compiler/rustc_target/src/callconv/nvptx64.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{ArgAttribute, ArgAttributes, ArgExtension, CastTarget};
use crate::abi::call::{ArgAbi, FnAbi, PassMode, Reg, Size, Uniform};
use crate::abi::call::{ArgAbi, FnAbi, Reg, Size, Uniform};
use crate::abi::{HasDataLayout, TyAbiInterface};

fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
Expand Down Expand Up @@ -53,21 +53,37 @@ where
Ty: TyAbiInterface<'a, C> + Copy,
C: HasDataLayout,
{
if matches!(arg.mode, PassMode::Pair(..)) && (arg.layout.is_adt() || arg.layout.is_tuple()) {
let align_bytes = arg.layout.align.abi.bytes();
match arg.mode {
super::PassMode::Ignore | super::PassMode::Direct(_) => return,
super::PassMode::Pair(_, _) => {}
super::PassMode::Cast { .. } => unreachable!(),
super::PassMode::Indirect { .. } => {}
}

// FIXME only allow structs and wide pointers here
// panic!(
// "`extern \"ptx-kernel\"` doesn't allow passing types other than primitives and structs"
// );

let align_bytes = arg.layout.align.abi.bytes();

let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
arg.cast_to(Uniform::new(unit, Size::from_bytes(2 * align_bytes)));
let unit = match align_bytes {
1 => Reg::i8(),
2 => Reg::i16(),
4 => Reg::i32(),
8 => Reg::i64(),
16 => Reg::i128(),
_ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
};
if arg.layout.size.bytes() / align_bytes == 1 {
// Make sure we pass the struct as array at the LLVM IR level and not as a single integer.
arg.cast_to(CastTarget {
prefix: [Some(unit), None, None, None, None, None, None, None],
rest: Uniform::new(unit, Size::ZERO),
attrs: ArgAttributes::new(),
});
} else {
// FIXME: find a better way to do this. See https://github.com/rust-lang/rust/issues/117271.
arg.make_direct_deprecated();
arg.cast_to(Uniform::new(unit, arg.layout.size));
}
}

Expand Down
31 changes: 18 additions & 13 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,20 +473,25 @@ fn fn_abi_sanity_check<'tcx>(
// This really shouldn't happen even for sized aggregates, since
// `immediate_llvm_type` will use `layout.fields` to turn this Rust type into an
// LLVM type. This means all sorts of Rust type details leak into the ABI.
// However wasm sadly *does* currently use this mode so we have to allow it --
// but we absolutely shouldn't let any more targets do that.
// (Also see <https://github.com/rust-lang/rust/issues/115666>.)
// However wasm sadly *does* currently use this mode for it's "C" ABI so we
// have to allow it -- but we absolutely shouldn't let any more targets do
// that. (Also see <https://github.com/rust-lang/rust/issues/115666>.)
//
// The unstable abi `PtxKernel` also uses Direct for now.
// It needs to switch to something else before stabilization can happen.
// (See issue: https://github.com/rust-lang/rust/issues/117271)
assert!(
matches!(&*tcx.sess.target.arch, "wasm32" | "wasm64")
|| matches!(spec_abi, ExternAbi::PtxKernel | ExternAbi::Unadjusted),
"`PassMode::Direct` for aggregates only allowed for \"unadjusted\" and \"ptx-kernel\" functions and on wasm\n\
Problematic type: {:#?}",
arg.layout,
);
// The unadjusted ABI also uses Direct for all args and is ill-specified,
// but unfortunately we need it for calling certain LLVM intrinsics.

match spec_abi {
ExternAbi::Unadjusted => {}
ExternAbi::C { unwind: _ }
if matches!(&*tcx.sess.target.arch, "wasm32" | "wasm64") => {}
_ => {
panic!(
"`PassMode::Direct` for aggregates only allowed for \"unadjusted\" functions and on wasm\n\
Problematic type: {:#?}",
arg.layout,
);
}
}
}
}
}
Expand Down
18 changes: 1 addition & 17 deletions tests/assembly/nvptx-kernel-abi/nvptx-kernel-args-abi-v7.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,22 +242,6 @@ pub unsafe extern "ptx-kernel" fn f_float_array_arg(_a: [f32; 5]) {}
//pub unsafe extern "ptx-kernel" fn f_u128_array_arg(_a: [u128; 5]) {}

// CHECK: .visible .entry f_u32_slice_arg(
// CHECK: .param .u64 f_u32_slice_arg_param_0
// CHECK: .param .u64 f_u32_slice_arg_param_1
// CHECK: .param .align 8 .b8 f_u32_slice_arg_param_0[16]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_u32_slice_arg(_a: &[u32]) {}

// CHECK: .visible .entry f_tuple_u8_u8_arg(
// CHECK: .param .align 1 .b8 f_tuple_u8_u8_arg_param_0[2]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_arg(_a: (u8, u8)) {}

// CHECK: .visible .entry f_tuple_u32_u32_arg(
// CHECK: .param .align 4 .b8 f_tuple_u32_u32_arg_param_0[8]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_tuple_u32_u32_arg(_a: (u32, u32)) {}

// CHECK: .visible .entry f_tuple_u8_u8_u32_arg(
// CHECK: .param .align 4 .b8 f_tuple_u8_u8_u32_arg_param_0[8]
#[no_mangle]
pub unsafe extern "ptx-kernel" fn f_tuple_u8_u8_u32_arg(_a: (u8, u8, u32)) {}

0 comments on commit f456ebf

Please sign in to comment.