From 3debe14ca3aa052bdb0ed84c0dc5f52861b48f87 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 18 Jun 2025 15:25:29 -0700 Subject: [PATCH 1/5] make more builder functions generic --- compiler/rustc_codegen_llvm/src/declare.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/declare.rs b/compiler/rustc_codegen_llvm/src/declare.rs index 2419ec1f88854..47b37156d689d 100644 --- a/compiler/rustc_codegen_llvm/src/declare.rs +++ b/compiler/rustc_codegen_llvm/src/declare.rs @@ -215,7 +215,9 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { llfn } +} +impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { /// Declare a global with an intention to define it. /// /// Use this function when you intend to define a global. This function will @@ -234,13 +236,13 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> { /// /// Use this function when you intend to define a global without a name. pub(crate) fn define_private_global(&self, ty: &'ll Type) -> &'ll Value { - unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod, ty) } + unsafe { llvm::LLVMRustInsertPrivateGlobal(self.llmod(), ty) } } /// Gets declared value by name. pub(crate) fn get_declared_value(&self, name: &str) -> Option<&'ll Value> { debug!("get_declared_value(name={:?})", name); - unsafe { llvm::LLVMRustGetNamedValue(self.llmod, name.as_c_char_ptr(), name.len()) } + unsafe { llvm::LLVMRustGetNamedValue(self.llmod(), name.as_c_char_ptr(), name.len()) } } /// Gets defined or externally defined (AvailableExternally linkage) value by From 22a499fc5f1d9bb2d9fb5e63de40607b615a7960 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 18 Jun 2025 15:29:43 -0700 Subject: [PATCH 2/5] add -Zoffload=Enable flag, to enable gpu (host) code generation --- compiler/rustc_codegen_ssa/src/back/write.rs | 2 ++ compiler/rustc_session/src/config.rs | 10 +++++++- compiler/rustc_session/src/options.rs | 27 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index bbf9cceef2a00..9313b39a2b100 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -120,6 +120,7 @@ pub struct ModuleConfig { pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, pub autodiff: Vec, + pub offload: Vec, } impl ModuleConfig { @@ -268,6 +269,7 @@ impl ModuleConfig { emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), + offload: if_regular!(sess.opts.unstable_opts.offload.clone(), vec![]), } } diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 04ca0b75c3113..13335f83959af 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -226,6 +226,13 @@ pub enum CoverageLevel { Mcdc, } +// The different settings that the `-Z offload` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum Offload { + /// Enable the llvm offload pipeline + Enable, +} + /// The different settings that the `-Z autodiff` flag can have. #[derive(Clone, Copy, PartialEq, Hash, Debug)] pub enum AutoDiff { @@ -3073,7 +3080,7 @@ pub(crate) mod dep_tracking { AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn, InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail, - LtoCli, MirStripDebugInfo, NextSolverConfig, OomStrategy, OptLevel, OutFileName, + LtoCli, MirStripDebugInfo, NextSolverConfig, Offload, OomStrategy, OptLevel, OutFileName, OutputType, OutputTypes, PatchableFunctionEntry, Polonius, RemapPathScopeComponents, ResolveDocLinks, SourceFileHashAlgorithm, SplitDwarfKind, SwitchWithOptPath, SymbolManglingVersion, WasiExecModel, @@ -3120,6 +3127,7 @@ pub(crate) mod dep_tracking { impl_dep_tracking_hash_via_hash!( (), AutoDiff, + Offload, bool, usize, NonZero, diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 9ca405333f43d..aaee2020c1ca4 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -720,6 +720,7 @@ mod desc { pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; pub(crate) const parse_autodiff: &str = "a comma separated list of settings: `Enable`, `PrintSteps`, `PrintTA`, `PrintAA`, `PrintPerf`, `PrintModBefore`, `PrintModAfter`, `PrintModFinal`, `PrintPasses`, `NoPostopt`, `LooseTypes`, `Inline`"; + pub(crate) const parse_offload: &str = "a comma separated list of settings: `Enable`"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -1351,6 +1352,27 @@ pub mod parse { } } + pub(crate) fn parse_offload(slot: &mut Vec, v: Option<&str>) -> bool { + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "Enable" => Offload::Enable, + _ => { + // FIXME(ZuseZ4): print an error saying which value is not recognized + return false; + } + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_autodiff(slot: &mut Vec, v: Option<&str>) -> bool { let Some(v) = v else { *slot = vec![]; @@ -2378,6 +2400,11 @@ options! { "do not use unique names for text and data sections when -Z function-sections is used"), normalize_docs: bool = (false, parse_bool, [TRACKED], "normalize associated items in rustdoc when generating documentation"), + offload: Vec = (Vec::new(), parse_offload, [TRACKED], + "a list of offload flags to enable + Mandatory setting: + `=Enable` + Currently the only option available"), on_broken_pipe: OnBrokenPipe = (OnBrokenPipe::Default, parse_on_broken_pipe, [TRACKED], "behavior of std::io::ErrorKind::BrokenPipe (SIGPIPE)"), oom: OomStrategy = (OomStrategy::Abort, parse_oom_strategy, [TRACKED], From 3dc0c3b4b998159c5756f19353c720a555ee53c1 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 18 Jun 2025 15:40:30 -0700 Subject: [PATCH 3/5] various wrappers --- compiler/rustc_codegen_llvm/src/builder.rs | 64 +++++++++++++++++++ compiler/rustc_codegen_llvm/src/common.rs | 9 +++ compiler/rustc_codegen_llvm/src/context.rs | 16 +++++ .../rustc_codegen_llvm/src/llvm/enzyme_ffi.rs | 10 ++- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 19 +++++- .../rustc_llvm/llvm-wrapper/RustWrapper.cpp | 38 +++++++++++ 6 files changed, 154 insertions(+), 2 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 5e9594dd06bb7..976f021251bb2 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -117,6 +117,70 @@ impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { } bx } + + pub(crate) fn my_alloca2(&mut self, ty: &'ll Type, align: Align, name: &str) -> &'ll Value { + let val = unsafe { + let alloca = llvm::LLVMBuildAlloca(self.llbuilder, ty, UNNAMED); + llvm::LLVMSetAlignment(alloca, align.bytes() as c_uint); + // Cast to default addrspace if necessary + llvm::LLVMBuildPointerCast(self.llbuilder, alloca, self.cx.type_ptr(), UNNAMED) + }; + if name != "" { + let name = std::ffi::CString::new(name).unwrap(); + llvm::set_value_name(val, &name.as_bytes()); + } + val + } + + pub(crate) fn inbounds_gep( + &mut self, + ty: &'ll Type, + ptr: &'ll Value, + indices: &[&'ll Value], + ) -> &'ll Value { + unsafe { + llvm::LLVMBuildGEPWithNoWrapFlags( + self.llbuilder, + ty, + ptr, + indices.as_ptr(), + indices.len() as c_uint, + UNNAMED, + GEPNoWrapFlags::InBounds, + ) + } + } + + pub(crate) fn store(&mut self, val: &'ll Value, ptr: &'ll Value, align: Align) -> &'ll Value { + debug!("Store {:?} -> {:?}", val, ptr); + assert_eq!(self.cx.type_kind(self.cx.val_ty(ptr)), TypeKind::Pointer); + unsafe { + let store = llvm::LLVMBuildStore(self.llbuilder, val, ptr); + llvm::LLVMSetAlignment(store, align.bytes() as c_uint); + store + } + } + + pub(crate) fn load(&mut self, ty: &'ll Type, ptr: &'ll Value, align: Align) -> &'ll Value { + unsafe { + let load = llvm::LLVMBuildLoad2(self.llbuilder, ty, ptr, UNNAMED); + llvm::LLVMSetAlignment(load, align.bytes() as c_uint); + load + } + } + + fn memset(&mut self, ptr: &'ll Value, fill_byte: &'ll Value, size: &'ll Value, align: Align) { + unsafe { + llvm::LLVMRustBuildMemSet( + self.llbuilder, + ptr, + align.bytes() as c_uint, + fill_byte, + size, + false, + ); + } + } } /// Empty string, to be used where LLVM expects an instruction name, indicating diff --git a/compiler/rustc_codegen_llvm/src/common.rs b/compiler/rustc_codegen_llvm/src/common.rs index ae5add59322fe..00a759f25d1b9 100644 --- a/compiler/rustc_codegen_llvm/src/common.rs +++ b/compiler/rustc_codegen_llvm/src/common.rs @@ -119,6 +119,10 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { r } } + + pub(crate) fn const_null(&self, t: &'ll Type) -> &'ll Value { + unsafe { llvm::LLVMConstNull(t) } + } } impl<'ll, 'tcx> ConstCodegenMethods for CodegenCx<'ll, 'tcx> { @@ -373,6 +377,11 @@ pub(crate) fn bytes_in_context<'ll>(llcx: &'ll llvm::Context, bytes: &[u8]) -> & } } +pub(crate) fn named_struct<'ll>(ty: &'ll Type, elts: &[&'ll Value]) -> &'ll Value { + let len = c_uint::try_from(elts.len()).expect("LLVMConstStructInContext elements len overflow"); + unsafe { llvm::LLVMConstNamedStruct(ty, elts.as_ptr(), len) } +} + fn struct_in_context<'ll>( llcx: &'ll llvm::Context, elts: &[&'ll Value], diff --git a/compiler/rustc_codegen_llvm/src/context.rs b/compiler/rustc_codegen_llvm/src/context.rs index 0324dff6ff256..12f5d3426392a 100644 --- a/compiler/rustc_codegen_llvm/src/context.rs +++ b/compiler/rustc_codegen_llvm/src/context.rs @@ -683,6 +683,22 @@ impl<'ll, CX: Borrow>> GenericCx<'ll, CX> { unsafe { llvm::LLVMConstInt(ty, val, llvm::False) } } + pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i64(), n) + } + + pub(crate) fn get_const_i32(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i32(), n) + } + + pub(crate) fn get_const_i16(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i16(), n) + } + + pub(crate) fn get_const_i8(&self, n: u64) -> &'ll Value { + self.get_const_int(self.type_i8(), n) + } + pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> { let name = SmallCStr::new(name); unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs index 2ad39fc853819..097ca72bc8aa2 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -5,7 +5,7 @@ use libc::{c_char, c_uint}; use super::MetadataKindId; use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value}; -use crate::llvm::Bool; +use crate::llvm::{Bool, Builder}; #[link(name = "llvm-wrapper", kind = "static")] unsafe extern "C" { @@ -32,6 +32,14 @@ unsafe extern "C" { index: c_uint, kind: AttributeKind, ); + pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value); + pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value); + pub(crate) fn LLVMRustGetFunctionCall( + F: &Value, + name: *const c_char, + NameLen: libc::size_t, + ) -> Option<&Value>; + } unsafe extern "C" { diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 91ada856d5977..bbe761309fdef 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1004,13 +1004,22 @@ unsafe extern "C" { SLen: c_uint, ) -> MetadataKindId; - // Create modules. + // Create, print, and destroy modules. + pub(crate) fn LLVMPrintModuleToFile( + M: &Module, + Name: *const c_char, + Error_message: *mut c_char, + ); + pub(crate) fn LLVMDisposeModule(M: &Module); pub(crate) fn LLVMModuleCreateWithNameInContext( ModuleID: *const c_char, C: &Context, ) -> &Module; pub(crate) fn LLVMCloneModule(M: &Module) -> &Module; + pub(crate) fn LLVMGetTarget(M: &Module) -> *const c_char; + pub(crate) fn LLVMSetTarget(M: &Module, Name: *const c_char); + /// Data layout. See Module::getDataLayout. pub(crate) fn LLVMGetDataLayoutStr(M: &Module) -> *const c_char; pub(crate) fn LLVMSetDataLayout(M: &Module, Triple: *const c_char); @@ -1138,6 +1147,11 @@ unsafe extern "C" { Count: c_uint, Packed: Bool, ) -> &'a Value; + pub(crate) fn LLVMConstNamedStruct<'a>( + StructTy: &'a Type, + ConstantVals: *const &'a Value, + Count: c_uint, + ) -> &'a Value; pub(crate) fn LLVMConstVector(ScalarConstantVals: *const &Value, Size: c_uint) -> &Value; // Constant expressions @@ -1217,6 +1231,8 @@ unsafe extern "C" { ) -> &'a BasicBlock; // Operations on instructions + pub(crate) fn LLVMGetInstructionParent(Inst: &Value) -> &BasicBlock; + pub(crate) fn LLVMGetCalledValue(CallInst: &Value) -> Option<&Value>; pub(crate) fn LLVMIsAInstruction(Val: &Value) -> Option<&Value>; pub(crate) fn LLVMGetFirstBasicBlock(Fn: &Value) -> &BasicBlock; pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>; @@ -2562,6 +2578,7 @@ unsafe extern "C" { pub(crate) fn LLVMRustSetDataLayoutFromTargetMachine<'a>(M: &'a Module, TM: &'a TargetMachine); + pub(crate) fn LLVMRustPositionBuilderPastAllocas<'a>(B: &Builder<'a>, Fn: &'a Value); pub(crate) fn LLVMRustPositionBuilderAtStart<'a>(B: &Builder<'a>, BB: &'a BasicBlock); pub(crate) fn LLVMRustSetModulePICLevel(M: &Module); diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 90aa9188c8300..3064f89f8c178 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -1591,12 +1591,50 @@ extern "C" LLVMValueRef LLVMRustBuildMemSet(LLVMBuilderRef B, LLVMValueRef Dst, MaybeAlign(DstAlign), IsVolatile)); } +extern "C" void LLVMRustPositionBuilderPastAllocas(LLVMBuilderRef B, + LLVMValueRef Fn) { + Function *F = unwrap(Fn); + unwrap(B)->SetInsertPointPastAllocas(F); +} extern "C" void LLVMRustPositionBuilderAtStart(LLVMBuilderRef B, LLVMBasicBlockRef BB) { auto Point = unwrap(BB)->getFirstInsertionPt(); unwrap(B)->SetInsertPoint(unwrap(BB), Point); } +extern "C" void LLVMRustPositionBefore(LLVMBuilderRef B, LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + unwrap(B)->SetInsertPoint(I); + } +} + +extern "C" void LLVMRustPositionAfter(LLVMBuilderRef B, LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + auto J = I->getNextNonDebugInstruction(); + unwrap(B)->SetInsertPoint(J); + } +} + +extern "C" LLVMValueRef +LLVMRustGetFunctionCall(LLVMValueRef Fn, const char *Name, size_t NameLen) { + auto targetName = StringRef(Name, NameLen); + Function *F = unwrap(Fn); + for (auto &BB : *F) { + for (auto &I : BB) { + if (auto *callInst = llvm::dyn_cast(&I)) { + const llvm::Function *calledFunc = callInst->getCalledFunction(); + if (calledFunc && calledFunc->getName() == targetName) { + // Found a call to the target function + llvm::errs() << "Found call: " << *callInst << "\n"; + return wrap(callInst); + } + } + } + } + + return nullptr; +} + extern "C" bool LLVMRustConstIntGetZExtValue(LLVMValueRef CV, uint64_t *value) { auto C = unwrap(CV); if (C->getBitWidth() > 64) From c1e5961fc96d84418787bd8ed74aa252f22f3222 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 18 Jun 2025 15:44:15 -0700 Subject: [PATCH 4/5] gpu host code generation --- compiler/rustc_codegen_llvm/src/back/lto.rs | 8 + compiler/rustc_codegen_llvm/src/builder.rs | 1 + .../src/builder/gpu_offload.rs | 612 ++++++++++++++++++ 3 files changed, 621 insertions(+) create mode 100644 compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index ee46b49a094c6..5e462dc38d145 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -653,6 +653,7 @@ pub(crate) fn run_pass_manager( // We then run the llvm_optimize function a second time, to optimize the code which we generated // in the enzyme differentiation pass. let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable); + let enable_gpu = config.offload.contains(&config::Offload::Enable); let stage = if thin { write::AutodiffStage::PreAD } else { @@ -667,6 +668,13 @@ pub(crate) fn run_pass_manager( write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?; } + if cfg!(llvm_enzyme) && enable_gpu && !thin { + dbg!(&enable_gpu); + let cx = + SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); + crate::builder::gpu_offload::handle_gpu_code(cgcx, &cx); + } + if cfg!(llvm_enzyme) && enable_ad && !thin { let cx = SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 976f021251bb2..06f1e31416cd5 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -3,6 +3,7 @@ use std::ops::Deref; use std::{iter, ptr}; pub(crate) mod autodiff; +pub(crate) mod gpu_offload; use libc::{c_char, c_uint, size_t}; use rustc_abi as abi; diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs new file mode 100644 index 0000000000000..e60b3d6c9e02b --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -0,0 +1,612 @@ +use std::ffi::CString; + +use llvm::Linkage::*; +use rustc_abi::Align; +use rustc_codegen_ssa::back::write::CodegenContext; +use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; + +use crate::builder::SBuilder; +use crate::common::AsCCharPtr; +use crate::llvm::AttributePlace::Function; +use crate::llvm::{self, Linkage, Visibility}; +use crate::{LlvmCodegenBackend, SimpleCx, attributes}; + +fn create_struct_ty<'ll>( + cx: &'ll SimpleCx<'_>, + name: &str, + tys: &[&'ll llvm::Type], +) -> &'ll llvm::Type { + let entry_struct_name = CString::new(name).unwrap(); + unsafe { + let entry_struct = llvm::LLVMStructCreateNamed(cx.llcx, entry_struct_name.as_ptr()); + llvm::LLVMStructSetBody(entry_struct, tys.as_ptr(), tys.len() as u32, 0); + entry_struct + } +} + +// We don't copy types from other functions because we generate a new module and context. +// Bringing in types from other contexts would likely cause issues. +pub(crate) fn gen_image_wrapper_module<'ll>( + cgcx: &CodegenContext, + old_cx: &SimpleCx<'ll>, +) { + unsafe { + let llcx = llvm::LLVMRustContextCreate(false); + let module_name = CString::new("offload.wrapper.module").unwrap(); + let llmod = llvm::LLVMModuleCreateWithNameInContext(module_name.as_ptr(), llcx); + let cx = SimpleCx::new(llmod, llcx, cgcx.pointer_size); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + let dl_cstr = llvm::LLVMGetDataLayoutStr(old_cx.llmod); + llvm::LLVMSetDataLayout(llmod, dl_cstr); + let target_cstr = llvm::LLVMGetTarget(old_cx.llmod); + llvm::LLVMSetTarget(llmod, target_cstr); + + let entry_fields = [ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + create_struct_ty(&cx, "__tgt_offload_entry", &entry_fields); + create_struct_ty(&cx, "__tgt_device_image", &[tptr, tptr, tptr, tptr]); + create_struct_ty(&cx, "__tgt_bin_desc", &[ti32, tptr, tptr, tptr]); + + let offload_entry_ty = add_tgt_offload_entry(&cx); + let offload_entry_arr = cx.type_array(offload_entry_ty, 0); + + let c_name = CString::new("__start_omp_offloading_entries").unwrap(); + let llglobal = llvm::add_global(cx.llmod, offload_entry_arr, &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, ExternalLinkage); + llvm::set_visibility(llglobal, Visibility::Hidden); + let c_name = CString::new("__stop_omp_offloading_entries").unwrap(); + let llglobal = llvm::add_global(cx.llmod, offload_entry_arr, &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, ExternalLinkage); + llvm::set_visibility(llglobal, Visibility::Hidden); + + let c_name = CString::new("__dummy.omp_offloading_entries").unwrap(); + let llglobal = llvm::add_global(cx.llmod, offload_entry_arr, &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, InternalLinkage); + let c_section_name = CString::new("omp_offloading_entries").unwrap(); + llvm::set_section(llglobal, &c_section_name); + let zeroinit = cx.const_null(offload_entry_arr); + llvm::set_initializer(llglobal, zeroinit); + + CString::new("llvm.compiler.used").unwrap(); + let arr_val = cx.const_array(tptr, &[llglobal]); + let c_section_name = CString::new("llvm.metadata").unwrap(); + let llglobal = add_global(&cx, "llvm.compiler.used", arr_val, AppendingLinkage); + llvm::set_section(llglobal, &c_section_name); + llvm::set_global_constant(llglobal, false); + + //@llvm.compiler.used = appending global [1 x ptr] [ptr @__dummy.omp_offloading_entries], section "llvm.metadata" + + let mapper_fn_ty = cx.type_func(&[tptr], cx.type_void()); + crate::declare::declare_simple_fn( + &cx, + &"__tgt_unregister_lib", + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + crate::declare::declare_simple_fn( + &cx, + &"__tgt_register_lib", + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + crate::declare::declare_simple_fn( + &cx, + &"atexit", + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + cx.type_func(&[tptr], ti32), + ); + + let unknown_txt = "11111111111111"; + let c_entry_name = CString::new(unknown_txt).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let llglobal = + add_unnamed_global(&cx, &".omp_offloading.device_image", initializer, InternalLinkage); + let c_section_name = CString::new(".llvm.offloading").unwrap(); + llvm::set_section(llglobal, &c_section_name); + llvm::set_alignment(llglobal, Align::EIGHT); + + //@.omp_offloading.device_image = internal unnamed_addr constant [4040 x i8] c"111111111", section ".llvm.offloading", align 8 + // TODO + // @.omp_offloading.device_images = internal unnamed_addr constant [1 x %__tgt_device_image] [%__tgt_device_image { ptr getelementptr ([4040 x i8], ptr @.omp_offloading.device_image, i64 0, i64 144), ptr getelementptr ([4040 x i8], ptr @.omp_offloading.device_image, i64 0, i64 4040), ptr @__start_omp_offloading_entries, ptr @__stop_omp_offloading_entries }] + // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_omp_offloading_entries, ptr @__stop_omp_offloading_entries } + // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }] + // + // define internal void @.omp_offloading.descriptor_reg() section ".text.startup" { + // entry: + // call void @__tgt_register_lib(ptr @.omp_offloading.descriptor) + // %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg) + // ret void + // } + // + // define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" { + // entry: + // call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor) + // ret void + // } + + llvm::LLVMPrintModuleToFile( + llmod, + CString::new("rustmagic.openmp.image.wrapper.ll").unwrap().as_ptr(), + std::ptr::null_mut(), + ); + + // Clean up + llvm::LLVMDisposeModule(llmod); + llvm::LLVMContextDispose(llcx); + } +} + +pub(crate) fn handle_gpu_code<'ll>( + cgcx: &CodegenContext, + cx: &'ll SimpleCx<'_>, +) { + let (offload_entry_ty, at_one, begin, update, end, tgt_bin_desc, fn_ty) = gen_globals(&cx); + + dbg!("created struct"); + let mut o_types = vec![]; + let mut kernels = vec![]; + for num in 0..9 { + let kernel = cx.get_function(&format!("kernel_{num}")); + if let Some(kernel) = kernel { + o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num)); + kernels.push(kernel); + } + } + dbg!("gen_call_handling"); + gen_call_handling(&cx, &kernels, at_one, begin, update, end, tgt_bin_desc, fn_ty, &o_types); + gen_image_wrapper_module(&cgcx, &cx); + // In a follow-up PR, we will enable gpu device code generation. + //gen_asdf(&cgcx, &cx); +} + +fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type { + let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let ti16 = cx.type_i16(); + let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]; + cx.set_struct_body(offload_entry_ty, &entry_elements, false); + offload_entry_ty +} + +fn gen_globals<'ll>( + cx: &'ll SimpleCx<'_>, +) -> ( + &'ll llvm::Type, + &'ll llvm::Value, + &'ll llvm::Value, + &'ll llvm::Value, + &'ll llvm::Value, + &'ll llvm::Type, + &'ll llvm::Type, +) { + let offload_entry_ty = add_tgt_offload_entry(&cx); + let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments"); + let tptr = cx.type_ptr(); + let ti64 = cx.type_i64(); + let ti32 = cx.type_i32(); + let tarr = cx.type_array(ti32, 3); + + // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 + let unknown_txt = ";unknown;unknown;0;0;;"; + let c_entry_name = CString::new(unknown_txt).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage); + llvm::set_alignment(at_zero, Align::ONE); + + // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 + let struct_ident_ty = cx.type_named_struct("struct.ident_t"); + let struct_elems: Vec<&llvm::Value> = vec![ + cx.get_const_i32(0), + cx.get_const_i32(2), + cx.get_const_i32(0), + cx.get_const_i32(22), + at_zero, + ]; + let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect(); + let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems); + cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false); + let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage); + llvm::set_alignment(at_one, Align::EIGHT); + + // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } + let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr]; + let tgt_bin_desc_name = cx.type_named_struct("struct.__tgt_bin_desc"); + cx.set_struct_body(tgt_bin_desc_name, &tgt_bin_desc_ty, false); + + // coppied from LLVM + // typedef struct { + // uint64_t Reserved; + // uint16_t Version; + // uint16_t Kind; + // uint32_t Flags; + // void *Address; + // char *SymbolName; + // uint64_t Size; + // uint64_t Data; + // void *AuxAddr; + // } __tgt_offload_entry; + let kernel_elements = + vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32]; + + cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false); + cx.declare_global("my_struct_global2", kernel_arguments_ty); + //@my_struct_global = external global %struct.__tgt_offload_entry + //@my_struct_global2 = external global %struct.__tgt_kernel_arguments + dbg!(&kernel_arguments_ty); + //LLVMTypeRef elements[9] = {i64Ty, i16Ty, i16Ty, i32Ty, ptrTy, ptrTy, i64Ty, i64Ty, ptrTy}; + //LLVMStructSetBody(structTy, elements, 9, 0); + + // New, to test memtransfer + // ; Function Attrs: nounwind + // declare void @__tgt_target_data_begin_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3 + // + // ; Function Attrs: nounwind + // declare void @__tgt_target_data_update_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3 + // + // ; Function Attrs: nounwind + // declare void @__tgt_target_data_end_mapper(ptr, i64, i32, ptr, ptr, ptr, ptr, ptr, ptr) #3 + + let mapper_begin = "__tgt_target_data_begin_mapper"; + let mapper_update = String::from("__tgt_target_data_update_mapper"); + let mapper_end = String::from("__tgt_target_data_end_mapper"); + let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr]; + let mapper_fn_ty = cx.type_func(&args, cx.type_void()); + let foo = crate::declare::declare_simple_fn( + &cx, + &mapper_begin, + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + let bar = crate::declare::declare_simple_fn( + &cx, + &mapper_update, + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + let baz = crate::declare::declare_simple_fn( + &cx, + &mapper_end, + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx); + attributes::apply_to_llfn(foo, Function, &[nounwind]); + attributes::apply_to_llfn(bar, Function, &[nounwind]); + attributes::apply_to_llfn(baz, Function, &[nounwind]); + + (offload_entry_ty, at_one, foo, bar, baz, tgt_bin_desc_name, mapper_fn_ty) +} + +fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value { + let ti64 = cx.type_i64(); + let mut size_val = Vec::with_capacity(vals.len()); + for &val in vals { + size_val.push(cx.get_const_i64(val)); + } + let initializer = cx.const_array(ti64, &size_val); + add_unnamed_global(cx, name, initializer, PrivateLinkage) +} + +fn add_unnamed_global<'ll>( + cx: &SimpleCx<'ll>, + name: &str, + initializer: &'ll llvm::Value, + l: Linkage, +) -> &'ll llvm::Value { + let llglobal = add_global(cx, name, initializer, l); + unsafe { llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global) }; + llglobal +} + +fn add_global<'ll>( + cx: &SimpleCx<'ll>, + name: &str, + initializer: &'ll llvm::Value, + l: Linkage, +) -> &'ll llvm::Value { + let c_name = CString::new(name).unwrap(); + let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, l); + llvm::set_initializer(llglobal, initializer); + llglobal +} + +fn gen_define_handling<'ll>( + cx: &'ll SimpleCx<'_>, + kernel: &'ll llvm::Value, + offload_entry_ty: &'ll llvm::Type, + num: i64, +) -> &'ll llvm::Value { + let types = cx.func_params_types(cx.get_type_of_global(kernel)); + // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or + // reference) types. + let num_ptr_types = types + .iter() + .map(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) + .count(); + + // We do not know their size anymore at this level, so hardcode a placeholder. + // A follow-up pr will track these from the frontend, where we still have Rust types. + // Then, we will be able to figure out that e.g. `&[f32;1024]` will result in 32*1024 bytes. + // I decided that 1024 bytes is a great placeholder value for now. + add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]); + // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), + // or both to and from the gpu (=3). Other values shouldn't affect us for now. + // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten + // will be 2. For now, everything is 3, untill we have our frontend set up. + let o_types = + add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]); + // Next: For each function, generate these three entries. A weak constant, + // the llvm.rodata entry name, and the omp_offloading_entries value + + // @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id = weak constant i8 0 + // @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1 + let name = format!(".kernel_{num}.region_id"); + let initializer = cx.get_const_i8(0); + let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage); + + let c_entry_name = CString::new(format!("kernel_{num}")).unwrap(); + let c_val = c_entry_name.as_bytes_with_nul(); + let foo = format!(".offloading.entry_name.{num}"); + + let initializer = crate::common::bytes_in_context(cx.llcx, c_val); + let llglobal = add_unnamed_global(&cx, &foo, initializer, InternalLinkage); + llvm::set_alignment(llglobal, Align::ONE); + let c_section_name = CString::new(".llvm.rodata.offloading").unwrap(); + llvm::set_section(llglobal, &c_section_name); + + // Not actively used yet, for calling real kernels + let name = format!(".offloading.entry.kernel_{num}"); + let ci64_0 = cx.get_const_i64(0); + let ci16_1 = cx.get_const_i16(1); + let elems: Vec<&llvm::Value> = vec![ + ci64_0, + ci16_1, + ci16_1, + cx.get_const_i32(0), + region_id, + llglobal, + ci64_0, + ci64_0, + cx.const_null(cx.type_ptr()), + ]; + + let initializer = crate::common::named_struct(offload_entry_ty, &elems); + let c_name = CString::new(name).unwrap(); + let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name); + llvm::set_global_constant(llglobal, true); + llvm::set_linkage(llglobal, WeakAnyLinkage); + llvm::set_initializer(llglobal, initializer); + llvm::set_alignment(llglobal, Align::ONE); + let c_section_name = CString::new(".omp_offloading_entries").unwrap(); + llvm::set_section(llglobal, &c_section_name); + // rustc + // @.offloading.entry.kernel_3 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_3.region_id, ptr @.offloading.entry_name.3, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 + // clang + // @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1 + + // + // 1. @.offload_sizes.{num} = private unnamed_addr constant [4 x i64] [i64 8, i64 0, i64 16, i64 0] + // 2. @.offload_maptypes + // 3. @.__omp_offloading__fnc_name_ = weak constant i8 0 + // 4. @.offloading.entry_name = internal unnamed_addr constant [66 x i8] c"__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7\00", section ".llvm.rodata.offloading", align 1 + // 5. @.offloading.entry.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.__omp_offloading_86fafab6_c40006a1__Z3fooPSt7complexIdES1_S0_m_l7.region_id, ptr @.offloading.entry_name, i64 0, i64 0, ptr null }, section "omp_offloading_entries", align 1 + o_types +} + +// For each kernel *call*, we now use some of our previous declared globals to move data to and from +// the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function +// from main is intended to run on the GPU. For now, we only handle the data transfer part of it. +// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu. +// Since in our frontend users (by default) don't have to specify data transfer, this is something +// we should optimize in the future! We also assume that everything should be copied back and forth, +// but sometimes we can directly zero-allocate on the device and only move back, or if something is +// immutable, we might only copy it to the device, but not back. +// +// Current steps: +// 0. Alloca some variables for the following steps +// 1. set insert point before kernel call. +// 2. generate all the GEPS and stores, to be used in 3) +// 3. generate __tgt_target_data_begin calls to move data to the GPU +// +// unchanged: keep kernel call. Later move the kernel to the GPU +// +// 4. set insert point after kernel call. +// 5. generate all the GEPS and stores, to be used in 6) +// 6. generate __tgt_target_data_end calls to move data from the GPU +fn gen_call_handling<'ll>( + cx: &'ll SimpleCx<'_>, + _kernels: &[&'ll llvm::Value], + s_ident_t: &'ll llvm::Value, + begin: &'ll llvm::Value, + _update: &'ll llvm::Value, + end: &'ll llvm::Value, + tgt_bin_desc: &'ll llvm::Type, + fn_ty: &'ll llvm::Type, + o_types: &[&'ll llvm::Value], +) { + let main_fn = cx.get_function("main"); + if let Some(main_fn) = main_fn { + let kernel_name = "kernel_1"; + let call = unsafe { + llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len()) + }; + let kernel_call = if call.is_some() { + dbg!("found kernel call"); + call.unwrap() + } else { + return; + }; + let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) }; + let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() }; + let mut builder = SBuilder::build(cx, kernel_call_bb); + + let types = cx.func_params_types(cx.get_type_of_global(called)); + dbg!(&types); + let num_args = types.len() as u64; + + // Step 0) + // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } + // %6 = alloca %struct.__tgt_bin_desc, align 8 + unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) }; + + let tgt_bin_desc_alloca = builder.my_alloca2(tgt_bin_desc, Align::EIGHT, "EmptyDesc"); + //fill_byte: &'ll Value, + //size: &'ll Value, + //align: Align, + //flags: MemFlags, + // call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false) + // mem + + let ty = cx.type_array(cx.type_ptr(), num_args); + // Baseptr are just the input pointer to the kernel, stored in a local alloca + let a1 = builder.my_alloca2(ty, Align::EIGHT, ".offload_baseptrs"); + // Ptrs are the result of a gep into the baseptr, at least for our trivial types. + let a2 = builder.my_alloca2(ty, Align::EIGHT, ".offload_ptrs"); + // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16. + let ty2 = cx.type_array(cx.type_i64(), num_args); + let a4 = builder.my_alloca2(ty2, Align::EIGHT, ".offload_sizes"); + // Now we allocate once per function param, a copy to be passed to one of our maps. + let mut vals = vec![]; + let mut geps = vec![]; + let i32_0 = cx.get_const_i32(0); + for (index, in_ty) in types.iter().enumerate() { + // get function arg, store it into the alloca, and read it. + let p = llvm::get_param(called, index as u32); + let name = llvm::get_value_name(p); + let name = str::from_utf8(name).unwrap(); + let arg_name = CString::new(format!("{name}.addr")).unwrap(); + let alloca = + unsafe { llvm::LLVMBuildAlloca(builder.llbuilder, in_ty, arg_name.as_ptr()) }; + builder.store(p, alloca, Align::EIGHT); + let val = builder.load(in_ty, alloca, Align::EIGHT); + let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]); + vals.push(val); + geps.push(gep); + } + + // Step 1) + unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; + builder.memset( + tgt_bin_desc_alloca, + cx.get_const_i8(0), + cx.get_const_i64(32), + Align::from_bytes(8).unwrap(), + ); + + let tptr = cx.type_ptr(); + let mapper_fn_ty = cx.type_func(&[tptr], cx.type_void()); + let foo = crate::declare::declare_simple_fn( + &cx, + &"__tgt_register_lib", + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + let bar = crate::declare::declare_simple_fn( + &cx, + &"__tgt_unregister_lib", + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + mapper_fn_ty, + ); + let init_ty = cx.type_func(&[], cx.type_void()); + let baz = crate::declare::declare_simple_fn( + &cx, + &"__tgt_init_all_rtls", + llvm::CallConv::CCallConv, + llvm::UnnamedAddr::No, + llvm::Visibility::Default, + init_ty, + ); + + builder.call(mapper_fn_ty, foo, &[tgt_bin_desc_alloca], None); + builder.call(init_ty, baz, &[], None); + + // call void @__tgt_register_lib(ptr noundef %6) + // call void @__tgt_init_all_rtls() + for i in 0..num_args { + let idx = cx.get_const_i32(i); + let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]); + builder.store(vals[i as usize], gep1, Align::EIGHT); + let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]); + builder.store(geps[i as usize], gep2, Align::EIGHT); + let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]); + builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT); + } + + // Step 2) + let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]); + let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]); + let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]); + + let nullptr = cx.const_null(cx.type_ptr()); + let o_type = o_types[0]; + let args = vec![ + s_ident_t, + cx.get_const_i64(u64::MAX), + cx.get_const_i32(num_args), + gep1, + gep2, + gep3, + o_type, + nullptr, + nullptr, + ]; + builder.call(fn_ty, begin, &args, None); + + // Step 4) + unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) }; + + let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]); + let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]); + let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]); + + let nullptr = cx.const_null(cx.type_ptr()); + let o_type = o_types[0]; + let args = vec![ + s_ident_t, + cx.get_const_i64(u64::MAX), + cx.get_const_i32(num_args), + gep1, + gep2, + gep3, + o_type, + nullptr, + nullptr, + ]; + builder.call(fn_ty, end, &args, None); + builder.call(mapper_fn_ty, bar, &[tgt_bin_desc_alloca], None); + + // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null) + // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null) + // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null) + // What is @1? Random but fixed: + // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 + // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 + } +} From 5d3ce244fd2d653f95fa83f78b4cf9558d8e852c Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Wed, 18 Jun 2025 16:47:15 -0700 Subject: [PATCH 5/5] add gpu offload codegen host side test --- tests/codegen/gpu_offload/gpu_host.rs | 80 +++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/codegen/gpu_offload/gpu_host.rs diff --git a/tests/codegen/gpu_offload/gpu_host.rs b/tests/codegen/gpu_offload/gpu_host.rs new file mode 100644 index 0000000000000..b5f6d2e50b681 --- /dev/null +++ b/tests/codegen/gpu_offload/gpu_host.rs @@ -0,0 +1,80 @@ +//@ compile-flags: -Zoffload=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to the +// kernel_1. Better documentation to what each global or variable means is available in the gpu +// offlaod code, or the LLVM offload documentation. This code does not launch any GPU kernels yet, +// and will be rewritten once a proper offload frontend has landed. +// +// We currently only handle memory transfer for specific calls to functions named `kernel_{num}`, +// when inside of a function called main. This, too, is a temporary workaround for not having a +// frontend. + +#![no_main] + +#[unsafe(no_mangle)] +fn main() { + let mut x = [3.0; 256]; + kernel_1(&mut x); + core::hint::black_box(&x); +} + +// CHECK: %struct.ident_t = type { i32, i32, i32, i32, ptr } +// CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } +// CHECK: %struct.__tgt_offload_entry = type { i64, i16, i16, i32, ptr, ptr, i64, i64, ptr } +// CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } + +// CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +// CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 +// CHECK: @my_struct_global2 = external global %struct.__tgt_kernel_arguments +// CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024] +// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 3] +// CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0 +// CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1 +// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section ".omp_offloading_entries", align 1 + +// CHECK: Function Attrs: +// CHECK-NEXT: define{{( dso_local)?}} void @main() +// CHECK-NEXT: start: +// CHECK-NEXT: %0 = alloca [8 x i8], align 8 +// CHECK-NEXT: %x = alloca [1024 x i8], align 16 +// CHECK-NEXT: %EmptyDesc = alloca %struct.__tgt_bin_desc, align 8 +// CHECK-NEXT: %.offload_baseptrs = alloca [1 x ptr], align 8 +// CHECK-NEXT: %.offload_ptrs = alloca [1 x ptr], align 8 +// CHECK-NEXT: %.offload_sizes = alloca [1 x i64], align 8 +// CHECK-NEXT: %x.addr = alloca ptr, align 8 +// CHECK-NEXT: store ptr %x, ptr %x.addr, align 8 +// CHECK-NEXT: %1 = load ptr, ptr %x.addr, align 8 +// CHECK-NEXT: %2 = getelementptr inbounds float, ptr %1, i32 0 +// CHECK: call void @llvm.memset.p0.i64(ptr align 8 %EmptyDesc, i8 0, i64 32, i1 false) +// CHECK-NEXT: call void @__tgt_register_lib(ptr %EmptyDesc) +// CHECK-NEXT: call void @__tgt_init_all_rtls() +// CHECK-NEXT: %3 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %1, ptr %3, align 8 +// CHECK-NEXT: %4 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: store ptr %2, ptr %4, align 8 +// CHECK-NEXT: %5 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: store i64 1024, ptr %5, align 8 +// CHECK-NEXT: %6 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %7 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %8 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 1, ptr %6, ptr %7, ptr %8, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: call void @kernel_1(ptr noalias noundef nonnull align 4 dereferenceable(1024) %x) +// CHECK-NEXT: %9 = getelementptr inbounds [1 x ptr], ptr %.offload_baseptrs, i32 0, i32 0 +// CHECK-NEXT: %10 = getelementptr inbounds [1 x ptr], ptr %.offload_ptrs, i32 0, i32 0 +// CHECK-NEXT: %11 = getelementptr inbounds [1 x i64], ptr %.offload_sizes, i32 0, i32 0 +// CHECK-NEXT: call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 1, ptr %9, ptr %10, ptr %11, ptr @.offload_maptypes.1, ptr null, ptr null) +// CHECK-NEXT: call void @__tgt_unregister_lib(ptr %EmptyDesc) +// CHECK: store ptr %x, ptr %0, align 8 +// CHECK-NEXT: call void asm sideeffect "", "r,~{memory}"(ptr nonnull %0) +// CHECK: ret void +// CHECK-NEXT: } + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn kernel_1(x: &mut [f32; 256]) { + for i in 0..256 { + x[i] = 21.0; + } +}