diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 657db58831c1b..f437e52355ef0 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -55,6 +55,7 @@ use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, LLVMGetNextBasicBlock, }; use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity, DiffMode}; +use rustc_ast::expand::typetree::FncTree; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -1091,6 +1092,24 @@ pub(crate) unsafe fn differentiate( llvm::set_loose_types(true); } + // Before dumping the module, we want all the tt to become part of the module. + for item in &diff_items { + let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + //let input_tts: Vec = + // item.inputs.iter().map(|x| to_enzyme_typetree(x.clone(), llvm_data_layout, llcx)).collect(); + //let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); + let tt: FncTree = FncTree { + args: item.inputs.clone(), + ret: item.output.clone(), + }; + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: &llvm::Value = llvm::LLVMGetNamedFunction(llmod, name.as_ptr()).unwrap(); + crate::builder::add_tt2(llmod, llcx, fn_def, tt); + } + if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() { unsafe { LLVMDumpModule(llmod); diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index f7afe9cbefb7a..733dd76ed105c 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -136,6 +136,49 @@ macro_rules! builder_methods_for_value_instructions { })+ } } +pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: &'ll Value, tt: FncTree) { + let inputs = tt.args; + let ret_tt: TypeTree = tt.ret; + let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; + let llvm_data_layout = + std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes()) + .expect("got a non-UTF8 data-layout from LLVM"); + let attr_name = "enzyme_type"; + let c_attr_name = std::ffi::CString::new(attr_name).unwrap(); + for (i, &ref input) in inputs.iter().enumerate() { + let c_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx); + let c_str = unsafe { llvm::EnzymeTypeTreeToString(c_tt.inner) }; + let c_str = unsafe { std::ffi::CStr::from_ptr(c_str) }; + unsafe { + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::LLVMRustAddFncParamAttr(fn_def, i as u32, attr); + } + unsafe { llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()) }; + } + let ret_attr = unsafe { + let c_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx); + let c_str = llvm::EnzymeTypeTreeToString(c_tt.inner); + let c_str = std::ffi::CStr::from_ptr(c_str); + let attr = llvm::LLVMCreateStringAttribute( + llcx, + c_attr_name.as_ptr(), + c_attr_name.as_bytes().len() as c_uint, + c_str.as_ptr(), + c_str.to_bytes().len() as c_uint, + ); + llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr()); + attr + }; + unsafe { + llvm::LLVMRustAddRetFncAttr(fn_def, ret_attr); + } +} fn add_tt<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context,val: &'ll Value, tt: FncTree) { let inputs = tt.args; diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 260e66d2ce92a..0bf2eb5fb3a2f 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -848,8 +848,8 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( fnc: &Value, input_diffactivity: Vec, ret_diffactivity: DiffActivity, - input_tts: Vec, - output_tt: TypeTree, + _input_tts: Vec, + _output_tt: TypeTree, void_ret: bool, ) -> (&Value, Vec) { let ret_activity = cdiffe_from(ret_diffactivity); @@ -878,13 +878,12 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( }; trace!("ret_primary_ret: {}", &ret_primary_ret); - let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. - let args_uncacheable = vec![0; input_tts.len()]; - assert!(args_uncacheable.len() == input_activity.len()); + let args_uncacheable = vec![0; input_activity.len()]; let num_fnc_args = LLVMCountParams(fnc); trace!("num_fnc_args: {}", num_fnc_args); trace!("input_activity.len(): {}", input_activity.len()); @@ -894,9 +893,16 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( let mut known_values = vec![kv_tmp; input_activity.len()]; + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; + + //let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()]; + //let ret_tt = std::ptr::null_mut(); + //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let ret_tt = TypeTree::new(); let dummy_type = CFnTypeInfo { Arguments: args_tree.as_mut_ptr(), - Return: output_tt.inner.clone(), + Return: ret_tt.inner, KnownValues: known_values.as_mut_ptr(), }; @@ -935,7 +941,7 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( rust_input_activity: Vec, ret_activity: DiffActivity, input_tts: Vec, - output_tt: TypeTree, + _output_tt: TypeTree, ) -> (&Value, Vec) { let (primary_ret, ret_activity) = match ret_activity { DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), @@ -961,16 +967,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( input_activity.push(cdiffe_from(x)); } - let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); + //let args_tree = input_tts.iter().map(|x| x.inner).collect::>(); // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. - let args_uncacheable = vec![0; input_tts.len()]; - if args_uncacheable.len() != input_activity.len() { - dbg!("args_uncacheable.len(): {}", args_uncacheable.len()); - dbg!("input_activity.len(): {}", input_activity.len()); - } - assert!(args_uncacheable.len() == input_activity.len()); + let args_uncacheable = vec![0; input_activity.len()]; let num_fnc_args = LLVMCountParams(fnc); println!("num_fnc_args: {}", num_fnc_args); println!("input_activity.len(): {}", input_activity.len()); @@ -979,9 +980,15 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( let mut known_values = vec![kv_tmp; input_tts.len()]; + let tree_tmp = TypeTree::new(); + let mut args_tree = vec![tree_tmp.inner; input_tts.len()]; + //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let ret_tt = TypeTree::new(); + //let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()]; + //let ret_tt = std::ptr::null_mut(); let dummy_type = CFnTypeInfo { Arguments: args_tree.as_mut_ptr(), - Return: output_tt.inner.clone(), + Return: ret_tt.inner, KnownValues: known_values.as_mut_ptr(), }; @@ -1023,12 +1030,12 @@ extern "C" { //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value; // Enzyme pub fn LLVMRustAddFncParamAttr<'a>( - Instr: &'a Value, + F: &'a Value, index: c_uint, Attr: &'a Attribute ); - pub fn LLVMRustAddRetAttr(V: &Value, attr: AttributeKind); + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute); pub fn LLVMRustRemoveFncAttr(V: &Value, attr: AttributeKind); pub fn LLVMRustHasDbgMetadata(I: &Value) -> bool; pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index 8f17c26f10177..43ed4c50613a3 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -865,9 +865,9 @@ extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i, } extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F, - LLVMRustAttribute RustAttr) { + LLVMAttributeRef RustAttr) { if (auto *Fn = dyn_cast(unwrap(F))) { - Fn->addRetAttr(fromRust(RustAttr)); + Fn->addRetAttr(unwrap(RustAttr)); } }