diff --git a/compiler/rustc_codegen_llvm/src/base.rs b/compiler/rustc_codegen_llvm/src/base.rs index d00e70638b45a..a75734d0271a5 100644 --- a/compiler/rustc_codegen_llvm/src/base.rs +++ b/compiler/rustc_codegen_llvm/src/base.rs @@ -99,7 +99,7 @@ pub(crate) fn compile_codegen_unit( .unstable_opts .offload .iter() - .any(|o| matches!(o, Offload::Host(_) | Offload::Test)); + .any(|o| matches!(o, Offload::Host(_) | Offload::Test | Offload::Args)); if has_host_offload && !cx.sess().target.is_like_gpu { cx.offload_globals.replace(Some(OffloadGlobals::declare(&cx))); } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 7a49ba64029e5..5da235094f3dc 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -173,19 +173,6 @@ impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { load } } - - fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) { - unsafe { - llvm::LLVMRustBuildMemSet( - self.llbuilder, - ptr, - align.bytes() as c_uint, - fill_byte, - size, - false, - ); - } - } } /// Empty string, to be used where LLVM expects an instruction name, indicating diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 084d40317ba89..ef98312a38958 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -8,7 +8,7 @@ use rustc_middle::bug; use rustc_middle::ty::offload_meta::OffloadMetadata; use crate::builder::Builder; -use crate::common::CodegenCx; +use crate::common::{AsCCharPtr, CodegenCx}; use crate::llvm::AttributePlace::Function; use crate::llvm::{self, Linkage, Type, Value}; use crate::{SimpleCx, attributes}; @@ -18,8 +18,6 @@ pub(crate) struct OffloadGlobals<'ll> { pub launcher_fn: &'ll llvm::Value, pub launcher_ty: &'ll llvm::Type, - pub bin_desc: &'ll llvm::Type, - pub kernel_args_ty: &'ll llvm::Type, pub offload_entry_ty: &'ll llvm::Type, @@ -30,8 +28,6 @@ pub(crate) struct OffloadGlobals<'ll> { pub ident_t_global: &'ll llvm::Value, - pub register_lib: &'ll llvm::Value, - pub unregister_lib: &'ll llvm::Value, pub init_rtls: &'ll llvm::Value, } @@ -43,15 +39,6 @@ impl<'ll> OffloadGlobals<'ll> { let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx); let ident_t_global = generate_at_one(cx); - let tptr = cx.type_ptr(); - let ti32 = cx.type_i32(); - let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr]; - let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc"); - cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false); - - let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void()); - let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl); - let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl); let init_ty = cx.type_func(&[], cx.type_void()); let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty); @@ -62,20 +49,83 @@ impl<'ll> OffloadGlobals<'ll> { OffloadGlobals { launcher_fn, launcher_ty, - bin_desc, kernel_args_ty, offload_entry_ty, begin_mapper, end_mapper, mapper_fn_ty, ident_t_global, - register_lib, - unregister_lib, init_rtls, } } } +// We need to register offload before using it. We also should unregister it once we are done, for +// good measures. Previously we have done so before and after each individual offload intrinsic +// call, but that comes at a performance cost. The repeated (un)register calls might also confuse +// the LLVM ompOpt pass, which tries to move operations to a better location. The easiest solution, +// which we copy from clang, is to just have those two calls once, in the global ctor/dtor section +// of the final binary. +pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) { + // First we check quickly whether we already have done our setup, in which case we return early. + // Shouldn't be needed for correctness. + if cx.get_function("__tgt_register_lib").is_some() { + return; + } + + let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void()); + let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl); + let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl); + + let ptr_null = cx.const_null(cx.type_ptr()); + let const_struct = cx.const_struct(&[cx.get_const_i32(0), ptr_null, ptr_null, ptr_null], false); + let omp_descriptor = + add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage); + // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_llvm_offload_entries, ptr @__stop_llvm_offload_entries } + // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 0, ptr null, ptr null, ptr null } + + let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32()); + let atexit_fn = declare_offload_fn(cx, "atexit", atexit); + + let desc_ty = cx.type_func(&[], cx.type_void()); + let reg_name = ".omp_offloading.descriptor_reg"; + let unreg_name = ".omp_offloading.descriptor_unreg"; + let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty); + let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty); + llvm::set_linkage(desc_reg_fn, InternalLinkage); + llvm::set_linkage(desc_unreg_fn, InternalLinkage); + llvm::set_section(desc_reg_fn, c".text.startup"); + llvm::set_section(desc_unreg_fn, c".text.startup"); + + // define internal void @.omp_offloading.descriptor_reg() section ".text.startup" { + // entry: + // call void @__tgt_register_lib(ptr @.omp_offloading.descriptor) + // %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg) + // ret void + // } + let bb = Builder::append_block(cx, desc_reg_fn, "entry"); + let mut a = Builder::build(cx, bb); + a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None); + a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None); + a.ret_void(); + + // define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" { + // entry: + // call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor) + // ret void + // } + let bb = Builder::append_block(cx, desc_unreg_fn, "entry"); + let mut a = Builder::build(cx, bb); + a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None); + a.ret_void(); + + // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }] + let args = vec![cx.get_const_i32(101), desc_reg_fn, ptr_null]; + let const_struct = cx.const_struct(&args, false); + let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]); + add_global(cx, "llvm.global_ctors", arr, AppendingLinkage); +} + pub(crate) struct OffloadKernelDims<'ll> { num_workgroups: &'ll Value, threads_per_block: &'ll Value, @@ -292,8 +342,8 @@ impl KernelArgsTy { pub(crate) struct OffloadKernelGlobals<'ll> { pub offload_sizes: &'ll llvm::Value, pub memtransfer_types: &'ll llvm::Value, - pub region_id: &'ll llvm::Value, - pub offload_entry: &'ll llvm::Value, + pub region_id: Option<&'ll llvm::Value>, + pub offload_entry: Option<&'ll llvm::Value>, } fn gen_tgt_data_mappers<'ll>( @@ -364,6 +414,7 @@ pub(crate) fn gen_define_handling<'ll>( types: &[&'ll Type], symbol: String, offload_globals: &OffloadGlobals<'ll>, + host: bool, ) -> OffloadKernelGlobals<'ll> { if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) { return *entry; @@ -394,33 +445,38 @@ pub(crate) fn gen_define_handling<'ll>( // Next: For each function, generate these three entries. A weak constant, // the llvm.rodata entry name, and the llvm_offload_entries value - let name = format!(".{symbol}.region_id"); - let initializer = cx.get_const_i8(0); - let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage); - - let c_entry_name = CString::new(symbol.clone()).unwrap(); - let c_val = c_entry_name.as_bytes_with_nul(); - let offload_entry_name = format!(".offloading.entry_name.{symbol}"); - - let initializer = crate::common::bytes_in_context(cx.llcx, c_val); - let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); - llvm::set_alignment(llglobal, Align::ONE); - llvm::set_section(llglobal, c".llvm.rodata.offloading"); - - let name = format!(".offloading.entry.{symbol}"); - - // See the __tgt_offload_entry documentation above. - let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); - - let initializer = crate::common::named_struct(offload_entry_ty, &elems); - let c_name = CString::new(name).unwrap(); - let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name); - llvm::set_global_constant(offload_entry, true); - llvm::set_linkage(offload_entry, WeakAnyLinkage); - llvm::set_initializer(offload_entry, initializer); - llvm::set_alignment(offload_entry, Align::EIGHT); - let c_section_name = CString::new("llvm_offload_entries").unwrap(); - llvm::set_section(offload_entry, &c_section_name); + let (offload_entry, region_id) = if !host { + let name = format!(".{symbol}.region_id"); + let initializer = cx.get_const_i8(0); + let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage); + + let c_entry_name = CString::new(symbol.clone()).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let offload_entry_name = format!(".offloading.entry_name.{symbol}"); + + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); + llvm::set_alignment(llglobal, Align::ONE); + llvm::set_section(llglobal, c".llvm.rodata.offloading"); + + let name = format!(".offloading.entry.{symbol}"); + + // See the __tgt_offload_entry documentation above. + let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); + + let initializer = crate::common::named_struct(offload_entry_ty, &elems); + let c_name = CString::new(name).unwrap(); + let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name); + llvm::set_global_constant(offload_entry, true); + llvm::set_linkage(offload_entry, WeakAnyLinkage); + llvm::set_initializer(offload_entry, initializer); + llvm::set_alignment(offload_entry, Align::EIGHT); + let c_section_name = CString::new("llvm_offload_entries").unwrap(); + llvm::set_section(offload_entry, &c_section_name); + (Some(offload_entry), Some(region_id)) + } else { + (None, None) + }; let result = OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry }; @@ -471,20 +527,20 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( types: &[&Type], metadata: &[OffloadMetadata], offload_globals: &OffloadGlobals<'ll>, - offload_dims: &OffloadKernelDims<'ll>, + offload_dims: Option<&OffloadKernelDims<'ll>>, + host: bool, + host_llfn: &'ll Value, + host_llty: &'ll Type, ) { let cx = builder.cx; let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } = offload_data; - let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } = - offload_dims; + //let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } = + // offload_dims; let tgt_decl = offload_globals.launcher_fn; let tgt_target_kernel_ty = offload_globals.launcher_ty; - // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } - let tgt_bin_desc = offload_globals.bin_desc; - let tgt_kernel_decl = offload_globals.kernel_args_ty; let begin_mapper_decl = offload_globals.begin_mapper; let end_mapper_decl = offload_globals.end_mapper; @@ -495,7 +551,12 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( // FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these // variables from being optimized away - for val in [offload_sizes, offload_entry] { + let to_keep: &[&llvm::Value] = if let Some(offload_entry) = offload_entry { + &[offload_sizes, offload_entry] + } else { + &[offload_sizes] + }; + for val in to_keep { unsafe { let dummy = llvm::LLVMBuildLoad2( &builder.llbuilder, @@ -508,12 +569,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( } // Step 0) - // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } - // %6 = alloca %struct.__tgt_bin_desc, align 8 unsafe { llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn()); } - let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc"); let ty = cx.type_array(cx.type_ptr(), num_args); // Baseptr are just the input pointer to the kernel, stored in a local alloca @@ -531,7 +589,6 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( unsafe { llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb); } - builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); // Now we allocate once per function param, a copy to be passed to one of our maps. let mut vals = vec![]; @@ -543,15 +600,9 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( geps.push(gep); } - let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void()); - let register_lib_decl = offload_globals.register_lib; - let unregister_lib_decl = offload_globals.unregister_lib; let init_ty = cx.type_func(&[], cx.type_void()); let init_rtls_decl = offload_globals.init_rtls; - // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function. - // call void @__tgt_register_lib(ptr noundef %6) - builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None); // call void @__tgt_init_all_rtls() builder.call(init_ty, None, None, init_rtls_decl, &[], None, None); @@ -615,27 +666,53 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( num_args, s_ident_t, ); - let values = - KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims); - - // Step 3) - // Here we fill the KernelArgsTy, see the documentation above - for (i, value) in values.iter().enumerate() { - let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); - builder.store(value.1, ptr, value.0); - } - let args = vec![ - s_ident_t, - // FIXME(offload) give users a way to select which GPU to use. - cx.get_const_i64(u64::MAX), // MAX == -1. - num_workgroups, - threads_per_block, - region_id, - a5, - ]; - builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None); - // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) + if host { + let fn_name = "omp_get_mapped_ptr"; + let ty2: &'ll Type = cx.type_func(&[cx.type_ptr(), cx.type_i32()], cx.type_ptr()); + let mapper_fn = unsafe { + llvm::LLVMRustGetOrInsertFunction( + builder.llmod, + fn_name.as_c_char_ptr(), + fn_name.len(), + ty2, + ) + }; + + let mut device_vals = Vec::with_capacity(vals.len()); + let device_num = cx.get_const_i32(0); + for arg in vals { + let device_arg = + builder.call(ty2, None, None, mapper_fn, &[arg, device_num], None, None); + device_vals.push(device_arg); + } + builder.call(host_llty, None, None, host_llfn, &device_vals, None, None); + } else { + let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } = + offload_dims.unwrap(); + let values = + KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims); + + // Step 3) + // Here we fill the KernelArgsTy, see the documentation above + for (i, value) in values.iter().enumerate() { + let ptr = + builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]); + builder.store(value.1, ptr, value.0); + } + // In the host case, we know by construction that this variable is set. + let args = vec![ + s_ident_t, + // FIXME(offload) give users a way to select which GPU to use. + cx.get_const_i64(u64::MAX), // MAX == -1. + num_workgroups, + threads_per_block, + region_id.unwrap(), + a5, + ]; + // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) + builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None); + } // Step 4) let geps = get_geps(builder, ty, ty2, a1, a2, a4); @@ -648,6 +725,4 @@ pub(crate) fn gen_call_handling<'ll, 'tcx>( num_args, s_ident_t, ); - - builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None); } diff --git a/compiler/rustc_codegen_llvm/src/common.rs b/compiler/rustc_codegen_llvm/src/common.rs index b0cf9925019d2..f2261ab79340f 100644 --- a/compiler/rustc_codegen_llvm/src/common.rs +++ b/compiler/rustc_codegen_llvm/src/common.rs @@ -124,6 +124,10 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value { unsafe { llvm::LLVMConstNull(t) } } + + pub(crate) fn const_struct(&self, elts: &[&'ll Value], packed: bool) -> &'ll Value { + struct_in_context(self.llcx(), elts, packed) + } } impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 565db7d298bc9..4ca1a9268ce3f 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -30,7 +30,9 @@ use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; -use crate::builder::gpu_offload::{OffloadKernelDims, gen_call_handling, gen_define_handling}; +use crate::builder::gpu_offload::{ + OffloadKernelDims, gen_call_handling, gen_define_handling, register_offload, +}; use crate::context::CodegenCx; use crate::declare::declare_raw_fn; use crate::errors::{ @@ -215,7 +217,19 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutFatLTO); } - codegen_offload(self, tcx, instance, args); + codegen_offload(self, tcx, instance, args, false); + return Ok(()); + } + sym::offload_args => { + if tcx.sess.opts.unstable_opts.offload.is_empty() { + let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutEnable); + } + + if tcx.sess.lto() != rustc_session::config::Lto::Fat { + let _ = tcx.dcx().emit_almost_fatal(OffloadWithoutFatLTO); + } + + codegen_offload(self, tcx, instance, args, true); return Ok(()); } sym::is_val_statically_known => { @@ -1368,6 +1382,7 @@ fn codegen_offload<'ll, 'tcx>( tcx: TyCtxt<'tcx>, instance: ty::Instance<'tcx>, args: &[OperandRef<'tcx, &'ll Value>], + host: bool, ) { let cx = bx.cx; let fn_args = instance.args; @@ -1390,8 +1405,18 @@ fn codegen_offload<'ll, 'tcx>( } }; - let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]); - let args = get_args_from_tuple(bx, args[3], fn_target); + let llfn = cx.get_fn(fn_target); + let (offload_dims, args) = if host { + // If we only map arguments to the gpu and otherwise work on host code, there is no need to + // handle block or thread dimensions. + let args = get_args_from_tuple(bx, args[1], fn_target); + (None, args) + } else { + let offload_dims = OffloadKernelDims::from_operands(bx, &args[1], &args[2]); + let args = get_args_from_tuple(bx, args[3], fn_target); + (Some(offload_dims), args) + }; + let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE); let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder(); @@ -1409,8 +1434,25 @@ fn codegen_offload<'ll, 'tcx>( return; } }; - let offload_data = gen_define_handling(&cx, &metadata, &types, target_symbol, offload_globals); - gen_call_handling(bx, &offload_data, &args, &types, &metadata, offload_globals, &offload_dims); + register_offload(cx); + let instance = rustc_middle::ty::Instance::mono(tcx, fn_target.def_id()); + let fn_abi = cx.fn_abi_of_instance(instance, tcx.mk_type_list(&[])); + let host_fn_ty = fn_abi.llvm_type(cx); + + let offload_data = + gen_define_handling(&cx, &metadata, &types, target_symbol.clone(), offload_globals, host); + gen_call_handling( + bx, + &offload_data, + &args, + &types, + &metadata, + offload_globals, + offload_dims.as_ref(), + host, + llfn, + host_fn_ty, + ); } fn get_args_from_tuple<'ll, 'tcx>( diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index d3d167f6e2544..bae10c2799038 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -165,6 +165,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::mul_with_overflow | sym::needs_drop | sym::offload + | sym::offload_args | sym::offset_of | sym::overflow_checks | sym::powf16 @@ -338,6 +339,7 @@ pub(crate) fn check_intrinsic_type( ], param(2), ), + sym::offload_args => (3, 0, vec![param(0), param(1)], param(2)), sym::offset => (2, 0, vec![param(0), param(1)], param(0)), sym::arith_offset => ( 1, diff --git a/compiler/rustc_monomorphize/src/collector/autodiff.rs b/compiler/rustc_monomorphize/src/collector/autodiff.rs index e3646596e75e6..cb79d0700c72f 100644 --- a/compiler/rustc_monomorphize/src/collector/autodiff.rs +++ b/compiler/rustc_monomorphize/src/collector/autodiff.rs @@ -15,7 +15,10 @@ pub(crate) fn collect_autodiff_fn<'tcx>( intrinsic: IntrinsicDef, output: &mut MonoItems<'tcx>, ) { - if intrinsic.name != rustc_span::sym::autodiff { + if intrinsic.name != rustc_span::sym::autodiff + && intrinsic.name != rustc_span::sym::offload + && intrinsic.name != rustc_span::sym::offload_args + { return; }; diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index f8b9ae040568a..805d3e556ea75 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -196,6 +196,8 @@ pub enum Offload { Device, /// Second step in the offload pipeline, generates the host code to call kernels. Host(String), + /// We only map arguments, but still call host (=CPU) code. + Args, /// Test is similar to Host, but allows testing without a device artifact. Test, } diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 39213a002383d..275507fc63139 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -796,7 +796,7 @@ mod desc { "a comma-separated list of strings, with elements beginning with + or -"; pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintTAFn`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`, `NoTT`"; pub(crate) const parse_offload: &str = - "a comma separated list of settings: `Host=`, `Device`, `Test`"; + "a comma separated list of settings: `Host=`, `Device`, `Test`, `Args`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1480,6 +1480,13 @@ pub mod parse { } Offload::Test } + "Args" => { + if let Some(_) = arg { + // Args does not accept a value + return false; + } + Offload::Args + } _ => { // FIXME(ZuseZ4): print an error saying which value is not recognized return false; @@ -2546,10 +2553,12 @@ options! { normalize_docs: bool = (false, parse_bool, [TRACKED], "normalize associated items in rustdoc when generating documentation"), offload: Vec = (Vec::new(), parse_offload, [TRACKED], - "a list of offload flags to enable - Mandatory setting: - `=Enable` - Currently the only option available"), + "a list of offload flags to enable: + `=Device` + `=Host(path)` + `=Test` + `=Args` + Multiple options can be combined with commas."), on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED], "behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"), osx_rpath_install_name: bool = (false, parse_bool, [TRACKED], diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index d21580c16db20..7fbeb66384052 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1609,6 +1609,7 @@ symbols! { of, off, offload, + offload_args, offset, offset_of, offset_of_enum, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 8d112b4c5d187..4109c72d2cae4 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3384,6 +3384,67 @@ pub const fn copysignf128(x: f128, y: f128) -> f128; #[rustc_intrinsic] pub const fn autodiff(f: F, df: G, args: T) -> R; +/// This intrinsic maps the given args from the Host(=CPU) to a GPU device. It then calls the given +/// function. Unlike the full `offload` intrinsic, this intrinsic expects a host function, in which +/// we will replace all usages of the given host args with their device version. This enables +/// support for various GPU libraries like `cuBLAS`, `cuDNN`, or `rocBLAS`, which *must* be called +/// from the host, but expect a mixture of host and device arguments. +/// +/// Type Parameters: +/// - `F`: The kernel to call. Must be a function item. +/// - `T`: A tuple of arguments passed to `f`. +/// - `R`: The return type of the kernel. +/// +/// Arguments: +/// - `f`: The host function to be called. +/// - `args`: A tuple of arguments, will be mapped to the gpu and forwarded to `f`. +/// +/// Example usage (pseudocode): +/// +/// ```rust,ignore (pseudocode) +/// fn kernel(A: &[f32; 6], x: &[f32; 3], y: &mut [f64; 2]) { +/// core::intrinsics::offload_args(sgemv_wrapper, (A,x,y)) +/// } +/// +/// #[cfg(target_os = "linux")] +/// extern "C" { +/// pub fn rocblas_sgemv( +/// alpha: *const f32, +/// A: *const f32, +/// x: *const f32, +/// beta: *const f32, +/// y: *mut f32, +/// ); +/// } +/// +/// #[cfg(not(target_os = "linux"))] +/// fn sgemv_wrapper(A: &[f32; 6], x: &[f32; 3], y: &mut [f64; 2]) { +/// // rocblas expects scalars to be passed as host pointers. +/// let alpha = 1.0; +/// let beta = 1.0; +/// unsafe { +/// rocblas_sgemv( +/// // Host ptr +/// &alpha as *const f32, +/// // Replaced by device ptr +/// A.as_ptr(), +/// // Replaced by device ptr +/// x.as_ptr(), +/// // Host ptr +/// &beta as *const f32, +/// // Replaced by device ptr +/// y.as_mut_ptr() +/// ); +/// } +/// } +/// ``` +/// +/// For reference, see the Clang documentation on offloading: +/// . +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn offload_args(f: F, args: T) -> R; + /// Generates the LLVM body of a wrapper function to offload a kernel `f`. /// /// Type Parameters: diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index 27ff6f325aa0f..d0bc34ec66b20 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -2,9 +2,10 @@ //@ no-prefer-dynamic //@ needs-offload -// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to the -// kernel_1. Better documentation to what each global or variable means is available in the gpu -// offload code, or the LLVM offload documentation. +// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to +// __tgt_target_kernel, and initialize all needed variables. It also verifies some related globals. +// Better documentation to what each global or variable means is available in the gpu offload code, +// or the LLVM offload documentation. #![feature(rustc_attrs)] #![feature(core_intrinsics)] @@ -17,10 +18,8 @@ fn main() { core::hint::black_box(&x); } -#[unsafe(no_mangle)] -#[inline(never)] pub fn kernel_1(x: &mut [f32; 256]) { - core::intrinsics::offload(_kernel_1, [256, 1, 1], [32, 1, 1], (x,)) + core::intrinsics::offload(kernel_1, [256, 1, 1], [32, 1, 1], (x,)) } #[unsafe(no_mangle)] @@ -33,75 +32,75 @@ pub fn _kernel_1(x: &mut [f32; 256]) { // CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr } // CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } -// CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } // CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } -// CHECK: @anon.{{.*}}.0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 -// CHECK: @anon.{{.*}}.1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @anon.{{.*}}.0 }, align 8 +// CHECK: @anon.[[ID:.*]].0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +// CHECK: @anon.{{.*}}.1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @anon.[[ID]].0 }, align 8 -// CHECK: @.offload_sizes._kernel_1 = private unnamed_addr constant [1 x i64] [i64 1024] -// CHECK: @.offload_maptypes._kernel_1 = private unnamed_addr constant [1 x i64] [i64 35] -// CHECK: @._kernel_1.region_id = internal constant i8 0 -// CHECK: @.offloading.entry_name._kernel_1 = internal unnamed_addr constant [10 x i8] c"_kernel_1\00", section ".llvm.rodata.offloading", align 1 -// CHECK: @.offloading.entry._kernel_1 = internal constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @._kernel_1.region_id, ptr @.offloading.entry_name._kernel_1, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 +// CHECK-DAG: @.omp_offloading.descriptor = internal constant { i32, ptr, ptr, ptr } zeroinitializer +// CHECK-DAG: @llvm.global_ctors = appending constant [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }] +// CHECK-DAG: @.offload_sizes.[[K:[^ ]*kernel_1]] = private unnamed_addr constant [1 x i64] [i64 1024] +// CHECK-DAG: @.offload_maptypes.[[K]] = private unnamed_addr constant [1 x i64] [i64 35] +// CHECK-DAG: @.[[K]].region_id = internal constant i8 0 +// CHECK-DAG: @.offloading.entry_name.[[K]] = internal unnamed_addr constant [{{[0-9]+}} x i8] c"[[K]]{{\\00}}", section ".llvm.rodata.offloading", align 1 +// CHECK-DAG: @.offloading.entry.[[K]] = internal constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.[[K]].region_id, ptr @.offloading.entry_name.[[K]], i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 // CHECK: declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) -// CHECK: declare void @__tgt_register_lib(ptr) local_unnamed_addr -// CHECK: declare void @__tgt_unregister_lib(ptr) local_unnamed_addr - -// CHECK: define{{( dso_local)?}} void @main() -// CHECK-NEXT: start: -// CHECK-NEXT: %0 = alloca [8 x i8], align 8 -// CHECK-NEXT: %x = alloca [1024 x i8], align 16 -// CHECK: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x) -// CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %0) -// CHECK-NEXT: store ptr %x, ptr %0, align 8 -// CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0) -// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %0) -// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 1024, ptr nonnull %x) -// CHECK-NEXT: ret void -// CHECK-NEXT: } -// CHECK: define{{( dso_local)?}} void @kernel_1(ptr noalias noundef align 4 dereferenceable(1024) %x) +// CHECK-LABEL: define{{( dso_local)?}} void @main() // CHECK-NEXT: start: -// CHECK-NEXT: %EmptyDesc = alloca %struct.__tgt_bin_desc, align 8 +// CHECK-NEXT: %0 = alloca [8 x i8], align 8 +// CHECK-NEXT: %x = alloca [1024 x i8], align 16 // CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8 // CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8 // CHECK-NEXT: %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8 -// CHECK-NEXT: %dummy = load volatile ptr, ptr @.offload_sizes._kernel_1, align 8 -// CHECK-NEXT: %dummy1 = load volatile ptr, ptr @.offloading.entry._kernel_1, align 8 -// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %EmptyDesc, i8 0, i64 32, i1 false) -// CHECK-NEXT: call void @__tgt_register_lib(ptr nonnull %EmptyDesc) +// CHECK: %dummy = load volatile ptr, ptr @.offload_sizes.[[K]], align 8 +// CHECK-NEXT: %dummy1 = load volatile ptr, ptr @.offloading.entry.[[K]], align 8 // CHECK-NEXT: call void @__tgt_init_all_rtls() // CHECK-NEXT: store ptr %x, ptr %.offload_baseptrs, align 8 // CHECK-NEXT: store ptr %x, ptr %.offload_ptrs, align 8 // CHECK-NEXT: store i64 1024, ptr %.offload_sizes, align 8 -// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null) +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]], ptr null, ptr null) // CHECK-NEXT: store i32 3, ptr %kernel_args, align 8 -// CHECK-NEXT: %0 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 4 -// CHECK-NEXT: store i32 1, ptr %0, align 4 -// CHECK-NEXT: %1 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 8 -// CHECK-NEXT: store ptr %.offload_baseptrs, ptr %1, align 8 -// CHECK-NEXT: %2 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 16 -// CHECK-NEXT: store ptr %.offload_ptrs, ptr %2, align 8 -// CHECK-NEXT: %3 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 24 -// CHECK-NEXT: store ptr %.offload_sizes, ptr %3, align 8 -// CHECK-NEXT: %4 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 32 -// CHECK-NEXT: store ptr @.offload_maptypes._kernel_1, ptr %4, align 8 -// CHECK-NEXT: %5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40 -// CHECK-NEXT: %6 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72 -// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) %5, i8 0, i64 32, i1 false) -// CHECK-NEXT: store <4 x i32> , ptr %6, align 8 -// CHECK-NEXT: %.fca.1.gep5 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88 -// CHECK-NEXT: store i32 1, ptr %.fca.1.gep5, align 8 -// CHECK-NEXT: %.fca.2.gep7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92 -// CHECK-NEXT: store i32 1, ptr %.fca.2.gep7, align 4 -// CHECK-NEXT: %7 = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96 -// CHECK-NEXT: store i32 0, ptr %7, align 8 -// CHECK-NEXT: %8 = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @._kernel_1.region_id, ptr nonnull %kernel_args) -// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes._kernel_1, ptr null, ptr null) -// CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull %EmptyDesc) +// CHECK-NEXT: [[P4:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 4 +// CHECK-NEXT: store i32 1, ptr [[P4]], align 4 +// CHECK-NEXT: [[P8:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 8 +// CHECK-NEXT: store ptr %.offload_baseptrs, ptr [[P8]], align 8 +// CHECK-NEXT: [[P16:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 16 +// CHECK-NEXT: store ptr %.offload_ptrs, ptr [[P16]], align 8 +// CHECK-NEXT: [[P24:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 24 +// CHECK-NEXT: store ptr %.offload_sizes, ptr [[P24]], align 8 +// CHECK-NEXT: [[P32:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 32 +// CHECK-NEXT: store ptr @.offload_maptypes.[[K]], ptr [[P32]], align 8 +// CHECK-NEXT: [[P40:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 40 +// CHECK-NEXT: [[P72:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 72 +// CHECK-NEXT: call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(32) [[P40]], i8 0, i64 32, i1 false) +// CHECK-NEXT: store <4 x i32> , ptr [[P72]], align 8 +// CHECK-NEXT: [[P88:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 88 +// CHECK-NEXT: store i32 1, ptr [[P88]], align 8 +// CHECK-NEXT: [[P92:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 92 +// CHECK-NEXT: store i32 1, ptr [[P92]], align 4 +// CHECK-NEXT: [[P96:%[^ ]+]] = getelementptr inbounds nuw i8, ptr %kernel_args, i64 96 +// CHECK-NEXT: store i32 0, ptr [[P96]], align 8 +// CHECK-NEXT: {{%[^ ]+}} = call i32 @__tgt_target_kernel(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 256, i32 32, ptr nonnull @.[[K]].region_id, ptr nonnull %kernel_args) +// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr nonnull @anon.{{.*}}.1, i64 -1, i32 1, ptr nonnull %.offload_baseptrs, ptr nonnull %.offload_ptrs, ptr nonnull %.offload_sizes, ptr nonnull @.offload_maptypes.[[K]], ptr null, ptr null) +// CHECK: ret void +// CHECK-NEXT: } + +// CHECK: declare void @__tgt_register_lib(ptr) local_unnamed_addr +// CHECK: declare void @__tgt_unregister_lib(ptr) local_unnamed_addr + +// CHECK-LABEL: define internal void @.omp_offloading.descriptor_reg() section ".text.startup" { +// CHECK-NEXT: entry: +// CHECK-NEXT: call void @__tgt_register_lib(ptr nonnull @.omp_offloading.descriptor) +// CHECK-NEXT: %0 = {{tail }}call i32 @atexit(ptr nonnull @.omp_offloading.descriptor_unreg) +// CHECK-NEXT: ret void +// CHECK-NEXT: } + +// CHECK-LABEL: define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" { +// CHECK-NEXT: entry: +// CHECK-NEXT: call void @__tgt_unregister_lib(ptr nonnull @.omp_offloading.descriptor) // CHECK-NEXT: ret void // CHECK-NEXT: } diff --git a/tests/codegen-llvm/gpu_offload/offload_args.rs b/tests/codegen-llvm/gpu_offload/offload_args.rs new file mode 100644 index 0000000000000..8b0e1f15c3eec --- /dev/null +++ b/tests/codegen-llvm/gpu_offload/offload_args.rs @@ -0,0 +1,82 @@ +//@ compile-flags: -Zoffload=Args -Zno-link -Zunstable-options -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-offload + +// This test is meant to verify that we are able to map cpu argument to a device, and pass those to +// a gpu library like cuBLAS or rocblas. We don't really want to link those libraries in CI, and we +// neither want to deal with the creation or destruction of handles that those require since it's +// just noise. We do however test that we can combine host pointer (like alpha, beta) with device +// pointers (A, x, y). We also test std support while already at it. + +#![allow(internal_features, non_camel_case_types, non_snake_case)] +#![feature(rustc_attrs)] +#![feature(core_intrinsics)] + +fn main() { + let mut A: [f32; 3 * 2] = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]; + let mut x: [f32; 3] = [1.0, 1.0, 1.0]; + let mut y: [f32; 2] = [0.0, 0.0]; + for _ in 0..10 { + core::intrinsics::offload_args::<_, _, ()>(rocblas_sgemv_wrapper, (&mut A, &mut x, &mut y)); + // CHECK-LABEL: ; offload_args::main + // CHECK: call void @__tgt_target_data_begin_mapper( + // CHECK-NEXT: [[A:%.*]] = call ptr @omp_get_mapped_ptr(ptr nonnull %A, i32 0) + // CHECK-NEXT: [[X:%.*]] = call ptr @omp_get_mapped_ptr(ptr nonnull %x, i32 0) + // CHECK-NEXT: [[Y:%.*]] = call ptr @omp_get_mapped_ptr(ptr nonnull %y, i32 0) + // CHECK-NEXT: ; call offload_args::rocblas_sgemv_wrapper + // CHECK-NEXT: call {{.*}}void {{@_RNv.*rocblas_sgemv_wrapper.*}}(ptr [[A]], ptr [[X]], ptr [[Y]]) + // CHECK-NEXT: call void @__tgt_target_data_end_mapper( + } + println!("{:?}", y); +} + +unsafe extern "C" { + pub fn fake_gpublas_sgemv( + m: i32, + n: i32, + alpha: *const f32, + A: *const f32, + lda: i32, + x: *const f32, + incx: i32, + beta: *const f32, + y: *mut f32, + incy: i32, + ) -> i32; +} + +#[inline(never)] +pub fn rocblas_sgemv_wrapper(A: &mut [f32; 6], x: &mut [f32; 3], y: &mut [f32; 2]) -> () { + let m: i32 = 2; + let n: i32 = 3; + let incx: i32 = 1; + let incy: i32 = 1; + let lda = m; + // those two by default should be host ptr: + let alpha: f32 = 1.0; + let beta: f32 = 1.0; + + // CHECK-LABEL: ; offload_args::rocblas_sgemv_wrapper + // CHECK: define {{.*}}void {{.*}}rocblas_sgemv_wrapper{{.*}}(ptr{{.*}} %A, ptr{{.*}} %x, ptr{{.*}} %y) + // CHECK-DAG: %alpha = alloca [4 x i8] + // CHECK-DAG: %beta = alloca [4 x i8] + // CHECK-DAG: store float 1.000000e+00, ptr %alpha + // CHECK-DAG: store float 1.000000e+00, ptr %beta + // CHECK: call noundef i32 @fake_gpublas_sgemv(i32 noundef 2, i32 noundef 3, ptr{{.*}} %alpha, ptr{{.*}} %A, i32 noundef 2, ptr{{.*}} %x, i32 noundef 1, ptr{{.*}} %beta, ptr{{.*}} %y, i32 noundef 1) + + unsafe { + let st_res = fake_gpublas_sgemv( + m, + n, + &alpha as *const f32, + A.as_ptr(), + lda, + x.as_ptr(), + incx, + &beta as *const f32, + y.as_mut_ptr(), + incy, + ); + assert_eq!(st_res, 1); + }; +}