Skip to content

Commit d57b4c3

Browse files
committed
wip
1 parent 215bcfc commit d57b4c3

File tree

3 files changed

+58
-2
lines changed

3 files changed

+58
-2
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,59 @@ pub(crate) unsafe fn llvm_optimize(
665665

666666
let llvm_plugins = config.llvm_plugins.join(",");
667667

668+
fn adjust_offload_kernel_abis(m: &llvm::Module, llcx: &llvm::Context) {
669+
unsafe {
670+
// We just add a `ptr %dyn_ptr, ` as the first arg to every kernel_{i} function.
671+
// for function in function
672+
for num in 0..9 {
673+
let name = format!("kernel_{num}");
674+
let c_name = CString::new(name).unwrap();
675+
let kernel = llvm::LLVMGetNamedFunction(m, c_name.as_ptr());
676+
if let Some(old_fn) = kernel {
677+
dbg!("found kernel");
678+
//let old_fn_ty = llvm::LLVMTypeOf(old_fn);
679+
//let ptr_ty = llvm::LLVMPointerTypeInContext(llcx, 0);
680+
//
681+
//let n = unsafe { llvm::LLVMCountParamTypes(old_fn_ty) } as usize;
682+
//
683+
//let mut old_param_tys = Vec::with_capacity(n);
684+
//unsafe { llvm::LLVMGetParamTypes(old_fn_ty, old_param_tys.as_mut_ptr()) };
685+
//let ret_ty = unsafe { llvm::LLVMGetReturnType(old_fn_ty) };// new param list = [ptr] + old params
686+
//let mut new_params = Vec::with_capacity(n + 1);
687+
//new_params.push(ptr_ty);
688+
//for elem in &old_param_tys {
689+
// new_params.push(elem);
690+
//}
691+
//let new_fn_ty = unsafe {
692+
// llvm::LLVMFunctionType(ret_ty, new_params.as_mut_ptr(), new_params.len() as u32, llvm::False)
693+
//};
694+
//let new_fn = unsafe { llvm::LLVMAddFunction(c_name.as_ptr(), new_fn_ty) };
695+
//let a0 = unsafe { llvm::LLVMGetParam(new_fn, 0) };
696+
//unsafe { llvm::LLVMSetValueName2(a0, b"dyn_ptr\0".as_ptr(), "dyn_ptr".len()) };// Move basic blocks
697+
//let mut bb = unsafe { llvm::LLVMGetFirstBasicBlock(old_fn) };
698+
//while !bb.is_null() {
699+
// let next = unsafe { llvm::LLVMGetNextBasicBlock(bb) };
700+
// unsafe { llvm::LLVMAppendExistingBasicBlock(new_fn, bb) };
701+
// bb = next;
702+
//}// Shift argument uses: old %0 -> new %1, old %1 -> new %2, ...
703+
//let old_n = unsafe { llvm::LLVMCountParams(old_fn) };
704+
//for i in 0..old_n {
705+
// let old_arg = unsafe { llvm::LLVMGetParam(old_fn, i) };
706+
// let new_arg = unsafe { llvm::LLVMGetParam(new_fn, i + 1) };
707+
// unsafe { llvm::LLVMReplaceAllUsesWith(old_arg, new_arg) };
708+
//}
709+
//unsafe { llvm::LLVMReplaceAllUsesWith(old_fn, new_fn) };
710+
}
711+
}
712+
}
713+
714+
}
715+
if cgcx.target_arch == "amdgpu" {
716+
adjust_offload_kernel_abis(module.module_llvm.llmod(), &*module.module_llvm.llc);
717+
} else {
718+
dbg!(&cgcx.target_arch);
719+
}
720+
668721
let result = unsafe {
669722
llvm::LLVMRustOptimize(
670723
module.module_llvm.llmod(),

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,8 @@ unsafe extern "C" {
12011201

12021202
// Operations on functions
12031203
pub(crate) fn LLVMSetFunctionCallConv(Fn: &Value, CC: c_uint);
1204+
//pub(crate) fn LLVMGetNamedFunction<'a>(Mod: &'a Module, Name: *const char) -> Option<&'a Value>;
1205+
12041206
pub(crate) fn LLVMDeleteFunction(Fn: &Value);
12051207

12061208
// Operations about llvm intrinsics

compiler/rustc_codegen_llvm/src/mono_item.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ impl<'tcx> PreDefineCodegenMethods<'tcx> for CodegenCx<'_, 'tcx> {
6868
let fn_abi = self.fn_abi_of_instance(instance, ty::List::empty());
6969
let fn_abi = if fn_abi.conv == rustc_abi::CanonAbi::GpuKernel {
7070
dbg!("found gpu fn!");
71-
my_fn_abi(fn_abi)
71+
fn_abi.clone()
72+
//my_fn_abi(fn_abi)
7273
} else {
73-
dbg!("asdf!");
74+
//dbg!("asdf!");
7475
fn_abi.clone()
7576
};
7677
let lldecl = self.declare_fn(symbol_name, &fn_abi, Some(instance));

0 commit comments

Comments
 (0)